用TensorFlowLite写个手写体识别APP

1,183 阅读5分钟

今天有个网友在手把手教你在Android上搭建tensorflow Lite2.0这篇文章下评论

求问如何进行一个图像的输入和数组的输出?

我想这也是很多初学者的痛点,很多入门同学都没有完整从模型建立,训练,到转换成TensorFlowLite,并在Android中实际的用。

于是我就把我之前写的demo给了他,想想还是抽空把这个demo写成文章,希望能够给帮助到更多的入门的同学。

虽然基于TensorFlow 实现手写体的文章,一抓一大把,但是我还是有必要啰嗦下,毕竟它是很好的入门人工智能的实例。

我不关注的手写体识别算法的细节,关注整个从模型到应用的整个过程,想对算法了解的,请自行学习。

有兴趣的同学可以关注下我的系列博客人工智能系列(更新中……),自己也在学习这方面的知识,一起学习和交流。

1 手写体基础知识

1.1 探索MINIST数据集

采用的MNIST数据集,它来自美国国家标准与技术研究所,National Institute of Standards and Technology(NIST)。 训练集 (training set) 由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set)也是同样比例的手写数字数据。

数据集中每张图片是什么样的呢?

就张这样子: 在这里插入图片描述 通过下面代码获得:

# Plot ad hoc mnist instances
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

# load (downloaded if needed) the MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# plot 4 images as gray scale
plt.subplot(221)
plt.imshow(X_train[0], cmap=plt.get_cmap("gray"))
plt.subplot(222)
plt.imshow(X_train[1], cmap=plt.get_cmap("gray"))
plt.subplot(223)
plt.imshow(X_train[2], cmap=plt.get_cmap("gray"))
plt.subplot(224)
plt.imshow(X_train[3], cmap=plt.get_cmap("gray"))
# show the plot
plt.show()

但是实际上存储是什么呢? 在这里插入图片描述 你可以发现这是一个0字,存储是0这张图片的RGB的值,凡是值为零的地方都是黑色,非零的地方都是不同灰阶。这就是一张图片灰阶RGB矩阵。

1.2 CNN基本介绍

本次采用手写体识别算法就是CNN(卷积神经网络),在计算机视觉中应用比较广泛。 最为经典的CNN手写体识别图,描述了手写体识别的整个过程,具体的细节就不讲了,有机会写一篇这个算法细节的文章,但是本文神经网络模型结构如下: CNN

1.3 基于TensorFlow 的手写体识别

采用TensorFlow 中Keras接口,比较适合新手使用。让你感觉创建神经网络模型就像是搭积木一样。

代码如下,留意注释。

import numpy
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.python.keras.utils import np_utils
import tensorflow as tf
import pathlib

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# reshape to be [samples][channels][width][height]
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')

# normalize inputs from 0-255 to 0-1
X_train = X_train / 255
X_test = X_test / 255

print(X_train.shape)
# one hot encode outputs
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
print(X_train[0])

num_classes = y_test.shape[1]

def baseline_model():
    # create model
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(5, 5),
                     input_shape=(28, 28, 1),//采用单通道的图片
                     activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(num_classes, activation='softmax'))
    # Compile model
    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer='adam',
                  metrics=['accuracy'])
    return model


model = baseline_model()
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200, verbose=2)

# Final evaluation of the model
scores = model.evaluate(X_test, y_test, verbose=0)
print("CNN Error: %.2f%%" % (100 - scores[1] * 100))

# 上面升级网络训练的过程
# 下面需要将其转换tensorflow Lite模型,便于在Android中使用。
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

tflite_model_file = pathlib.Path('saved_model/model.tflite')
tflite_model_file.write_bytes(tflite_model)

2 在Android实现手写体识别

如果你不知道如何配置Android的环境,请参考手把手教你在Android上搭建tensorflow Lite2.0

2.1 加载模型

将训练好的TensorFlow Lite 文件放在Android的asset文件夹下。

public class TF {
    private static Context mContext;
    Interpreter mInterpreter;
    private static TF instance;

    public static TF newInstance(Context context) {
        mContext = context;
        if (instance == null) {
            instance = new TF();
        }
        return instance;
    }

    Interpreter get() {
        try {
            if (Objects.isNull(mInterpreter))
                mInterpreter = new Interpreter(loadModelFile(mContext));
        } catch (IOException e) {
            e.printStackTrace();
        }
        return mInterpreter;
    }

