机器学习数字识别--从零到应用

204 阅读10分钟

机器学习数字识别--从零到应用

机器学习、计算机视觉、构建强大的API以及创建漂亮的UI都是令人兴奋的领域,见证了大量的创新。

前两者需要大量的数学和科学,而API和UI开发则以算法思维和设计灵活的架构为中心。它们是非常不同的,所以决定你下一步想学哪一个可能是有挑战性的。本文的目的是演示如何在创建一个图像处理应用程序时采用这四种方法。

我们要建立的应用程序是一个简单的数字识别器。你画画,机器就能预测出数字。简单性是至关重要的,因为它使我们能够看到大局而不是关注细节。

为了简单起见,我们将使用最流行和最容易学习的技术。机器学习部分将使用Python进行后端应用。至于应用程序的交互方面,我们将通过一个不需要介绍的JavaScript库来操作。React

机器学习来猜测数字

我们应用程序的核心部分是猜测所抽号码的算法。机器学习将是用于实现良好猜测质量的工具。这种基本的人工智能允许一个系统在给定的数据量下自动学习。从广义上讲,机器学习是一个在数据中寻找巧合或一组巧合的过程,依靠它们来猜测结果。

我们的图像识别过程包含三个步骤:

  • 获取绘制的数字图像进行训练
  • 训练系统通过训练数据来猜测数字
  • 用新的/未知的数据测试该系统

环境

我们需要一个虚拟环境来处理Python中的机器学习。这种方法很实用,因为它管理着所有需要的Python包,所以你不需要担心它们。

让我们用下面的终端命令来安装它。

python3 -m venv virtualenv
source virtualenv/bin/activate

训练模型

在我们开始写代码之前,我们需要为我们的机器选择一个合适的 "老师"。通常情况下,数据科学专家在选择最佳模型之前会尝试不同的模型。我们将跳过需要大量技巧的非常高级的模型,继续使用k-nearest neighbors算法

这是一种算法,它得到一些数据样本,并将它们安排在一个平面上,按照一组给定的特征排序。为了更好地理解它,让我们回顾一下下面的图片。

为了检测绿点的类型,我们应该检查k个最近的邻居的类型,其中k是参数集。考虑到上面的图片,如果k等于1、2、3或4,那么猜测将是一个黑三角,因为绿点的大多数最近的k个邻居是黑三角。如果我们将k增加到5,那么大多数物体都是蓝色方块,因此猜测将是一个蓝色方块

在创建我们的机器学习模型时,需要一些依赖性:

  • sklearn.neighbors.KNeighborsClassifier是我们要使用的分类器。
  • sklearn.model_selection.train_test_split是一个函数,它将帮助我们把数据分成训练数据和用于检查模型正确性的数据。
  • sklearn.model_selection.cross_val_score是用来获得模型正确性标记的函数。该值越高,正确性就越好。
  • sklearn.metrics.classification_report是用于显示模型猜测的统计报告的函数。
  • sklearn.datasets是用于获取训练数据的包(数字的图像)。
  • numpy是一个在科学界广泛使用的包,因为它为在Python中操作多维数据结构提供了一种高效和舒适的方式。
  • matplotlib.pyplot是用于可视化数据的包。

让我们首先安装并导入所有这些程序。

pip install sklearn numpy matplotlib scipy

from sklearn.datasets import load_digits
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, cross_val_score
import numpy as np
import matplotlib.pyplot as plt 

现在,我们需要加载MNIST数据库。MNIST是一个经典的手写图像数据集,被机器学习领域的成千上万的新手使用。

digits = load_digits()

一旦数据被取来并准备好,我们就可以进入下一步,将数据分成两部分:训练测试

我们将使用75%的数据来训练我们的模型来猜测数字,我们将使用其余的数据来测试模型的正确性。

(X_train, X_test, y_train, y_test) = train_test_split(
    digits.data, digits.target, test_size=0.25, random_state=42
)

现在数据已经安排好了,我们准备使用它。我们将尝试为我们的模型找到最佳参数k,这样猜测就会更精确。在这个阶段,我们不能不考虑k值,因为我们必须用不同的k值来评估模型。

