后端部署深度学习模型

419 阅读1分钟

一、需求

将Pytorch模型整合到后端中,具体的要求是不能采用启动Python进程来启动pytorch项目并将结果返回给后端项目,而是直接再Java项目中直接运行模型,得到对应的结果。

二、技术路线

  • 将利用Pytorch中trance函数将训练得到的模型进行序列化
  • 利用DJL工具库加载模型和对应的数据

三、maven依赖

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<parent>
		<groupId>org.springframework.boot</groupId>
		<artifactId>spring-boot-starter-parent</artifactId>
		<version>2.2.1.RELEASE</version>
		<relativePath/> <!-- lookup parent from repository -->
	</parent>
	<groupId>com.example</groupId>
	<artifactId>djlproject</artifactId>
	<version>0.0.1-SNAPSHOT</version>
	<name>djlproject</name>
	<description>Demo project for Spring Boot</description>

	<dependencies>


		<dependency>
			<groupId>ai.djl.pytorch</groupId>
			<artifactId>pytorch-engine</artifactId>
			<version>0.9.0</version>
			<scope>runtime</scope>
		</dependency>

		<dependency>
            <groupId>net.java.dev.jna</groupId>
            <artifactId>jna</artifactId>
            <version>5.3.0</version>
        </dependency>


		<dependency>
			<groupId>ai.djl.pytorch</groupId>
			<artifactId>pytorch-jni</artifactId>
			<version>1.8.1-0.16.0</version>
			<scope>runtime</scope>
		</dependency>

		<dependency>
			<groupId>ai.djl.pytorch</groupId>
			<artifactId>pytorch-native-cpu</artifactId>
			<classifier>win-x86_64</classifier>
			<scope>runtime</scope>
			<version>1.7.0</version>
		</dependency>

        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.7.0</version>
            <scope>runtime</scope>
        </dependency>

		<dependency>
		<groupId>org.nd4j</groupId>
		<artifactId>nd4j-native-platform</artifactId>
		<version>1.0.0-M1.1</version>
	</dependency>


		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter</artifactId>
		</dependency>

		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-devtools</artifactId>
			<scope>runtime</scope>
			<optional>true</optional>
		</dependency>

		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-test</artifactId>
			<scope>test</scope>
		</dependency>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-web</artifactId>
		</dependency>


	</dependencies>

	<build>
		<plugins>
			<plugin>
				<groupId>org.springframework.boot</groupId>
				<artifactId>spring-boot-maven-plugin</artifactId>
				<configuration>
					<excludes>
						<exclude>
							<groupId>org.projectlombok</groupId>
							<artifactId>lombok</artifactId>
						</exclude>
					</excludes>
				</configuration>
			</plugin>
		</plugins>
	</build>

</project>

四、配置类

@Configuration
public class ModelConfiguration {

    @Bean("EegModel")
    Model getModel(){
        String modelPath_str = "D:\study-material\studyOpenFeign\djlproject\src\main\java\com\example\djlproject\model\model.pt";
        Model model = Model.newInstance("mymodel");
        try {
            model.load(Paths.get(modelPath_str));
        } catch (IOException e) {
            e.printStackTrace();
        } catch (MalformedModelException e) {
            e.printStackTrace();
        }
        return model;

    }


    @Bean("ttanslator")
    Translator<NDList, NDList> getTranslator(Model model){
        Translator<NDList, NDList> translator  = new Translator<NDList, NDList>(){

            @Override
            public NDList processInput(TranslatorContext ctx, NDList input) throws Exception {

                return input;
            }

            @Override
            public NDList processOutput(TranslatorContext ctx, NDList list) throws Exception {
                return list;
            }

            @Override
            public Batchifier getBatchifier() {

                return null;
            }
        };
        return translator;
    }

    @Bean("predictor")
    Predictor<NDList, NDList> getPredictor(Model model,Translator<NDList, NDList> translator){
        Predictor<NDList, NDList> predictor = model.newPredictor(translator);
        return predictor;
    }

}

五、Controller接口(测试)

@RestController
@RequestMapping("/eeg/")
public class EegController {

    @Autowired
    Predictor<NDList, NDList> predictor;

    @GetMapping("/predict")
    public void predict() throws Exception {
        String filePath = "D:\study-material\studyOpenFeign\djlproject\src\main\java\com\example\djlproject\data\SN154.npz";
        File file = new File(filePath);
        Map<String, INDArray> fromNpzFile = Nd4j.createFromNpzFile(file);                 // 读取Numpy文件
        INDArray data = fromNpzFile.get("x").reshape(1000L,7680L);

        float[][] floats = data.toFloatMatrix();
        NDManager ndManager = NDManager.newBaseManager();

        List<NDList> datalist = new ArrayList<>();
        for(int i=0;i<floats.length;i++){
            datalist.add(new NDList(ndManager.create(floats[i]).reshape(new Shape(1L,1L,7680L))));
        }


        System.out.println(datalist.size());


        List<NDList> ndLists = predictor.batchPredict(datalist);

        for (int i = 0; i <ndLists.size() ; i++) {
            NDArray ndArray = ndLists.get(i).get(0);
            System.out.println(Arrays.toString(ndArray.toArray()));
            System.out.println("-----------");
        }
    }

}