一、需求
将Pytorch模型整合到后端中,具体的要求是不能采用启动Python进程来启动pytorch项目并将结果返回给后端项目,而是直接再Java项目中直接运行模型,得到对应的结果。
二、技术路线
- 将利用Pytorch中trance函数将训练得到的模型进行序列化
- 利用DJL工具库加载模型和对应的数据
三、maven依赖
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.2.1.RELEASE</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.example</groupId>
<artifactId>djlproject</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>djlproject</name>
<description>Demo project for Spring Boot</description>
<dependencies>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.9.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>net.java.dev.jna</groupId>
<artifactId>jna</artifactId>
<version>5.3.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>1.8.1-0.16.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>win-x86_64</classifier>
<scope>runtime</scope>
<version>1.7.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.7.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-M1.1</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-devtools</artifactId>
<scope>runtime</scope>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
</project>
四、配置类
@Configuration
public class ModelConfiguration {
@Bean("EegModel")
Model getModel(){
String modelPath_str = "D:\study-material\studyOpenFeign\djlproject\src\main\java\com\example\djlproject\model\model.pt";
Model model = Model.newInstance("mymodel");
try {
model.load(Paths.get(modelPath_str));
} catch (IOException e) {
e.printStackTrace();
} catch (MalformedModelException e) {
e.printStackTrace();
}
return model;
}
@Bean("ttanslator")
Translator<NDList, NDList> getTranslator(Model model){
Translator<NDList, NDList> translator = new Translator<NDList, NDList>(){
@Override
public NDList processInput(TranslatorContext ctx, NDList input) throws Exception {
return input;
}
@Override
public NDList processOutput(TranslatorContext ctx, NDList list) throws Exception {
return list;
}
@Override
public Batchifier getBatchifier() {
return null;
}
};
return translator;
}
@Bean("predictor")
Predictor<NDList, NDList> getPredictor(Model model,Translator<NDList, NDList> translator){
Predictor<NDList, NDList> predictor = model.newPredictor(translator);
return predictor;
}
}
五、Controller接口(测试)
@RestController
@RequestMapping("/eeg/")
public class EegController {
@Autowired
Predictor<NDList, NDList> predictor;
@GetMapping("/predict")
public void predict() throws Exception {
String filePath = "D:\study-material\studyOpenFeign\djlproject\src\main\java\com\example\djlproject\data\SN154.npz";
File file = new File(filePath);
Map<String, INDArray> fromNpzFile = Nd4j.createFromNpzFile(file); // 读取Numpy文件
INDArray data = fromNpzFile.get("x").reshape(1000L,7680L);
float[][] floats = data.toFloatMatrix();
NDManager ndManager = NDManager.newBaseManager();
List<NDList> datalist = new ArrayList<>();
for(int i=0;i<floats.length;i++){
datalist.add(new NDList(ndManager.create(floats[i]).reshape(new Shape(1L,1L,7680L))));
}
System.out.println(datalist.size());
List<NDList> ndLists = predictor.batchPredict(datalist);
for (int i = 0; i <ndLists.size() ; i++) {
NDArray ndArray = ndLists.get(i).get(0);
System.out.println(Arrays.toString(ndArray.toArray()));
System.out.println("-----------");
}
}
}