Pytorch On Java 你的第一个java版本的【真】 神经网络 [AI Infra 3.0]

0 阅读3分钟

如何用Java 写 全真的神经网络 一直是java 众多程序员的梦想,奈何你们寻不到真经,被各种玩具框架 妖魔鬼怪 挟持认知,一个个都以为必须用jni 调python 要么就是冷言嘲讽java不适合做神经网络,在ai是的已经被淘汰了? 果真如此,当然是假的了,反而是 java 在AI时代有媲美Apache spark flink 神级 框架的存在,它就是pytorch

编辑版本将使用 JavaCPP Presets for PyTorch。JavaCPP 提供了 PyTorch C++ API(LibTorch)的直接映射,因此代码风格会非常接近 C++ 版的 LibTorch,但运行在 JVM 上,注意是几百万行代码的全量编译!!!

注意 不是 Pytorch 官方支持 java 版本,也不是 java Oracle支持 Pytorch,而是ByteDeco 旗下的Javacpp 支持 PyTorch ,Pytorch官方基金会在java 的支持上只限于 andriod ,其他都非常拉胯!!! 吃水不忘挖井人,你如果要感谢的话,一定要感谢 Bytedeco 这个伟大的天才开源组织

Image

以下是针对 javacpp-pytorch 2.1.0-1.5.13 版本的完整指南。

1. Maven 配置

首先,在你的 pom.xml 中引入依赖。pytorch-platform 会自动根据你的操作系统下载对应的本地库(包含 CPU 版本,如需 GPU 需额外配置)。

<dependencies>
    <dependency>
        <groupId>org.bytedeco</groupId>
        <artifactId>pytorch-platform</artifactId>
        <version>2.1.0-1.5.13</version>
    </dependency>
</dependencies>

Image

2. 张量(Tensor)操作

在 JavaCPP 中,Tensor 的操作主要通过 org.bytedeco.pytorch.global.torch 类中的静态方法实现。


import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import static org.bytedeco.pytorch.global.torch.*;

public class TensorDemo {
public static void main(String[] args) {
// 1. 创建一个未初始化的 5x3 矩阵
Tensor x1 = torch.empty(new long[]{53}); // 默认类型
System.out.println("Empty Tensor:\n" + x1);

// 2. 创建一个随机初始化的 5x3 矩阵
Tensor x2 = torch.rand(new long[]{53});
System.out.println("Random Tensor:\n" + x2);

// 3. 创建一个全为 0,数据类型为 Long 的矩阵
Tensor x3 = torch.zeros(new long[]{53}, torch.dtype(kLong()));
System.out.println("Zeros Tensor:\n" + x3);

// 4. 直接使用数据初始化 (Java 数组转 Tensor)
Tensor x4 = torch.tensor(new float[]{5.5f3f});
System.out.println("Data Tensor:\n" + x4);
}
}

Image

3. 定义神经网络模型

在 JavaCPP 中定义模型需要继承 Module 类,并手动注册子模块(使用 register_module)。


import static org.bytedeco.pytorch.global.torch.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module; // 注意别导错包
import static org.bytedeco.pytorch.global.torch.*;

// 定义网络结构
class Net extends Module {
// 定义层
private Conv2dImpl conv1, conv2;
private LinearImpl fc1, fc2, fc3;

public Net() {
// 1. 初始化并注册层 (注意:JavaCPP 中使用 Options 对象配置参数)
        // Conv2d(输入通道, 输出通道, 卷积核大小)
conv1 = register_module("conv1"new Conv2dImpl(new Conv2dOptions(163)));
conv2 = register_module("conv2"new Conv2dImpl(new Conv2dOptions(6163)));

// Linear(输入特征数, 输出特征数)
        // 16*6*6 是根据输入图像大小推算出的展平后的特征数
fc1 = register_module("fc1"new LinearImpl(new LinearOptions(16 * 6 * 6120)));
fc2 = register_module("fc2"new LinearImpl(new LinearOptions(12084)));
fc3 = register_module("fc3"new LinearImpl(new LinearOptions(8410)));
}

// 前向传播
public Tensor forward(Tensor x) {
// 第一层卷积 -> ReLU -> 2x2 最大池化
x = max_pool2d(relu(conv1.forward(x)), new long[]{22});

// 第二层卷积 -> ReLU -> 2x2 最大池化
x = max_pool2d(relu(conv2.forward(x)), new long[]{22});

// 展平张量 (flatten),-1 表示自动推导 batch 维度
x = x.view(new long[]{-116 * 6 * 6});

// 全连接层 -> ReLU
x = relu(fc1.forward(x));
x = relu(fc2.forward(x));

// 输出层
x = fc3.forward(x);
return x;
}
}

4. 运行模型(Main 方法)

最后,我们将一切串联起来,创建一个网络实例并进行一次前向计算。



import org.bytedeco.pytorch.Adam;
import org.bytedeco.pytorch.BCELossImpl;
import org.bytedeco.pytorch.SGD;
import org.bytedeco.pytorch.Tensor;
import org.bytedeco.pytorch.global.torch;

public class Main {
public static void main(String[] args) {
// 实例化网络
Net net = new Net();
System.out.println("Network structure initialized.");

// 创建一个模拟输入:1张图像,1个通道,32x32 分辨率
        // 注意 模拟数据 代码中 6x6 的特征图推导通常对应 32x32 的输入
Tensor input = torch.rand(new long[]{113232});

Tensor target = torch.rand(new long[]{11}); // 模拟二分类标签

Adam optimizer = new Adam(net.parameters());
BCELossImpl lossFn = new BCELossImpl();
// 前向传播
Tensor output = net.forward(input);
optimizer.zero_grad();
var loss = lossFn.forward(output, target);
optimizer.step();
System.out.println("Initial loss: " + loss.item_double());

System.out.println("Output Tensor:");
System.out.println(output);

// 打印输出形状
System.out.println("Output sizes: " + java.util.Arrays.toString(output.sizes().vec().get()));
}
}

赞赏二维码****