阅读 1525

DJL目标检测Demo

概述

Deep Java Library是AWS在2019年推出的深度学习Java库,目前已经支持MXNet、PyTorch、TensorFlow模型的训练和推理。DJL没有和固定的深度学习框架绑定,因此同一套代码可以适配不同的深度学习框架。

这里根据官网给的教程,介绍如果搭建目标检测的Demo,实现的功能包括读取本地图片,加载官方Model Zoo提供的预训练模型、进行模型推理、输出目标检测的结果图到本地。参考资料包括:SSD模型推理的官方教程Maven依赖配置DJL Maven的BOM配置DJL版本依赖项搭配

工程搭建

新建Maven项目

JDK>=1.8

项目结构

依赖引入

引入djl本身以及djl依赖的其他包。

    <build>
        <plugins>
            <!-->设定maven编译使用jdk8<-->
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>8</source>
                    <target>8</target>
                </configuration>
            </plugin>
        </plugins>
    </build>

    <!-->以BOM的方式统一管理依赖包的版本<-->
    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>0.9.0</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <dependencies>
        <dependency>
            <groupId>commons-cli</groupId>
            <artifactId>commons-cli</artifactId>
            <version>1.4</version>
        </dependency>
        <dependency>
            <groupId>com.google.code.gson</groupId>
            <artifactId>gson</artifactId>
            <version>2.8.5</version>
        </dependency>
        <!-->日志依赖包<-->
        <dependency>
            <groupId>org.apache.logging.log4j</groupId>
            <artifactId>log4j-slf4j-impl</artifactId>
            <version>2.12.1</version>
        </dependency>
        <!-->使用djl必须引入的依赖包<-->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <!-->使用不同的深度学习框架模型引入不同的依赖包, 将mxnet改为pytorch即可更换深度学习框架
        	Apache MXNet engine implementation<-->
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-engine</artifactId>
        </dependency>
        <!-->使用不同的深度学习框架模型引入不同的依赖包,
			Apache MXNet native library<-->
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-native-auto</artifactId>
            <scope>runtime</scope>
        </dependency>
        <!-->使用不同的深度学习框架模型引入不同的依赖包, 
			A ModelZoo containing models exported from Apache MXNet<-->
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-model-zoo</artifactId>
        </dependency>
    </dependencies>
复制代码

日志配置文件

这部分没有固定要求,按实际需要来配置就可以,这里沿用从网上找的一份简单配置文件log4j2.xml

<?xml version="1.0" encoding="UTF-8"?>
<Configuration status="WARN">
    <Properties>
        <property name="log_level" value="info" />
        <Property name="log_dir" value="log" />
        <property name="log_pattern"
                  value="[%d{yyyy-MM-dd HH:mm:ss.SSS}] [%p] - [%t] %logger - %m%n" />
        <property name="file_name" value="test" />
        <property name="every_file_size" value="100 MB" />
    </Properties>
    <Appenders>
        <Console name="Console" target="SYSTEM_OUT">
            <PatternLayout pattern="${log_pattern}" />
        </Console>
        <RollingFile name="RollingFile"
                     filename="${log_dir}/${file_name}.log"
                     filepattern="${log_dir}/$${date:yyyy-MM}/${file_name}-%d{yyyy-MM-dd}-%i.log">
            <ThresholdFilter level="DEBUG" onMatch="ACCEPT"
                             onMismatch="DENY" />
            <PatternLayout pattern="${log_pattern}" />
            <Policies>
                <SizeBasedTriggeringPolicy
                        size="${every_file_size}" />
                <TimeBasedTriggeringPolicy modulate="true"
                                           interval="1" />
            </Policies>
            <DefaultRolloverStrategy max="20" />
        </RollingFile>

        <RollingFile name="RollingFileErr"
                     fileName="${log_dir}/${file_name}-warnerr.log"
                     filePattern="${log_dir}/$${date:yyyy-MM}/${file_name}-%d{yyyy-MM-dd}-warnerr-%i.log">
            <ThresholdFilter level="WARN" onMatch="ACCEPT"
                             onMismatch="DENY" />
            <PatternLayout pattern="${log_pattern}" />
            <Policies>
                <SizeBasedTriggeringPolicy
                        size="${every_file_size}" />
                <TimeBasedTriggeringPolicy modulate="true"
                                           interval="1" />
            </Policies>
        </RollingFile>
    </Appenders>
    <Loggers>
        <Root level="${log_level}">
            <AppenderRef ref="Console" />
            <AppenderRef ref="RollingFile" />
            <appender-ref ref="RollingFileErr" />
        </Root>
    </Loggers>
