如何用Java 写 全真的神经网络 一直是java 众多程序员的梦想,奈何你们寻不到真经,被各种玩具框架 妖魔鬼怪 挟持认知,一个个都以为必须用jni 调python 要么就是冷言嘲讽java不适合做神经网络,在ai是的已经被淘汰了? 果真如此,当然是假的了,反而是 java 在AI时代有媲美Apache spark flink 神级 框架的存在,它就是pytorch
编辑版本将使用 JavaCPP Presets for PyTorch。JavaCPP 提供了 PyTorch C++ API(LibTorch)的直接映射,因此代码风格会非常接近 C++ 版的 LibTorch,但运行在 JVM 上,注意是几百万行代码的全量编译!!!
注意 不是 Pytorch 官方支持 java 版本,也不是 java Oracle支持 Pytorch,而是ByteDeco 旗下的Javacpp 支持 PyTorch ,Pytorch官方基金会在java 的支持上只限于 andriod ,其他都非常拉胯!!! 吃水不忘挖井人,你如果要感谢的话,一定要感谢 Bytedeco 这个伟大的天才开源组织
以下是针对 javacpp-pytorch 2.1.0-1.5.13 版本的完整指南。
1. Maven 配置
首先,在你的 pom.xml 中引入依赖。pytorch-platform 会自动根据你的操作系统下载对应的本地库(包含 CPU 版本,如需 GPU 需额外配置)。
<dependencies>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>pytorch-platform</artifactId>
<version>2.1.0-1.5.13</version>
</dependency>
</dependencies>
2. 张量(Tensor)操作
在 JavaCPP 中,Tensor 的操作主要通过 org.bytedeco.pytorch.global.torch 类中的静态方法实现。
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import static org.bytedeco.pytorch.global.torch.*;
public class TensorDemo {
public static void main(String[] args) {
// 1. 创建一个未初始化的 5x3 矩阵
Tensor x1 = torch.empty(new long[]{5, 3}); // 默认类型
System.out.println("Empty Tensor:\n" + x1);
// 2. 创建一个随机初始化的 5x3 矩阵
Tensor x2 = torch.rand(new long[]{5, 3});
System.out.println("Random Tensor:\n" + x2);
// 3. 创建一个全为 0,数据类型为 Long 的矩阵
Tensor x3 = torch.zeros(new long[]{5, 3}, torch.dtype(kLong()));
System.out.println("Zeros Tensor:\n" + x3);
// 4. 直接使用数据初始化 (Java 数组转 Tensor)
Tensor x4 = torch.tensor(new float[]{5.5f, 3f});
System.out.println("Data Tensor:\n" + x4);
}
}
3. 定义神经网络模型
在 JavaCPP 中定义模型需要继承 Module 类,并手动注册子模块(使用 register_module)。
import static org.bytedeco.pytorch.global.torch.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module; // 注意别导错包
import static org.bytedeco.pytorch.global.torch.*;
// 定义网络结构
class Net extends Module {
// 定义层
private Conv2dImpl conv1, conv2;
private LinearImpl fc1, fc2, fc3;
public Net() {
// 1. 初始化并注册层 (注意:JavaCPP 中使用 Options 对象配置参数)
// Conv2d(输入通道, 输出通道, 卷积核大小)
conv1 = register_module("conv1", new Conv2dImpl(new Conv2dOptions(1, 6, 3)));
conv2 = register_module("conv2", new Conv2dImpl(new Conv2dOptions(6, 16, 3)));
// Linear(输入特征数, 输出特征数)
// 16*6*6 是根据输入图像大小推算出的展平后的特征数
fc1 = register_module("fc1", new LinearImpl(new LinearOptions(16 * 6 * 6, 120)));
fc2 = register_module("fc2", new LinearImpl(new LinearOptions(120, 84)));
fc3 = register_module("fc3", new LinearImpl(new LinearOptions(84, 10)));
}
// 前向传播
public Tensor forward(Tensor x) {
// 第一层卷积 -> ReLU -> 2x2 最大池化
x = max_pool2d(relu(conv1.forward(x)), new long[]{2, 2});
// 第二层卷积 -> ReLU -> 2x2 最大池化
x = max_pool2d(relu(conv2.forward(x)), new long[]{2, 2});
// 展平张量 (flatten),-1 表示自动推导 batch 维度
x = x.view(new long[]{-1, 16 * 6 * 6});
// 全连接层 -> ReLU
x = relu(fc1.forward(x));
x = relu(fc2.forward(x));
// 输出层
x = fc3.forward(x);
return x;
}
}
4. 运行模型(Main 方法)
最后,我们将一切串联起来,创建一个网络实例并进行一次前向计算。
import org.bytedeco.pytorch.Adam;
import org.bytedeco.pytorch.BCELossImpl;
import org.bytedeco.pytorch.SGD;
import org.bytedeco.pytorch.Tensor;
import org.bytedeco.pytorch.global.torch;
public class Main {
public static void main(String[] args) {
// 实例化网络
Net net = new Net();
System.out.println("Network structure initialized.");
// 创建一个模拟输入:1张图像,1个通道,32x32 分辨率
// 注意 模拟数据 代码中 6x6 的特征图推导通常对应 32x32 的输入
Tensor input = torch.rand(new long[]{1, 1, 32, 32});
Tensor target = torch.rand(new long[]{1, 1}); // 模拟二分类标签
Adam optimizer = new Adam(net.parameters());
BCELossImpl lossFn = new BCELossImpl();
// 前向传播
Tensor output = net.forward(input);
optimizer.zero_grad();
var loss = lossFn.forward(output, target);
optimizer.step();
System.out.println("Initial loss: " + loss.item_double());
System.out.println("Output Tensor:");
System.out.println(output);
// 打印输出形状
System.out.println("Output sizes: " + java.util.Arrays.toString(output.sizes().vec().get()));
}
}