让我们看看为什么必须考虑一系列的k值,以及这如何提高我们模型的准确性。

ks = np.arange(2, 10)
scores = []
for k in ks:
    model = KNeighborsClassifier(n_neighbors=k)
    score = cross_val_score(model, X_train, y_train, cv=5)
    score.mean()
    scores.append(score.mean())

plt.plot(scores, ks)
plt.xlabel('accuracy')
plt.ylabel('k')
plt.show()

执行这段代码,你会看到以下描述算法在不同k值下的准确性的图。

正如你所看到的,k值为3,可以确保我们的模型和数据集的最佳准确性。

使用Flask来构建一个API

应用程序的核心,也就是预测图像中数字的算法,现在已经准备好了。接下来,我们需要用一个API层来装饰这个算法,使其可以被使用。让我们使用流行的Flask网络框架来简洁地完成这项工作。

我们将首先在虚拟环境中安装Flask和与图像处理有关的依赖项。

pip install Flask Pillow scikit-image

当安装完成后,我们转到创建应用程序的入口点文件。

touch app.py

该文件的内容将看起来像这样

import os

from flask import Flask
from views import PredictDigitView, IndexView

app = Flask(__name__)

app.add_url_rule(
    '/api/predict',
    view_func=PredictDigitView.as_view('predict_digit'),
    methods=['POST']
)

app.add_url_rule(
    '/',
    view_func=IndexView.as_view('index'),
    methods=['GET']
)

if __name__ == 'main':
    port = int(os.environ.get("PORT", 5000))
    app.run(host='0.0.0.0', port=port)

你会得到一个错误,说PredictDigitViewIndexView 没有定义。下一步是创建一个文件来初始化这些视图。

from flask import render_template, request, Response
from flask.views import MethodView, View

from flask.views import View

from repo import ClassifierRepo
from services import PredictDigitService
from settings import CLASSIFIER_STORAGE

class IndexView(View):
    def dispatch_request(self):
        return render_template('index.html')

class PredictDigitView(MethodView):
    def post(self):
        repo = ClassifierRepo(CLASSIFIER_STORAGE)
        service = PredictDigitService(repo)
        image_data_uri = request.json['image']
        prediction = service.handle(image_data_uri)
        return Response(str(prediction).encode(), status=200)

再一次,我们会遇到一个关于未解决的导入的错误。视图包依赖于我们还没有的三个文件:

  • 设置
  • Repo
  • 服务

我们将逐一实现它们。

Settings是一个带有配置和常量变量的模块。它将为我们存储通往序列化分类器的路径。这引出了一个合乎逻辑的问题。为什么我需要保存分类器?

因为这是一个提高你的应用程序性能的简单方法。我们将存储分类器的准备版本,而不是在每次收到请求时训练分类器,使其能够开箱即用。

import os

BASE_DIR = os.getcwd()
CLASSIFIER_STORAGE = os.path.join(BASE_DIR, 'storage/classifier.txt')

设置的机制--获取分类器--将在我们列表中的下一个包中初始化,即Repo。这是一个带有两个方法的类,使用Python内置的pickle 模块检索和更新训练有素的分类器。

import pickle

class ClassifierRepo:
    def __init__(self, storage):
        self.storage = storage

    def get(self):
        with open(self.storage, 'wb') as out:
            try:
                classifier_str = out.read()
                if classifier_str != '':
                    return pickle.loads(classifier_str)
                else:
                    return None
            except Exception:
                return None

    def update(self, classifier):
        with open(self.storage, 'wb') as in_:
            pickle.dump(classifier, in_)

我们已经接近完成我们的API了。现在它只缺少服务模块。它的目的是什么?

  • 从存储器中获取训练好的分类器
  • 把从用户界面传来的图像转换成分类器能理解的格式
  • 通过分类器用格式化的图像计算预测值
  • 返回预测结果

让我们对这个算法进行编码。

from sklearn.datasets import load_digits

from classifier import ClassifierFactory
from image_processing import process_image

