【ml5.js】Regression 回归 & 保存/加载模型

557 阅读4分钟

blog.csdn.net/fribbler/ar…

回归是想得到一个线性的答案,比如预测房价,识别物体在图片中的位置,回归的答案是一个连续的值。而分类是想得到一个离散的值,比如我们想分辨一张图片中的水果是苹果还是梨,虽然分类结果会给一个0-1的confidence值,但是我们想要的结果就是知道它是苹果还是梨;

观看本教程的视频:www.bilibili.com/video/BV1az…

一、开头一段代码

还是打开ml5.js的在线编程网页:chn.ai/ml5.html,录入下面代码,点击运行。第一次运行的时候浏览器会请求您电脑摄像头的权限,点“允许”。

<!DOCTYPE html>
<html lang="en">
<head>
<title>Regression</title>
<script src='js/ml5.min.js'></script>
<script src="https://cdn.bootcdn.net/ajax/libs/p5.js/1.1.9/p5.min.js"></script>
</head>
<body>
<script>

let video, canvas, logger;
let size = {w: 500, h: 400};
let featureExtractor, regressor;
let largeCt = 0, mediumCt = 0, smallCt = 0;
let img;
let value;

function preload() {
    img = loadImage('img/docStrange.png');
}

function setup() {
    video = createCapture(VIDEO);
    video.size(size.w, size.h);
    video.style('transform: rotateY(180deg)');
    
    canvas = createCanvas(size.w, size.h, WEBGL);
    canvas.style('position: absolute; z-index: 10;');
    
    createDiv();
    createButton('add large').mouseClicked(addLarge);
    createButton('add medium').mouseClicked(addMedium);
    createButton('add small').mouseClicked(addSmall);
    
    createDiv();
    createButton('train').mouseClicked(train);
    createButton('save').mouseClicked(saveModel);
    createButton('load').mouseClicked(function() {
        logging('loading custom model');
        regressor.load(files.elt.files, function() {
            logging('custom model loaded');
        });
    });
    files = createFileInput();
    files.attribute('multiple', true);
    
    createDiv();
    createButton('go').mouseClicked(go);
    
    logger = createP();

    initExtractor();
}

function initExtractor() {
    featureExtractor = ml5.featureExtractor('MobileNet', function() {
        logging('model loaded');
        regressor = featureExtractor.regression(video, function() {
            logging('regressor inited');
        })
    })
}

function saveModel() {
    regressor.save();
}

/*function loadModel() {
}*/

function addLarge() {
    regressor.addImage(1, function() {
        logging('large added ' + ++largeCt);
    })
}
function addMedium() {
    regressor.addImage(0.5, function() {
        logging('medium added ' + ++mediumCt);
    })
}
function addSmall() {
    regressor.addImage(0, function() {
        logging('small added ' + ++smallCt);
    })
}

function train() {
    regressor.train(function(loss) {
        logging(loss);
    })
}

function go() {
    setInterval(function() {
        regressor.predict(function(err, result) {
            if(err) {
                logging(err);
                return;
            }
            
            // else
            value = result.value;
            logging(value);
        })
    }, 100);
}

function draw() {
    // image(img, 0, 0, 50, 50);
    if(!value) return;
    
    let r = value * 500 / 2;

    //background('rgba(0,0,0, 0)');
    rotateZ(frameCount * 0.01);
    image(img, -r, -r, r * 2, r * 2);
}

function logging(c) {
    logger.html(c + '<br/>' + logger.html());
}
</script>

</body>
</html>

和上节课不同的是我们今天要引入一个p5.js的库,这个库能方便的创建页面上的元素,并能用简单的代码做出一些页面特效。只需要在页面开头写上

<script src="https://cdn.bootcdn.net/ajax/libs/p5.js/1.1.9/p5.min.js">
就可以引入了。

我们还是先来创建页面元素,今天我们还是需要用到摄像头。和上次不同,我们现在准备用p5的createCapture()方法了创建video元素并将摄像头捕获的内容投放到这个元素里面。如果我们正对摄像头,屏幕上显示的视频和我们照镜子的左右是相反的,用起来相当别扭。那我们用css video.style('transform: rotateY(180deg)');把视频水平180度翻转一下看起来就对了。

canvas = createCanvas(size.w, size.h, WEBGL);用来创建一个canvas画布,通过调整css将画布覆盖到video上面,我们要在上面绘制奇异博士的魔法圈。

createDiv, createButton, createFileInput, createP这些p5的方法用来创建页面的基本元素。

p5有一些特定名字的函数,只要我们创建了这些名字的函数,p5就会在不同的阶段调用它们:

  • setup(): 在页面初始化完成后就被p5调用,我们只需要将创建UI的代码写到这个函数里面就可以了。

  • preload(): 页面加载之前被p5调用,在这里我们把图片加载进来

  • draw(): 如果我们实现了这个函数,p5就会定期调用这个函数,我们在这个函数里面实现需要绘制的内容。我们要在画布上显示一张图片;

二、原理

回归是想得到一个线性的答案,比如预测房价,识别物体在图片中的位置,回归的答案是一个连续的值。而分类是想得到一个离散的值,比如我们想分辨图片中水果是苹果还是梨,虽然分类结果会给一个0-1的confidence值,但是我们想要的结果就是知道它是苹果还是梨。

在初始化模型的时候,我们用了featureExtractor.regression();记得在分类的时候,我们用的是featureExtractor.classification()

function initExtractor() {
    featureExtractor = ml5.featureExtractor('MobileNet', function() {
        logging('model loaded');
        regressor = featureExtractor.regression(video, function() {
            logging('regressor inited');
        })
    })
}

所以在训练regression的模型的时候,我们输入到网络里面的数据是图片 + 一个数字,这个数字代表这个图片状态的程度。所以我们添加训练数据的时候,调用addImage()方法,第一个参数是一个数字,这是一个0-1之间的值,用来告诉模型当前图像当前的状态,比如我们手掌全开的时候,就是addLarge函数里面,这个值是1;addMedium(),手掌半开的时候这个值是0.5;addSmall()手掌闭合的时候这个值是0;在预测的时候得到的结果就是根据刚才训练的标准来预测的值,总体也是0-1范围内的,手掌越打开越接近1,越闭合越接近0。

训练过程和之前一样,调用regressor.train()即可。我们通过regressor.predict()来预测结果,得到的结果也是一个 0 - 1之间的一个数字,表明模型预测到的状态程度。

三、保存/加载模型

可能之前大家都会问,我们网页上的网络模型,在每次刷新页面的时候都会全部清空,那是不是每次都需要重新训练呢?当然,ml5提供了save和load方法,用来保存和加载训练好的模型,这样即使页面被刷新,我们也有办法把原来训练好的模型加载到内存中来。

regressor.save();

这个save方法执行后页面会提示下载多个文件,如果浏览器询问权限,回答‘同意’下载多个文件。下载的文件包括model.json描述网络的结构;model.weights.bin存放的是网络里面神经节点的权重值。所以我们加载模型的时候也需要指定这两个文件。

files = createFileInput();
files.attribute('multiple', true);

再来看加载的方法:

regressor.load(files.elt.files, function() {
  logging('custom model loaded');
});

我们这里将files.elt是从p5的files变量里面拿到对应的DOM对象,然后再拿到DOM对象里面的files对象,这样刚才save的model和数据并不需要上传到服务器上,而是浏览器直接从文件系统里面读取出来。

好了,这节课我们学习了用ml5.js的 Regression 回归还介绍了怎么保存/加载模型。