如何使用机器学习识别图像中的刀

198 阅读4分钟

社会公益

不幸的是,持刀犯罪在英国是一个严重的问题。在截至 2018 年 3 月的一年中,英格兰和威尔士大约有40,100 起涉及刀具或利器的犯罪行为。

我认为通过能够识别图像中的刀具,可能会在犯罪侦查和减少方面有潜在的应用,并将 AI 用于某些社会公益活动。

方法

数据结构

首先,我整理了我的数据并将所有图像放入一个文件夹中,命名模式为:

knife_1.bmp notknife_1.bmp

数字是连续的:knife_1、knife_2 等。

运行 Jupyter Notebook

Jupyter notebooks是运行这些数据集的实际工具,我没有偏离这一点。

首先,我将我的笔记本设置为自动更新,并加载了我的 FastAI 库:

!pip install --upgrade fastai %reload_ext autoreload %autoreload 2 
%matplotlib inline from fastai.vision import * from fastai.metrics 
import error_rate

接下来我导入了我的数据,以便我的模型可以轻松访问它:

path = Path('/data/') knives = fnames = get_image_files(path) 
print(knives[:5])

我打印了一组五个文件名来检查它们是否被有效导入:

[PosixPath('/data/notknife_3198.bmp'),PosixPath('/data/notknife_2790.bmp'), PosixPath('/data/notknife_2296.bmp'), PosixPath('/data/notknife_8269.bmp'), PosixPath('/data/notknife_5795.bmp')]

接下来我创建了一个正则表达式模式来从我们的每个文件名中获取标签,以便能够将它们分类为一把刀或不是一把刀:

pat = r'/([^/]+)_\d+.bmp$'

然后我利用 FastAI 的一个功能,以适合计算机视觉的格式将我们的图像捆绑在一起:

data = (ImageDataBunch .from_name_re(path, knives, pat, 
ds_tfms=get_transforms(), size=224, bs=64) .normalize(imagenet_stats))

正如你在上面看到的,我们有访问数据的路径,我们的训练集knives以最适合我们稍后出现的 RestNet 数据),然后我们对数据集进行归一化。

为了检查这一切是否如我们所料,我输出了 4 行图像以及两个预期的标签knife, notknife::

data.show_batch(rows=4) print(data.classes)

作为输出,我得到以下内容:

image.png

太好了,所以我得到了我期望的图像并且它们被正确标记了。现在是时候开始训练我的模型了。

训练模型

为此,我使用卷积神经网络或 CNN,并针对 resnet50 进行训练,顾名思义,resnet50 是一个 50 层残差网络。这使我们能够使用一种叫做迁移学习的东西,它利用存储在一个模型中的知识并将其应用到另一个模型中。本质上,它是我们模型的起点。

使用 Fast.AI 这非常简单:

learn = create_cnn(data, models.resnet50, metrics=error_rate) 
learn.fit_one_cycle(4) learn.save('knives-stage-1')

正如您在上面看到的,我们正在创建我们的 CNN,根据我们提供的数据和从 resnet 推断的模型创建学习者对象。我们使用差异学习率来训练模型。

一旦完成,这可能需要一段时间!我们最终得到这样的结果:

image.png

正如你所看到的,我们从较高的错误率开始,逐渐将其从 98% 的准确率降低到 99.5% 的准确率。

所以这已经很棒了!但我们能做得更好吗?让我们弄清楚我们的模型在哪里变得混乱。

识别模型错误

interp = ClassificationInterpretation.from_learner(learn) 
interp.plot_top_losses(9)

通过运行上面的代码,我要求 FastAI 解释结果并显示前 9 个不正确的猜测,并返回以下结果:

image.png

唔。所以用我们的肉眼来看是有道理的。这位女士使用的蛇工具看起来很像一把刀,右上角的图像对比度很差,所以很难辨认出刀身。

很好,让我们找出我们的错误率从哪里开始飙升,看看我们是否可以改进:

learn.lr_find() learn.recorder.plot()

这将呈现以下图表:

图像

因此,在错误率再次飙升之前,我们可以看到图表右侧出现强烈下降,因此让我们校准我们的模型,以使用从下降顶部开始到结束位置之间的数据。

learn.fit_one_cycle(10, max_lr=slice(1e-03,1e-02)) learn.save('knives-stage-2')

在这里,我们正在运行十个周期(或纪元)的训练,并如上所述对数据进行切片以适合图表。

然后我们收到以下信息:

图像

因此,我们可以看到第二次通过时的错误率要低得多,并且最后的结果集特别好。到训练完成时,我们的准确率达到了惊人的99.84% !

接下来我将采取的步骤是将模型包装在 API 中,这样您就可以上传图像并可靠地了解是否存在刀具。我可以想象这对执法或安全特别有用,但这里使用的技术可以很容易地应用于任何主题。