class PredictDigitService:
    def __init__(self, repo):
        self.repo = repo

    def handle(self, image_data_uri):
        classifier = self.repo.get()
        if classifier is None:
            digits = load_digits()
            classifier = ClassifierFactory.create_with_fit(
                digits.data,
                digits.target
            )
            self.repo.update(classifier)
        
        x = process_image(image_data_uri)
        if x is None:
            return 0

        prediction = classifier.predict(x)[0]
        return prediction

在这里你可以看到,PredictDigitService 有两个依赖关系:ClassifierFactoryprocess_image

我们先创建一个类来创建和训练我们的模型。

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

class ClassifierFactory:
    @staticmethod
    def create_with_fit(data, target):
        model = KNeighborsClassifier(n_neighbors=3)
        model.fit(data, target)
        return model

API已经准备好了,可以行动了。现在我们可以进行图像处理的步骤了。

图像处理

图像处理是一种对图像进行某些操作的方法,以增强它或从中提取一些有用的信息。在我们的案例中,我们需要将用户绘制的图像平稳地过渡到机器学习模型的格式。

让我们导入一些辅助工具来实现这一目标。

import numpy as np
from skimage import exposure
import base64
from PIL import Image, ImageOps, ImageChops
from io import BytesIO

我们可以把过渡过程分成六个不同的部分。

1.用一种颜色替换透明背景

def replace_transparent_background(image):
    image_arr = np.array(image)

    if len(image_arr.shape) == 2:
        return image

    alpha1 = 0
    r2, g2, b2, alpha2 = 255, 255, 255, 255

    red, green, blue, alpha = image_arr[:, :, 0], image_arr[:, :, 1], image_arr[:, :, 2], image_arr[:, :, 3]
    mask = (alpha == alpha1)
    image_arr[:, :, :4][mask] = [r2, g2, b2, alpha2]

    return Image.fromarray(image_arr)

2.修剪开放的边框

def trim_borders(image):
    bg = Image.new(image.mode, image.size, image.getpixel((0,0)))
    diff = ImageChops.difference(image, bg)
    diff = ImageChops.add(diff, diff, 2.0, -100)
    bbox = diff.getbbox()
    if bbox:
        return image.crop(bbox)
    
    return image

3.增加同等大小的边框

def pad_image(image):
    return ImageOps.expand(image, border=30, fill='#fff')

4.将图像转换为灰度模式

def to_grayscale(image):
    return image.convert('L')

5.反转颜色

def invert_colors(image):
    return ImageOps.invert(image)

6.将图像的大小调整为8x8格式

def resize_image(image):
    return image.resize((8, 8), Image.LINEAR)

现在你可以测试该应用程序。运行应用程序,并输入下面的命令,向API发送一个带有这个iStock图片的请求。

export FLASK_APP=app
flask run
curl "http://localhost:5000/api/predict" -X "POST" -H "Content-Type: application/json" -d "{\"image\": \"data:image/png;base64,$(curl "https://media.istockphoto.com/vectors/number-eight-8-hand-drawn-with-dry-brush-vector-id484207302?k=6&m=484207302&s=170667a&w=0&h=s3YANDyuLS8u2so-uJbMA2uW6fYyyRkabc1a6OTq7iI=" | base64)\"}" -i

你应该看到以下输出:

HTTP/1.1 100 Continue

HTTP/1.0 200 OK
Content-Type: text/html; charset=utf-8
Content-Length: 1
Server: Werkzeug/0.14.1 Python/3.6.3
Date: Tue, 27 Mar 2018 07:02:08 GMT

8

样本图片描述的是数字8,我们的应用程序可以正确识别它。

通过React创建一个绘图窗格

为了快速启动前端应用程序,我们将使用CRA的模板

create-react-app frontend
cd frontend

在设置好工作场所后,我们还需要一个依赖关系来绘制数字。react-sketch包与我们的需求完美匹配。

npm i react-sketch

该应用程序只有一个组件。我们可以把这个组件分为两部分:逻辑和视图

视图部分负责表示绘图窗格、提交重置按钮。当进行交互时,我们还应该表示预测或错误。从逻辑角度看,它有以下职责:提交图像清除草图

每当用户点击提交时,该组件将从草图组件中提取图像,并向API模块的makePrediction 功能提出上诉。如果对后端的请求成功了,我们将设置预测状态变量。否则,我们将更新错误状态。