</Configuration>
复制代码

业务代码

引自:github.com/awslabs/djl…

package org.town;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * An example of inference using an object detection model.
 */
public final class ObjectDetection {

    private static final Logger logger = LoggerFactory.getLogger(ObjectDetection.class);

    private ObjectDetection() {}

    public static void main(String[] args) throws IOException, ModelException, TranslateException {
        DetectedObjects detection = ObjectDetection.predict();
        logger.info("{}", detection);
    }

    public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
        Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
        Image img = ImageFactory.getInstance().fromFile(imageFile);

        String backbone;
        if ("TensorFlow".equals(Engine.getInstance().getEngineName())) {
            backbone = "mobilenet_v2";
        } else {
            backbone = "resnet50";
        }

        Criteria<Image, DetectedObjects> criteria =
                Criteria.builder()
                        .optApplication(Application.CV.OBJECT_DETECTION)
                        .setTypes(Image.class, DetectedObjects.class)
                        .optFilter("backbone", backbone)
                        .optProgress(new ProgressBar())
                        .build();

        try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
            try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
                DetectedObjects detection = predictor.predict(img);
                saveBoundingBoxImage(img, detection);
                return detection;
            }
        }
    }

    private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
            throws IOException {
        Path outputDir = Paths.get("build/output");
        Files.createDirectories(outputDir);

        // Make image copy with alpha channel because original image was jpg
        Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
        newImage.drawBoundingBoxes(detection);

        Path imagePath = outputDir.resolve("detected-dog_bike_car.png");
        // OpenJDK can't save jpg with alpha channel
        newImage.save(Files.newOutputStream(imagePath), "png");
        logger.info("Detected objects image has been saved in: {}", imagePath);
    }
}
复制代码

输出结果

问题记录

问题1:JDK版本过低导致Static interface method calls are not supported at language level '7'

解决办法:pom配置文件中显式指定使用jdk8进行编译。

问题2:Exception in thread "main" ai.djl.repository.zoo.ModelNotFoundException: No matching model with specified Input/Output type found.

解决办法:pom配置文件中没有引入ai.djl.mxnet:mxnet-model-zoo依赖包导致的,引入依赖即可。

问题3:切换为pytorch引擎后出错。

切换方式:

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-auto</artifactId>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-model-zoo</artifactId>
</dependency>
复制代码

报错信息:

[2021-01-01 23:49:12.420] [WARN] - [main] ai.djl.engine.Engine - Failed to load engine from: ai.djl.pytorch.engine.PtEngineProvider
ai.djl.engine.EngineException: Failed to load PyTorch native library
	at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:56) ~[pytorch-engine-0.9.0.jar:?]
	at ai.djl.pytorch.engine.PtEngineProvider.getEngine(PtEngineProvider.java:27) ~[pytorch-engine-0.9.0.jar:?]
	at ai.djl.engine.Engine.initEngine(Engine.java:59) [api-0.9.0.jar:?]
	at ai.djl.engine.Engine.<clinit>(Engine.java:49) [api-0.9.0.jar:?]
	at org.town.ObjectDetection.predict(ObjectDetection.java:45) [classes/:?]
	at org.town.ObjectDetection.main(ObjectDetection.java:36) [classes/:?]