    // 获取文件
    private MappedByteBuffer loadModelFile(Context context) throws IOException {
        AssetFileDescriptor fileDescriptor = context.getAssets().openFd("model.tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }
}

2.2 自定义写画View

public class HandWriteView extends View {
    Path mPath = new Path();
    Paint mPaint;

    Bitmap mBitmap;
    Canvas mCanvas;

    public HandWriteView(Context context) {
        super(context);
        init();
    }

    public HandWriteView(Context context, AttributeSet attrs) {
        super(context, attrs);
        init();
    }

    void init() {
        mPaint = new Paint();
        mPaint.setColor(Color.WHITE);
        mPaint.setStyle(Paint.Style.STROKE);
        mPaint.setStrokeJoin(Paint.Join.ROUND);
        mPaint.setStrokeCap(Paint.Cap.ROUND);
        mPaint.setStrokeWidth(30);

    }

    @Override
    protected void onDraw(Canvas canvas) {
        super.onDraw(canvas);
        mBitmap = Bitmap.createBitmap(getWidth(), getHeight(), Bitmap.Config.ARGB_8888);
        mCanvas = new Canvas(mBitmap);
        mCanvas.drawColor(Color.BLACK);
        canvas.drawPath(mPath, mPaint);
        mCanvas.drawPath(mPath, mPaint);
    }

    @Override
    public boolean onTouchEvent(MotionEvent event) {
        switch (event.getAction()) {
            case MotionEvent.ACTION_DOWN:
                mPath.moveTo(event.getX(), event.getY());
                break;
            case MotionEvent.ACTION_MOVE:
                mPath.lineTo(event.getX(), event.getY());
                break;
            case MotionEvent.ACTION_UP:
            case MotionEvent.ACTION_CANCEL:
                break;
        }
        postInvalidate();
        return true;
    }

    Bitmap getBitmap() {
        mPath.reset();
        return mBitmap;
    }
}

2.3 将bitmap转成网络需要的格式

因为数据集中的数据都是28 * 28 * 3的,28为图片的宽和高,3为R,G,B三个通道,所以在输入到网络之前,我们需要将bitmap转成网络需要的格式。

private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
        int inputShape[] = TF.newInstance(getApplicationContext()).get().getInputTensor(0).shape();
        int inputImageWidth = inputShape[1];
        int inputImageHeight = inputShape[2];
        Bitmap bs = Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true);
        mImageView.setImageBitmap(bs);
        ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4 * inputImageHeight * inputImageWidth);
        byteBuffer.order(ByteOrder.nativeOrder());

        int[] pixels = new int[inputImageWidth * inputImageHeight];
        bs.getPixels(pixels, 0, bs.getWidth(), 0, 0, bs.getWidth(), bs.getHeight());

        for (int pixelValue : pixels) {
            int r = (pixelValue >> 16 & 0xFF);
            int g = (pixelValue >> 8 & 0xFF);
            int b = (pixelValue & 0xFF);

            // Convert RGB to grayscale and normalize pixel value to [0..1]
            float normalizedPixelValue = (r + g + b) / 3.0f / 255.0f;
            byteBuffer.putFloat(normalizedPixelValue);
        }
        return byteBuffer;
    }

2.4 识别结果的输出

识别的结果是根据0-9的概率进行判断,概率最大的就是识别的结果。

float[][] input = new float[1][10];
TF.newInstance(getApplicationContext()).get().run(convertBitmapToByteBuffer(mHandWriteView.getBitmap()), input);
int result = -1;
float value = 0f;
for (int j = 0; j < 10; j++) {
    if (input[0][j] > value) {
        value = input[0][j];
        result = j;
    }
Log.i("TAG", "result: " + j + " " + input[0][j]);
}
if (input[0][result] < 0.2f) {
    mTextView.setText("结果为:未识别");
} else {
    mTextView.setText("结果为:" + result);
}

识别结果: 在这里插入图片描述

若有需要,请自行点击demo下载。

3 总结

开发一个人工智能APP的主要流程就这么多,关键还是在于算法,要想得到更为精准的模型,除了要采用更好的模型之外,还需要对数据进行旋转,增强或者白质化,来提高数据的多样性。

欢迎大家一起交流!!!!