当用户点击重置时,草图将被清除。

import React, { useRef, useState } from "react";

import { makePrediction } from "./api";

const App = () => {
  const sketchRef = useRef(null);
  const [error, setError] = useState();
  const [prediction, setPrediction] = useState();

  const handleSubmit = () => {
    const image = sketchRef.current.toDataURL();

    setPrediction(undefined);
    setError(undefined);

    makePrediction(image).then(setPrediction).catch(setError);
  };

  const handleClear = (e) => sketchRef.current.clear();

  return null
}

这个逻辑已经足够了。现在我们可以给它添加视觉界面了。

import React, { useRef, useState } from "react";
import { SketchField, Tools } from "react-sketch";

import { makePrediction } from "./api";

import logo from "./logo.svg";
import "./App.css";

const pixels = (count) => `${count}px`;
const percents = (count) => `${count}%`;

const MAIN_CONTAINER_WIDTH_PX = 200;
const MAIN_CONTAINER_HEIGHT = 100;
const MAIN_CONTAINER_STYLE = {
  width: pixels(MAIN_CONTAINER_WIDTH_PX),
  height: percents(MAIN_CONTAINER_HEIGHT),
  margin: "0 auto",
};

const SKETCH_CONTAINER_STYLE = {
  border: "1px solid black",
  width: pixels(MAIN_CONTAINER_WIDTH_PX - 2),
  height: pixels(MAIN_CONTAINER_WIDTH_PX - 2),
  backgroundColor: "white",
};

const App = () => {
  const sketchRef = useRef(null);
  const [error, setError] = useState();
  const [prediction, setPrediction] = useState();

  const handleSubmit = () => {
    const image = sketchRef.current.toDataURL();

    setPrediction(undefined);
    setError(undefined);

    makePrediction(image).then(setPrediction).catch(setError);
  };

  const handleClear = (e) => sketchRef.current.clear();

  return (
    <div className="App" style={MAIN_CONTAINER_STYLE}>
      <div>
        <header className="App-header">
          <img src={logo} className="App-logo" alt="logo" />
          <h1 className="App-title">Draw a digit</h1>
        </header>
        <div style={SKETCH_CONTAINER_STYLE}>
          <SketchField
            ref={sketchRef}
            width="100%"
            height="100%"
            tool={Tools.Pencil}
            imageFormat="jpg"
            lineColor="#111"
            lineWidth={10}
          />
        </div>
        {prediction && <h3>Predicted value is: {prediction}</h3>}
        <button onClick={handleClear}>Clear</button>
        <button onClick={handleSubmit}>Guess the number</button>
        {error && <p style={{ color: "red" }}>Something went wrong</p>}
      </div>
    </div>
  );
};

export default App;

该组件已经准备好了,通过执行并在之后进入localhost:3000 来测试它。

npm run start

演示程序可以在这里找到。你也可以在GitHub上浏览源代码。

总结

这个分类器的质量并不完美,我也不假装它是完美的。我们用于训练的数据和来自用户界面的数据之间的差异是巨大的。尽管如此,我们还是在不到30分钟的时间里从头开始创建了一个工作的应用程序。

在这个过程中,我们磨练了我们在四个领域的技能。

  • 机器学习
  • 后端开发
  • 图像处理
  • 前端开发

能够识别手写数字的软件不乏潜在的使用案例,从教育和行政软件到邮政和金融服务。

因此,我希望这篇文章能够激励你提高你的机器学习能力、图像处理能力以及前端和后端开发能力,并利用这些技能来设计出精彩而有用的应用程序。

了解基础知识

什么是机器学习中的MNIST?

MNIST是计算机视觉中最受欢迎的入门级数据集之一。它包含了数以千计的手写数字的图像。

什么是机器学习中的训练?

机器学习模型从数据中学习。为了使模型足够聪明,我们需要提供我们拥有的预期结果的数据。该模型将使用这些数据来检测数据参数和预期结果之间的关系。

什么是机器学习中的图像处理?

图像处理是一种在图像上进行一些操作的方法,以获得一个增强的图像或提取有用的信息。