Caused by: java.lang.UnsatisfiedLinkError: C:\Users\steel\.djl.ai\pytorch\1.7.0-cpu-win-x86_64\asmjit.dll: Can't find dependent libraries
	at java.lang.ClassLoader$NativeLibrary.load(Native Method) ~[?:1.8.0_251]
	at java.lang.ClassLoader.loadLibrary0(ClassLoader.java:1934) ~[?:1.8.0_251]
	at java.lang.ClassLoader.loadLibrary(ClassLoader.java:1817) ~[?:1.8.0_251]
	at java.lang.Runtime.load0(Runtime.java:809) ~[?:1.8.0_251]
	at java.lang.System.load(System.java:1086) ~[?:1.8.0_251]
	at java.util.stream.ForEachOps$ForEachOp$OfRef.accept(ForEachOps.java:184) ~[?:1.8.0_251]
	at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193) ~[?:1.8.0_251]
	at java.util.stream.ReferencePipeline$2$1.accept(ReferencePipeline.java:175) ~[?:1.8.0_251]
	at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193) ~[?:1.8.0_251]
	at java.util.Iterator.forEachRemaining(Iterator.java:116) ~[?:1.8.0_251]
	at java.util.Spliterators$IteratorSpliterator.forEachRemaining(Spliterators.java:1801) ~[?:1.8.0_251]
	at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:482) ~[?:1.8.0_251]
	at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:472) ~[?:1.8.0_251]
	at java.util.stream.ForEachOps$ForEachOp.evaluateSequential(ForEachOps.java:151) ~[?:1.8.0_251]
	at java.util.stream.ForEachOps$ForEachOp$OfRef.evaluateSequential(ForEachOps.java:174) ~[?:1.8.0_251]
	at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234) ~[?:1.8.0_251]
	at java.util.stream.ReferencePipeline.forEach(ReferencePipeline.java:418) ~[?:1.8.0_251]
	at ai.djl.pytorch.jni.LibUtils.loadWinDependencies(LibUtils.java:119) ~[pytorch-engine-0.9.0.jar:?]
	at ai.djl.pytorch.jni.LibUtils.loadLibrary(LibUtils.java:75) ~[pytorch-engine-0.9.0.jar:?]
	at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:44) ~[pytorch-engine-0.9.0.jar:?]
	... 5 more
Exception in thread "main" ai.djl.engine.EngineException: No deep learning engine found.
Please refer to https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md for more details.
	at ai.djl.engine.Engine.getInstance(Engine.java:119)
	at org.town.ObjectDetection.predict(ObjectDetection.java:45)
	at org.town.ObjectDetection.main(ObjectDetection.java:36)
Caused by: ai.djl.engine.EngineException: Failed to load PyTorch native library
	at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:56)
	at ai.djl.pytorch.engine.PtEngineProvider.getEngine(PtEngineProvider.java:27)
	at ai.djl.engine.Engine.initEngine(Engine.java:59)
	at ai.djl.engine.Engine.<clinit>(Engine.java:49)
	... 2 more
Caused by: java.lang.UnsatisfiedLinkError: C:\Users\steel\.djl.ai\pytorch\1.7.0-cpu-win-x86_64\asmjit.dll: Can't find dependent libraries
	at java.lang.ClassLoader$NativeLibrary.load(Native Method)
	at java.lang.ClassLoader.loadLibrary0(ClassLoader.java:1934)
	at java.lang.ClassLoader.loadLibrary(ClassLoader.java:1817)
	at java.lang.Runtime.load0(Runtime.java:809)
	at java.lang.System.load(System.java:1086)
	at java.util.stream.ForEachOps$ForEachOp$OfRef.accept(ForEachOps.java:184)
	at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193)
	at java.util.stream.ReferencePipeline$2$1.accept(ReferencePipeline.java:175)
	at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193)
	at java.util.Iterator.forEachRemaining(Iterator.java:116)
	at java.util.Spliterators$IteratorSpliterator.forEachRemaining(Spliterators.java:1801)
	at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:482)
	at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:472)
	at java.util.stream.ForEachOps$ForEachOp.evaluateSequential(ForEachOps.java:151)
	at java.util.stream.ForEachOps$ForEachOp$OfRef.evaluateSequential(ForEachOps.java:174)
	at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234)
	at java.util.stream.ReferencePipeline.forEach(ReferencePipeline.java:418)
	at ai.djl.pytorch.jni.LibUtils.loadWinDependencies(LibUtils.java:119)
	at ai.djl.pytorch.jni.LibUtils.loadLibrary(LibUtils.java:75)
	at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:44)
	... 5 more
复制代码

解决办法:未找到。

文章分类
人工智能
文章标签