本人有幸参加了GLCC开源夏令营,并得到了参与由Amazon公开的Deep Java Library开源项目的机会,本人负责的Project是DJL-Zero,该项目旨在帮助没有机器学习或者深度学习相关经验的Java开发者们能够很方便地训练出自己的深度学习模型。在该文中以周为单位总结归纳自己项目中学习到的相关知识。
第一周(或者说正式工作的第一周)
作为深度学习的newbee,第一周(之前也有2周熟悉了)的主要任务就是熟悉DJL环境,并且对ImageClassification Class进行修改——目前ImageClassification中的train方法对于每个输入的dataset样例,train方法都会训练一个ResNet作为模型输出,而导师认为对于那些Small and fast的情况,可以训练一个MobileNet。然后问题主要就是:DJL的ModelZoo当中并没有MobileNet。因此我实现了一个MobileNetV1,然后成功提交PR合入了代码仓库。PR地址:Add MobileNetV1 into modelZoo and add corresponding test class MobileNetV1Test by warthecatalyst · Pull Request #1817 · deepjavalibrary/djl (github.com)
遇到的一些挑战
- 首先最大的挑战就是作为深度学习几乎0基础的学生(至少没有过代码经验),从0开始在DJL框架里面搭建一个MobileNet也算是第一个挑战
- 第二个问题就是跟自己平时写的trash项目不一样,代码中需要遵守的规范非常多。我搭建完之后花了一个下午的时间修改代码以及JavaDoc和注释等内容,项目中的命名规范、各种类成员变量和方法等等都需要遵守一定的规范。比如PR1815也是我的,因为代码没有通过测试因此没有被合入。
- 第三个问题就是不太会用git,这里推荐一个学习git branching的神器:Learn Git Branching
第二周
这周主要是被自己的导师gank了,他的横向项目好长时间没进展,然后我又是负责人,所以花了2天搞导师的项目。
这周在DJL方面的进展主要就是Build MobileNetV2,跟MobileNetV1不同的地方在于,MobileNetV2使用了ResNet中的残差网路结构,因此这周build的难点在于怎么在DJL中搭建残差网络。在DJL中提供了ParallelBlock,因此当遇到残差块时,我只需要build一个parallel block即可。在JavaDoc(青总)的帮助下,我知道了如何在最后一层使用conv1*1(因为之前遇到了Shape Unmatch)的问题。但是一点小问题是在运行trainer.steo()时报错了,但是先提交了一个PR可能让DJL的大佬们看看就知道哪里出现了问题。
PR地址:Add MobileNetV2 into ModelZoo and corresponding Test Class MobileNetV… by warthecatalyst · Pull Request #1847 · deepjavalibrary/djl (github.com)
以及在Amazon提供的服务器上train了一下MobileNetV1 Model,最开始对他们的环境瞎搞导致他们环境直接崩溃了...然后第二次的时候吸取教训使用docker来进行训练,感觉结果挺好:
第三-第四周
第三-第四周双线程了,参加了字节的夏令营机器人小组,不小心拿了个二等奖... 然后主要是实现了MobileNetV2构建的完成,以及MobileNetV1和MobileNetV2的训练,一个bug 有好多天以为是自己的问题,结果之后调用./gradlew clean就然后就成了。 两周的PR成果:Training of MobileNetV1 and MobileNetV2 on Mnist and Cifar10 by warthecatalyst · Pull Request #1878 · deepjavalibrary/djl (github.com)
最终结果
最终,为期12周的GLCC_DJL项目也是圆满结束了。中间这个理应每周一次的记录环节也是没有如期进行(俗话说计划赶不上变化)。但在12周的过程中我觉得我也是完成了不少的任务:在整个扩展DJL-零项目中,我的任务主要分为三个模块:Image Classification、Object Detection和Tabular Prediction。
Image Classification
这一块内容基本上Zach已经完成了,但是Zach认为DJL目前缺少一些小而快的模型,因此,在DJL中加入了MobileNetV1和MobileNetV2模型。MobileNetV1和MobileNetV2模型自身在此就不过多赘述了。其使用方式可以查看DJL-Zero中Image Classification的Train方法。代码如下所示:
block = MobileNetV2.builder().setOutSize(classes.size()).build();
Object Detection
关于目标检测,这一块的内容并没有完全完成。目前确认完成的部分有yolov3的model。然后Yolov3Loss目前还存在一点问题,对于某些数据集(实际上我就成功了一个TrainPikachu),会出现loss becomes NAN的问题,并且目前只有Pytorch Engine是可以使用的,MXNet会遇到shape inconsistent的bug。
如果有兴趣继续研究的,也欢迎在我的基础上再改动一些内容,或者在DJL中实现Yolov5等更高级的目标检测模型。
TabNet
对于表格数据,深度学习却一直没有照顾到这个拥有真实世界80%数据格式的领域。直到Google的AI research team发表了关于TabNet模型。TabNet的网络架构在此处也不赘述,直接阐述如何使用。
类似于DJL-Zero的部分,TabNet模型的构造十分简单。设置好输入参数和输出参数(分别是TabularDataset的featureSize和labelSize),然后可以选择性设置numShared和numIndependent,分别代表TabNet论文中的分享层数目和独立层数目。最后build()即可。
// for fast cases, we set the number of independent layers and shared layers lower
block =
TabNet.builder()
.setInputDim(featureSize)
.setOutDim(labelSize)
.optNumIndependent(1)
.optNumShared(1)
.build();
针对不同的表格学习任务(回归任务和分类任务),定义了不同的LossFunction,分别为TabNetRegressionLoss和TabNetClassificationLoss,以供用户使用。