我们今天打算讲讲最近的一些突破,现在的模型文件 格式都有哪些呢,从最开始的 zip ppml json bin pickle pt pth npz 最基础的,在pytorch 上还有 pt2 pte aot torchscript gguf ggml 等格式,这些现在pytorch on java 都能正确加载,不过做大模型的话,现在最流行的模型文件格式应该算是huggingface的 safetensors 格式,不过 safetensors 格式启示不算复杂,它其实是升级版的json,我们在去年就已经实现了safe tensor的读取和写入,今年我们打算做一些深入的,我们打算实现safetensors 格式的模型直接在java 中读取和加载 并实现推理和微调。这如果实现了,就相当于实现了java 版本的transformers ,这里面有一些难度。
1.读取可以读取,但是千万上亿的参数,几百个layer,你怎么排列,你怎么组织
2.加载时好加载, 你怎么用,加载完的是数组还是tensor,你怎么组织这些权重,让权重变模型?变成什么样的模型?
3.你怎么才可以把加载的模型 推理
4.想做微调 那么它到底是 继承 Module 还是 JitModule呢?
5.另外加载的时候 一个llm模型 0.5b少说800mb ,2b的大概4gb,32b的大概60gb ,如果不调jvm ,2gb就jvm crash了,一口气把4gb加载到4gb也容易出问题,你怎么加载,能不能直接把safetensor 直接转为 javacpp-Pytorch tensor,不要经过中间的数组转化,能不能实现零拷贝技术 或者使用直接buffer 实现。
6.技术实现上要尽可能的靠近python transformers 的实现模式,可扩展,可复用,尽可能的简单
基于以上的目标定位,我们最后真的尝试去做了,基于去年几个月的研究和今年AI 编程的能力,我们把它实现了。我们基于 qwen3-vl qwen3.5-vl-embedding 和jina-vl-embedding-v4 ,三个超级模型进行了在java侧的加载和推理的尝试。确切的来说,这是一次实验,也是一次突破,很成功,给了我们很大的信心,未来我们会在这方面投入更多的精力来做
首先大家一起先看看 运行日志看看
```console
═══════════════════════════════════════════════════════════════════╗
║ Qwen3-VL-2B-Instruct 完整测试 V2 ║
╚════════════════════════════════════════════════════════════════════╝
缓存目录: /Users/mullerzhang/IdeaProjects/lanceScala/./cache_qwen3vl_instruct
[阶段1] 下载并加载模型...
WARNING: A restricted method in java.lang.System has been called
WARNING: java.lang.System::load has been called by org.bytedeco.javacpp.Loader in an unnamed module (file:/Users/mullerzhang/Library/Caches/Coursier/v1/https/repo1.maven.org/maven2/org/bytedeco/javacpp/1.5.13/javacpp-1.5.13.jar)
WARNING: Use --enable-native-access=ALL-UNNAMED to avoid a warning for callers in this module
WARNING: Restricted methods will be blocked in a future release unless native access is enabled
[Device] 使用 MPS (Apple GPU)
======================================================================
[Step 1] 下载配置文件
======================================================================
✓ config.json → Qwen_Qwen3-VL-2B-Instruct_config.json
✓ tokenizer.json → Qwen_Qwen3-VL-2B-Instruct_tokenizer.json
✓ tokenizer_config.json → Qwen_Qwen3-VL-2B-Instruct_tokenizer_config.json
✓ merges.txt → Qwen_Qwen3-VL-2B-Instruct_merges.txt
✓ vocab.json → Qwen_Qwen3-VL-2B-Instruct_vocab.json
✓ preprocessor_config.json → Qwen_Qwen3-VL-2B-Instruct_preprocessor_config.json
⚠ chat_template.jinja 下载失败: Failed to download https://hf-mirror.com/Qwen/Qwen3-VL-2B-Instruct/resolve/main/chat_template.jinja HTTP 404
⚠ special_tokens_map.json 下载失败: Failed to download https://hf-mirror.com/Qwen/Qwen3-VL-2B-Instruct/resolve/main/special_tokens_map.json HTTP 404
======================================================================
[Step 2] 下载模型权重
======================================================================
[ModelFetcher] Optional file not found, skip: https://hf-mirror.com/Qwen/Qwen3-VL-2B-Instruct/resolve/main/model.safetensors.index.json HTTP 404
✓ 共 1 个 safetensors 文件
======================================================================
[Step 3] 解析模型配置
======================================================================
Qwen3VLInstructConfig{
text: hidden=2048, layers=28, heads=16/8, headDim=128, inter=6144, vocab=151936
rope: eps=1e-06, theta=5000000, type=default, interleaved=true, mrope=[24, 20, 20]
tokens: bos=151643, eos=151645, img=151655, vid=151656, vis_start=151652, vis_end=151653
vision: depth=24, hidden=1024, heads=16, patch=16, merge=2, out=2048
deepstack=[5, 11, 17], tieEmbed=true
}
======================================================================
[Step 4] 加载 Tokenizer
======================================================================
加载 tokenizer: ./cache_qwen3vl_instruct/Qwen_Qwen3-VL-2B-Instruct_tokenizer.json
SLF4J(W): Class path contains multiple SLF4J providers.
SLF4J(W): Found provider [org.slf4j.impl.JBossSlf4jServiceProvider@44afefd5]
SLF4J(W): Found provider [ch.qos.logback.classic.spi.LogbackServiceProvider@9a7a808]
SLF4J(W): See https://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J(I): Actual provider is of type [org.slf4j.impl.JBossSlf4jServiceProvider@44afefd5]
验证: "Hello, world! 你好世界" → 7 tokens → "Hello, world! 你好世界"
======================================================================
[Step 5] 加载模型权重 (零拷贝)
======================================================================
加载: Qwen_Qwen3-VL-2B-Instruct_model.safetensors
[TorchOps] BF16 Zero-copy probe result: false
进度: 100/625
进度: 200/625
(跳过大张量)[安全模式] model.language_model.embed_tokens.weight 大小=622329856 bytes - set LANCE_ALLOW_LARGE_TENSORS=1 to force load
进度: 300/625
进度: 400/625
进度: 500/625
进度: 600/625
✓ 加载完成: 624/625 权重 (8268ms)
⚠ 失败: 1 个权重:
✗ [TEXT] model.language_model.embed_tokens.weight (skipped-large)
失败分类: text=1 vision=0 other=0
======================================================================
[Step 6] 构建模型
======================================================================
[Qwen3VLInstruct] 模型初始化完成
架构: Qwen3VLForConditionalGeneration
层数: 28
隐藏维度: 2048
注意力头: 16 Q, 8 KV
词汇表: 151936
权重数: 624
✓ Qwen3-VL-2B-Instruct 模型就绪
✓ 模型加载完成 (24137ms)
======================================================================
[测试1] Tokenizer 编码/解码
======================================================================
✓ "Hello, world!" → 4 tokens → "Hello, world!"
✓ "你好,世界!" → 4 tokens → "你好,世界!"
✓ "What is artificial intelligence?" → 5 tokens → "What is artificial intelligence?"
✓ "2 + 2 = 4" → 7 tokens → "2 + 2 = 4"
✓ "The quick brown fox jumps over the lazy dog." → 10 tokens → "The quick brown fox jumps over the lazy dog."
✓ "<|im_start|>system
You are a helpful assistant.<|i..." → 10 tokens → "system
You are a helpful assistant."
结果: 6/6 通过
======================================================================
[测试2] 权重检查
======================================================================
总权重数: 624
--- 前20个权重名称 ---
model.language_model.layers.18.input_layernorm.weight [2048] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9830,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9830,deallocatorAddress=0x16f6136e0]]
model.language_model.layers.18.mlp.gate_proj.weight [6144, 2048] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9840,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9840,deallocatorAddress=0x16f6136e0]]
model.language_model.layers.20.self_attn.k_norm.weight [128] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9850,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9850,deallocatorAddress=0x16f6136e0]]
model.visual.blocks.18.mlp.linear_fc1.weight [4096, 1024] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9860,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9860,deallocatorAddress=0x16f6136e0]]
model.visual.deepstack_merger_list.0.norm.bias [4096] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9870,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9870,deallocatorAddress=0x16f6136e0]]
model.visual.patch_embed.proj.weight [1024, 3, 2, 16, 16] org.bytedeco.pytorch.TypeMeta[address=0xaf51f98b0,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f98b0,deallocatorAddress=0x16f6136e0]]
model.language_model.layers.24.post_attention_layernorm.weight [2048] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9880,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9880,deallocatorAddress=0x16f6136e0]]
model.language_model.layers.25.self_attn.q_norm.weight [128] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9890,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9890,deallocatorAddress=0x16f6136e0]]
model.visual.blocks.17.norm1.bias [1024] org.bytedeco.pytorch.TypeMeta[address=0xaf51f98a0,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f98a0,deallocatorAddress=0x16f6136e0]]
model.visual.merger.linear_fc2.bias [2048] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9730,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9730,deallocatorAddress=0x16f6136e0]]
model.visual.blocks.5.mlp.linear_fc2.bias [1024] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9720,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9720,deallocatorAddress=0x16f6136e0]]
model.language_model.layers.19.mlp.down_proj.weight [2048, 6144] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9710,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9710,deallocatorAddress=0x16f6136e0]]
model.language_model.layers.0.input_layernorm.weight [2048] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9700,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9700,deallocatorAddress=0x16f6136e0]]
model.language_model.layers.20.self_attn.q_norm.weight [128] org.bytedeco.pytorch.TypeMeta[address=0xaf51f96f0,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f96f0,deallocatorAddress=0x16f6136e0]]
model.language_model.layers.26.self_attn.k_proj.weight [1024, 2048] org.bytedeco.pytorch.TypeMeta[address=0xaf51f96e0,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f96e0,deallocatorAddress=0x16f6136e0]]
model.visual.blocks.4.attn.proj.weight [1024, 1024] org.bytedeco.pytorch.TypeMeta[address=0xaf51f96d0,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f96d0,deallocatorAddress=0x16f6136e0]]
model.visual.blocks.18.attn.proj.weight [1024, 1024] org.bytedeco.pytorch.TypeMeta[address=0xaf51f96c0,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f96c0,deallocatorAddress=0x16f6136e0]]
model.visual.blocks.22.norm2.weight [1024] org.bytedeco.pytorch.TypeMeta[address=0xaf51f96b0,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f96b0,deallocatorAddress=0x16f6136e0]]
model.language_model.layers.24.mlp.up_proj.weight [6144, 2048] org.bytedeco.pytorch.TypeMeta[address=0xaf51f96a0,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f96a0,deallocatorAddress=0x16f6136e0]]
model.visual.blocks.18.mlp.linear_fc2.bias [1024] org.bytedeco.pytorch.TypeMeta[address=0xaf51f9690,position=0,limit=1,capacity=1,deallocator=org.bytedeco.javacpp.Pointer$NativeDeallocator[ownerAddress=0xaf51f9690,deallocatorAddress=0x16f6136e0]]
... (共 624 个)
✗ model.language_model.embed_tokens.weight (缺失)
✓ model.language_model.norm.weight [2048]
✓ model.language_model.layers.0.self_attn.q_proj.weight [2048, 2048]
✓ model.language_model.layers.0.self_attn.k_proj.weight [1024, 2048]
✓ model.language_model.layers.0.self_attn.v_proj.weight [1024, 2048]
✓ model.language_model.layers.0.self_attn.o_proj.weight [2048, 2048]
✓ model.language_model.layers.0.self_attn.q_norm.weight [128]
✓ model.language_model.layers.0.self_attn.k_norm.weight [128]
✓ model.language_model.layers.0.mlp.gate_proj.weight [6144, 2048]
✓ model.language_model.layers.0.mlp.up_proj.weight [6144, 2048]
✓ model.language_model.layers.0.mlp.down_proj.weight [2048, 6144]
✓ model.language_model.layers.0.input_layernorm.weight [2048]
✓ model.language_model.layers.0.post_attention_layernorm.weight [2048]
最大层索引: 27 (期望: 27)
结果: 12/13 关键文本权重存在
--- 视觉模块权重 ---
✓ model.visual.patch_embed.proj.weight [1024, 3, 2, 16, 16]
✓ model.visual.blocks.0.attn.qkv.weight [3072, 1024]
✓ model.visual.blocks.0.attn.proj.weight [1024, 1024]
✓ model.visual.blocks.0.mlp.linear_fc1.weight [4096, 1024]
✓ model.visual.blocks.0.mlp.linear_fc2.weight [1024, 4096]
✓ model.visual.blocks.0.norm1.weight [1024]
✓ model.visual.blocks.0.norm2.weight [1024]
视觉层数: 24 (期望: 24)
结果: 7/7 关键视觉权重存在
⚠ 检测到模型/权重目前不安全用于本地原生推理 (可能触发 native 崩溃)。
建议:修复 lance.pytorch.TorchOps 与 TensorDataTorchBridge 的 BF16/from_blob/Device 路径,或在更强的环境(MPS/GPU)上运行。
```
接着我们来说 代码,
要实现 safetensors 格式的模型加载,要实现最少四个模块,一个是model,一个 config ,一个是loader ,一个是tokenizer ,不过对于多模态和未来的扩展,你还需要实现 processor 和pipeline ,这些基本上都是参考python transformers 来实习的,下面大家看一下具体的代码,
```java
package lance.pytorch;
import com.google.gson.*;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
/**
* Qwen3-VL-2B-Instruct 模型配置
* 解析 config.json 并提供类型安全的访问器
*/
public class Qwen3VLInstructConfig {
private final JsonObject root;
private final JsonObject textConfig;
private final JsonObject visionConfig;
// Text model params
public final int hiddenSize;
public final int numHiddenLayers;
public final int numAttentionHeads;
public final int numKeyValueHeads;
public final int headDim;
public final int intermediateSize;
public final int vocabSize;
public final double rmsNormEps;
public final double ropeTheta;
public final int[] mropeSections; // e.g. [24, 20, 20]
public final boolean tieWordEmbeddings;
public final int maxPositionEmbeddings;
public final int bosTokenId;
public int eosTokenId;
public final int imageTokenId;
public final int videoTokenId;
public final int visionStartTokenId;
public final int visionEndTokenId;
// Rope
public final String ropeType;
public final boolean mropeInterleaved;
// Vision model params
public final int visionDepth;
public final int visionHiddenSize;
public final int visionNumHeads;
public final int visionIntermediateSize;
public final int patchSize;
public final int spatialMergeSize;
public final int temporalPatchSize;
public final int outHiddenSize;
public final int inChannels;
public final int[] deepstackVisualIndexes;
public Qwen3VLInstructConfig(Path configPath) throws IOException {
String content = Files.readString(configPath, StandardCharsets.UTF_8);
root = JsonParser.parseString(content).getAsJsonObject();
textConfig = root.has("text_config") ? root.getAsJsonObject("text_config") : root;
visionConfig = root.has("vision_config") ? root.getAsJsonObject("vision_config") : new JsonObject();
// Parse text config
hiddenSize = getInt(textConfig, "hidden_size", 2048);
numHiddenLayers = getInt(textConfig, "num_hidden_layers", 28);
numAttentionHeads = getInt(textConfig, "num_attention_heads", 16);
numKeyValueHeads = getInt(textConfig, "num_key_value_heads", 8);
headDim = getInt(textConfig, "head_dim", 128);
intermediateSize = getInt(textConfig, "intermediate_size", 6144);
vocabSize = getInt(textConfig, "vocab_size", 151936);
rmsNormEps = getDouble(textConfig, "rms_norm_eps", 1e-6);
ropeTheta = getDouble(textConfig, "rope_theta", 5000000.0);
maxPositionEmbeddings = getInt(textConfig, "max_position_embeddings", 262144);
tieWordEmbeddings = getBool(root, "tie_word_embeddings", true);
bosTokenId = getInt(textConfig, "bos_token_id", 151643);
eosTokenId = getInt(textConfig, "eos_token_id", 151645);
imageTokenId = getInt(root, "image_token_id", 151655);
videoTokenId = getInt(root, "video_token_id", 151656);
visionStartTokenId = getInt(root, "vision_start_token_id", 151652);
visionEndTokenId = getInt(root, "vision_end_token_id", 151653);
// Parse rope_scaling
mropeSections = parseMropeSections(textConfig);
ropeType = parseRopeType(textConfig);
mropeInterleaved = parseRopeInterleaved(textConfig);
// Parse vision config
visionDepth = getInt(visionConfig, "depth", 24);
visionHiddenSize = getInt(visionConfig, "hidden_size", 1024);
visionNumHeads = getInt(visionConfig, "num_heads", 16);
visionIntermediateSize = getInt(visionConfig, "intermediate_size", 4096);
patchSize = getInt(visionConfig, "patch_size", 16);
spatialMergeSize = getInt(visionConfig, "spatial_merge_size", 2);
temporalPatchSize = getInt(visionConfig, "temporal_patch_size", 2);
outHiddenSize = getInt(visionConfig, "out_hidden_size", 2048);
inChannels = getInt(visionConfig, "in_channels", 3);
deepstackVisualIndexes = parseIntArray(visionConfig, "deepstack_visual_indexes", new int[]{5, 11, 17});
}
public void setEosTokenId(long id) {
this.eosTokenId = (int) id;
}
private int[] parseMropeSections(JsonObject tc) {
try {
if (tc.has("rope_scaling")) {
JsonObject ropeScaling = tc.getAsJsonObject("rope_scaling");
if (ropeScaling != null && ropeScaling.has("mrope_section")) {
JsonArray arr = ropeScaling.getAsJsonArray("mrope_section");
int[] s = new int[arr.size()];
for (int i = 0; i < arr.size(); i++) s[i] = arr.get(i).getAsInt();
return s;
}
}
} catch (Exception ignore) {}
return new int[]{24, 20, 20};
}
private String parseRopeType(JsonObject tc) {
try {
JsonObject ropeScaling = tc.getAsJsonObject("rope_scaling");
if (ropeScaling != null && ropeScaling.has("rope_type")) {
return ropeScaling.get("rope_type").getAsString();
}
} catch (Exception ignore) {}
return "default";
}
private boolean parseRopeInterleaved(JsonObject tc) {
try {
JsonObject ropeScaling = tc.getAsJsonObject("rope_scaling");
if (ropeScaling != null && ropeScaling.has("mrope_interleaved")) {
return ropeScaling.get("mrope_interleaved").getAsBoolean();
}
} catch (Exception ignore) {}
return true;
}
private int[] parseIntArray(JsonObject obj, String key, int[] def) {
try {
if (obj.has(key)) {
JsonArray arr = obj.getAsJsonArray(key);
int[] result = new int[arr.size()];
for (int i = 0; i < arr.size(); i++) result[i] = arr.get(i).getAsInt();
return result;
}
} catch (Exception ignore) {}
return def;
}
private static int getInt(JsonObject obj, String key, int def) {
try {
if (obj.has(key)) return obj.get(key).getAsInt();
} catch (Exception ignore) {}
return def;
}
private static double getDouble(JsonObject obj, String key, double def) {
try {
if (obj.has(key)) return obj.get(key).getAsDouble();
} catch (Exception ignore) {}
return def;
}
private static boolean getBool(JsonObject obj, String key, boolean def) {
try {
if (obj.has(key)) return obj.get(key).getAsBoolean();
} catch (Exception ignore) {}
return def;
}
@Override
public String toString() {
return String.format(
"Qwen3VLInstructConfig{\n" +
" text: hidden=%d, layers=%d, heads=%d/%d, headDim=%d, inter=%d, vocab=%d\n" +
" rope: eps=%.0e, theta=%.0f, type=%s, interleaved=%b, mrope=%s\n" +
" tokens: bos=%d, eos=%d, img=%d, vid=%d, vis_start=%d, vis_end=%d\n" +
" vision: depth=%d, hidden=%d, heads=%d, patch=%d, merge=%d, out=%d\n" +
" deepstack=%s, tieEmbed=%b\n" +
"}",
hiddenSize, numHiddenLayers, numAttentionHeads, numKeyValueHeads, headDim,
intermediateSize, vocabSize,
rmsNormEps, ropeTheta, ropeType, mropeInterleaved,
java.util.Arrays.toString(mropeSections),
bosTokenId, eosTokenId, imageTokenId, videoTokenId,
visionStartTokenId, visionEndTokenId,
visionDepth, visionHiddenSize, visionNumHeads, patchSize, spatialMergeSize, outHiddenSize,
java.util.Arrays.toString(deepstackVisualIndexes), tieWordEmbeddings
);
}
}
```
```java
package lance.pytorch;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;
import lance.pytorch.tokenizer.LanceTokenizer;
import java.util.*;
/**
* Qwen3-VL-2B-Instruct 模型实现
*
* 架构: Qwen3VLForConditionalGeneration
* - Embedding (tied with LM head)
* - 28 Transformer decoder layers:
* - RMSNorm → GQA Self-Attention (16 Q heads, 8 KV heads) → residual
* - RMSNorm → SwiGLU MLP → residual
* - Final RMSNorm
* - LM Head (tied to embedding)
*
* 使用 javacpp-pytorch 直接加载 safetensors 权重并进行推理。
* 所有权重均以 native Tensor 形式保持在堆外内存。
*/
public class Qwen3VLInstructModel extends Module {
private final Map<String, Tensor> weights;
private final Qwen3VLInstructConfig config;
// tokenizer can be either LanceTokenizer (pure-java) or an external DJL HuggingFaceTokenizer
private final Object tokenizerObj;
private final LanceTokenizer javaTokenizer; // non-null when a pure-java tokenizer is provided
// Cached embedding weight for fast lookup
private Tensor cachedEmbedWeight;
public Qwen3VLInstructModel(Map<String, Tensor> weights, Qwen3VLInstructConfig config, Object tokenizer) {
super();
this.weights = weights;
this.config = config;
this.tokenizerObj = tokenizer;
this.javaTokenizer = (tokenizer instanceof LanceTokenizer) ? (LanceTokenizer) tokenizer : null;
this.cachedEmbedWeight = resolveEmbedWeight();
System.out.println("[Qwen3VLInstruct] 模型初始化完成");
System.out.println(" 架构: Qwen3VLForConditionalGeneration");
System.out.println(" 层数: " + config.numHiddenLayers);
System.out.println(" 隐藏维度: " + config.hiddenSize);
System.out.println(" 注意力头: " + config.numAttentionHeads + " Q, " + config.numKeyValueHeads + " KV");
System.out.println(" 词汇表: " + config.vocabSize);
System.out.println(" 权重数: " + weights.size());
// Avoid calling into native tokenizers during construction. Only use pure-java tokenizer here.
if (this.javaTokenizer != null) {
long eosId = this.javaTokenizer.getEosTokenId();
if (eosId != -1) {
config.setEosTokenId(eosId);
System.out.println(" ✓ EOS token ID 设置为: " + eosId);
} else {
System.err.println(" ⚠ 无法从 java tokenizer 获取 EOS token ID");
}
} else if (this.tokenizerObj != null) {
System.err.println(" ⚠ 外部 tokenizer 实例已提供(可能为 DJL native),构造时不调用其方法以避免原生冲突");
}
}
/**
* Return the tokenizer object that was supplied. Callers should check its type
* and avoid invoking native-backed tokenizers from within native inference paths.
*/
// Typed accessor expected by existing tests: returns the pure-java LanceTokenizer when available
public lance.pytorch.tokenizer.LanceTokenizer getTokenizer() {
if (this.javaTokenizer != null) return this.javaTokenizer;
if (this.tokenizerObj instanceof lance.pytorch.tokenizer.LanceTokenizer) {
return (lance.pytorch.tokenizer.LanceTokenizer) this.tokenizerObj;
}
return null;
}
// Raw accessor for callers that need the original tokenizer object (could be DJL native tokenizer)
public Object getTokenizerObj() {
return tokenizerObj;
}
// ======================== 文本生成接口 ========================
/**
* 文本生成 - 完整的自回归生成
*/
public String generate(String prompt, int maxNewTokens, float temperature, float topP) {
if (javaTokenizer == null) throw new IllegalStateException("Native inference requires a pure-java LanceTokenizer; model was constructed with a native/external tokenizer or none.");
System.out.println("\n[生成] 输入: " + prompt.substring(0, Math.min(80, prompt.length())) + (prompt.length() > 80 ? "..." : ""));
// 1. Apply chat template
String formatted = applyChatTemplate(prompt);
// 2. Tokenize
long[] encodedIds = javaTokenizer.encode(formatted);
List<Integer> inputIds = new ArrayList<>();
for (long id : encodedIds) {
inputIds.add((int) id);
}
System.out.println("[Tokenize] token数: " + inputIds.size());
if (inputIds.size() > 0) {
System.out.println("[Tokenize] 前5个tokens: " + inputIds.subList(0, Math.min(5, inputIds.size())));
String decodedVerify = javaTokenizer.decode(encodedIds);
System.out.println("[Tokenize] 解码验证: " + decodedVerify.substring(0, Math.min(60, decodedVerify.length())));
}
// Use NoGradGuard for inference
try (var noGrad = new NoGradGuard()) {
// 3. Prefill - forward all input tokens at once
long startMs = System.currentTimeMillis();
Tensor logits = forwardPrefill(inputIds);
if (logits == null) {
System.err.println("[生成] Prefill 失败");
return "";
}
// Diagnostic: check logits shape and values
long[] logitsShape = logits.sizes().vec().get();
System.out.println("[Logits] shape=" + java.util.Arrays.toString(logitsShape));
float logitsMin = logits.min().item_float();
float logitsMax = logits.max().item_float();
float logitsMean = logits.mean().item_float();
System.out.println("[Logits] min=" + logitsMin + " max=" + logitsMax + " mean=" + logitsMean);
if (Float.isNaN(logitsMean) || Float.isInfinite(logitsMean)) {
System.err.println("[生成] ⚠ Logits包含 NaN/Inf! 模型前向传播有问题");
}
// 4. Autoregressive generation
List<Long> generatedIds = new ArrayList<>();
int eosId = config.eosTokenId;
for (int step = 0; step < maxNewTokens; step++) {
// Sample next token from logits
int nextToken = sampleTopP(logits, temperature, topP);
// Check EOS
if (nextToken == eosId) {
System.out.println("[生成] EOS at step " + step);
break;
}
generatedIds.add((long)nextToken);
// Decode incrementally for debugging (first few tokens only)
if (step < 5) {
String partial = javaTokenizer.decode(new long[]{nextToken});
System.out.println("[Step " + step + "] token=" + nextToken + " → \"" + partial + "\"");
}
// Forward single new token
logits = forwardSingleToken(nextToken);
if (logits == null) {
System.err.println("[生成] 步骤 " + step + " forward失败,停止生成");
break;
}
}
long elapsed = System.currentTimeMillis() - startMs;
// 5. Decode
String result = javaTokenizer.decode(generatedIds.stream().mapToLong(l -> l).toArray());
System.out.println("[生成完成] " + generatedIds.size() + " tokens, " + elapsed + "ms (" +
String.format("%.1f", generatedIds.size() * 1000.0 / Math.max(1, elapsed)) + " tok/s)");
return result;
}
}
/**
* 文本生成 - 支持多轮对话
*/
public String generate(List<Map<String, String>> messages, int maxNewTokens, float temperature, float topP) {
// if (tokenizer == null) throw new IllegalStateException("Tokenizer not set");
// 1. Apply chat template
String formatted = applyChatTemplate(messages);
return generate(formatted, maxNewTokens, temperature, topP);
}
private String applyChatTemplate(List<Map<String, String>> messages) {
StringBuilder sb = new StringBuilder();
for (Map<String, String> message : messages) {
sb.append("<|im_start|>")
.append(message.get("role"))
.append("\n")
.append(message.get("content"))
.append("<|im_end|>\n");
}
sb.append("<|im_start|>assistant\n");
return sb.toString();
}
// ======================== 前向传播 ========================
/**
* Prefill: 处理完整输入序列
* 返回最后一个位置的 logits [1, vocab_size]
*/
private Tensor forwardPrefill(List<Integer> tokenIds) {
try {
Tensor embedWeight = getEmbedWeight();
if (embedWeight == null) {
System.err.println("[Forward] 找不到 embedding 权重");
return null;
}
// Force everything to Float32 for numerical stability
Tensor embedWeightF32 = embedWeight.to(torch.ScalarType.Float);
long hiddenDim = embedWeightF32.size(1);
int seqLen = tokenIds.size();
// 1. Embedding lookup
long[] ids = new long[seqLen];
for (int i = 0; i < seqLen; i++) ids[i] = tokenIds.get(i);
Tensor inputIdsTensor = torch.tensor(ids);
Tensor hidden = embedWeightF32.index_select(0, inputIdsTensor);
hidden = hidden.reshape(1, seqLen, hiddenDim);
// hidden is now [1, seqLen, hiddenDim] Float32
// 2. Transformer layers
int layersFailed = 0;
for (int layer = 0; layer < config.numHiddenLayers; layer++) {
Tensor prev = hidden;
hidden = transformerLayer(hidden, layer);
// Sanity check: if transformerLayer returns same tensor, layer was skipped
if (hidden == prev && layer == 0) {
System.err.println("[Forward] ⚠ Layer 0 was no-op (weights missing?)");
}
}
// 3. Final RMS norm
Tensor finalNormW = findWeight("model.norm.weight");
if (finalNormW != null) {
hidden = rmsNorm(hidden, finalNormW);
}
// Ensure Float32
hidden = hidden.to(torch.ScalarType.Float);
// 4. Extract last position → LM head
Tensor lastHidden = hidden.slice(1, new LongOptional(seqLen - 1), new LongOptional(seqLen), 1);
lastHidden = lastHidden.reshape(1, hiddenDim);
// LM Head = embed_weight.T (tie_word_embeddings)
return torch.mm(lastHidden, embedWeightF32.transpose(0, 1));
} catch (Exception e) {
System.err.println("[Forward] Prefill错误: " + e.getMessage());
e.printStackTrace();
return null;
}
}
/**
* 单token forward (decode阶段)
* 返回 [1, vocab_size] logits
*/
private Tensor forwardSingleToken(int tokenId) {
try {
Tensor embedWeight = getEmbedWeight();
if (embedWeight == null) return null;
Tensor embedWeightF32 = embedWeight.to(torch.ScalarType.Float);
long hiddenDim = embedWeightF32.size(1);
// Embedding lookup for single token
Tensor idTensor = torch.tensor(new long[]{tokenId});
Tensor hidden = embedWeightF32.index_select(0, idTensor);
hidden = hidden.reshape(1, 1, hiddenDim);
// Transformer layers
for (int layer = 0; layer < config.numHiddenLayers; layer++) {
hidden = transformerLayer(hidden, layer);
}
// Final norm
Tensor finalNormW = findWeight("model.norm.weight");
if (finalNormW != null) {
hidden = rmsNorm(hidden, finalNormW);
}
// Ensure Float32
hidden = hidden.to(torch.ScalarType.Float);
// LM head
hidden = hidden.reshape(1, hiddenDim);
return torch.mm(hidden, embedWeightF32.transpose(0, 1));
} catch (Exception e) {
return null;
}
}
// ======================== Transformer 层 ========================
/**
* 单个 Transformer decoder 层
* input_layernorm → GQA self-attention → residual → post_attn_layernorm → SwiGLU MLP → residual
*/
private Tensor transformerLayer(Tensor hidden, int layerIdx) {
String pfx = "model.layers." + layerIdx;
try {
// Ensure hidden is Float32
hidden = hidden.to(torch.ScalarType.Float);
long bsz = hidden.size(0);
long seqLen = hidden.size(1);
long hDim = hidden.size(2);
// === Self-Attention ===
Tensor normW = findWeightF32(pfx + ".input_layernorm.weight");
Tensor normed = (normW != null) ? rmsNorm(hidden, normW) : hidden;
Tensor qW = findWeightF32(pfx + ".self_attn.q_proj.weight");
Tensor kW = findWeightF32(pfx + ".self_attn.k_proj.weight");
Tensor vW = findWeightF32(pfx + ".self_attn.v_proj.weight");
Tensor oW = findWeightF32(pfx + ".self_attn.o_proj.weight");
if (qW != null && kW != null && vW != null && oW != null) {
Tensor normed2d = normed.reshape(bsz * seqLen, hDim);
// Q, K, V projections
Tensor q = torch.mm(normed2d, qW.transpose(0, 1)); // [seq, numHeads*headDim]
Tensor k = torch.mm(normed2d, kW.transpose(0, 1)); // [seq, numKVHeads*headDim]
Tensor v = torch.mm(normed2d, vW.transpose(0, 1)); // [seq, numKVHeads*headDim]
int numHeads = config.numAttentionHeads;
int numKVHeads = config.numKeyValueHeads;
int headDim = config.headDim;
int groupSize = numHeads / numKVHeads; // 2 for GQA
// Apply q_norm and k_norm (Qwen3-VL specific)
Tensor qNormW = findWeightF32(pfx + ".self_attn.q_norm.weight");
Tensor kNormW = findWeightF32(pfx + ".self_attn.k_norm.weight");
if (qNormW != null) {
q = q.reshape(bsz * seqLen * numHeads, headDim);
q = rmsNorm2d(q, qNormW);
q = q.reshape(bsz * seqLen, (long) numHeads * headDim);
}
if (kNormW != null) {
k = k.reshape(bsz * seqLen * numKVHeads, headDim);
k = rmsNorm2d(k, kNormW);
k = k.reshape(bsz * seqLen, (long) numKVHeads * headDim);
}
// Reshape for multi-head attention
q = q.reshape(bsz, seqLen, numHeads, headDim).transpose(1, 2);
k = k.reshape(bsz, seqLen, numKVHeads, headDim).transpose(1, 2);
v = v.reshape(bsz, seqLen, numKVHeads, headDim).transpose(1, 2);
// GQA: expand K,V to match Q heads by repeating
if (groupSize > 1) {
k = k.unsqueeze(2).expand(bsz, numKVHeads, groupSize, seqLen, headDim)
.reshape(bsz, numHeads, seqLen, headDim);
v = v.unsqueeze(2).expand(bsz, numKVHeads, groupSize, seqLen, headDim)
.reshape(bsz, numHeads, seqLen, headDim);
}
// Scaled dot-product attention
Tensor scores = torch.matmul(q, k.transpose(2, 3));
scores = torch.div(scores, new Scalar(Math.sqrt(headDim)));
// Causal mask: upper triangular = -inf
if (seqLen > 1) {
Tensor causalMask = torch.ones(seqLen, seqLen);
Tensor triu = torch.triu(causalMask, 1L);
Tensor maskValue = torch.mul(triu, new Scalar(-1e9));
scores = torch.add(scores, maskValue);
}
Tensor attnWeights = torch.softmax(scores, -1);
Tensor attnOut = torch.matmul(attnWeights, v);
// Merge heads: [bsz, seqLen, numHeads*headDim]
attnOut = attnOut.transpose(1, 2).reshape(bsz * seqLen, (long) numHeads * headDim);
// Output projection
Tensor projected = torch.mm(attnOut, oW.transpose(0, 1));
projected = projected.reshape(bsz, seqLen, hDim);
// Residual
hidden = torch.add(hidden, projected);
} else if (layerIdx < 3) {
// Only log missing weights for first few layers
System.err.println("[Layer " + layerIdx + "] ⚠ 注意力权重缺失: q=" + (qW!=null) +
" k=" + (kW!=null) + " v=" + (vW!=null) + " o=" + (oW!=null));
}
// === MLP (SwiGLU) ===
Tensor postNormW = findWeightF32(pfx + ".post_attention_layernorm.weight");
Tensor postNormed = (postNormW != null) ? rmsNorm(hidden, postNormW) : hidden;
Tensor gateW = findWeightF32(pfx + ".mlp.gate_proj.weight");
Tensor upW = findWeightF32(pfx + ".mlp.up_proj.weight");
Tensor downW = findWeightF32(pfx + ".mlp.down_proj.weight");
if (gateW != null && upW != null && downW != null) {
long bsz2 = postNormed.size(0);
long seqLen2 = postNormed.size(1);
long hDim2 = postNormed.size(2);
Tensor normed2d = postNormed.reshape(bsz2 * seqLen2, hDim2);
Tensor gate = torch.mm(normed2d, gateW.transpose(0, 1));
Tensor up = torch.mm(normed2d, upW.transpose(0, 1));
Tensor activated = torch.mul(torch.silu(gate), up);
Tensor mlpOut = torch.mm(activated, downW.transpose(0, 1));
mlpOut = mlpOut.reshape(bsz2, seqLen2, hDim2);
hidden = torch.add(hidden, mlpOut);
} else if (layerIdx < 3) {
System.err.println("[Layer " + layerIdx + "] ⚠ MLP权重缺失: gate=" + (gateW!=null) +
" up=" + (upW!=null) + " down=" + (downW!=null));
}
} catch (Exception e) {
// Log the error for first few layers
if (layerIdx < 3) {
System.err.println("[Layer " + layerIdx + "] 错误: " + e.getMessage());
e.printStackTrace();
}
}
return hidden;
}
/**
* Find weight by name and convert to Float32
*/
private Tensor findWeightF32(String name) {
Tensor t = findWeight(name);
if (t == null) return null;
return t.to(torch.ScalarType.Float);
}
// ======================== RMS Norm ========================
/**
* RMS normalization for 3D+ tensors (standard layer norm path)
*/
private Tensor rmsNorm(Tensor x, Tensor weight) {
try {
Tensor xf = x.to(torch.ScalarType.Float);
Tensor wf = weight.to(torch.ScalarType.Float);
long[] shape = xf.sizes().vec().get();
long lastDim = shape[shape.length - 1];
long rows = 1;
for (int i = 0; i < shape.length - 1; i++) rows *= shape[i];
Tensor flat = xf.reshape(rows, lastDim);
Tensor normalized = rmsNormCore(flat, lastDim);
Tensor scaled = torch.mul(normalized, wf);
// Keep as Float32 — don't convert back to original dtype
return scaled.reshape(shape);
} catch (Exception e) {
return x;
}
}
/**
* RMS normalization for 2D [N, dim] tensors (q_norm, k_norm)
*/
private Tensor rmsNorm2d(Tensor x, Tensor weight) {
try {
Tensor xf = x.to(torch.ScalarType.Float);
Tensor wf = weight.to(torch.ScalarType.Float);
long lastDim = xf.size(1);
Tensor normalized = rmsNormCore(xf, lastDim);
return torch.mul(normalized, wf);
} catch (Exception e) {
return x;
}
}
/**
* Core RMS normalization: x / sqrt(mean(x^2) + eps) for [rows, dim] tensor
*/
private Tensor rmsNormCore(Tensor flat, long dim) {
Tensor sq = torch.mul(flat, flat);
Tensor rowMean = sq.sum(-1L);
rowMean = torch.div(rowMean, new Scalar((double) dim));
Tensor rms = torch.sqrt(torch.add(rowMean, new Scalar(config.rmsNormEps)));
Tensor rmsExp = rms.reshape(flat.size(0), 1);
return torch.div(flat, rmsExp);
}
// ======================== 采样 ========================
/**
* Top-p (nucleus) sampling with temperature
*/
private int sampleTopP(Tensor logits, float temperature, float topP) {
try {
Tensor l = logits.squeeze(0); // [vocab_size]
// Apply temperature
if (temperature != 1.0f) {
l = torch.div(l, new Scalar(temperature));
}
// Convert to probabilities
Tensor probs = torch.softmax(l, -1);
// Top-p (nucleus) sampling
if (topP < 1.0f) {
// Sort probabilities in descending order
Tensor sortedProbs = torch.sort(probs, -1, true).get0();
Tensor sortedIndices = torch.sort(probs, -1, true).get1();
// Compute cumulative probabilities
Tensor cumsum = torch.cumsum(sortedProbs, -1);
// Find cutoff
Tensor cutoffMask = torch.lt(cumsum, new Scalar(topP));
// Include at least one token
Tensor firstTrue = torch.argmax(cutoffMask.to(torch.ScalarType.Int), new LongOptional(-1), false);
long cutoffIdx = firstTrue.item_long() + 1;
// Get top-p indices
Tensor topPIndices = sortedIndices.slice(0, new LongOptional(0), new LongOptional(cutoffIdx), 1);
// Sample from top-p distribution
Tensor topPProbs = sortedProbs.slice(0, new LongOptional(0), new LongOptional(cutoffIdx), 1);
topPProbs = torch.div(topPProbs, torch.sum(topPProbs)); // renormalize
// Sample
Tensor sampledIdx = torch.multinomial(topPProbs, 1, false, new GeneratorOptional());
long selectedIdx = sampledIdx.item_long();
return (int) topPIndices.slice(0, new LongOptional(selectedIdx), new LongOptional(selectedIdx + 1), 1).item_long();
} else {
// Simple multinomial sampling
return (int) torch.multinomial(probs, 1, false, new GeneratorOptional()).item_long();
}
} catch (Exception e) {
System.err.println("Sampling failed: " + e.getMessage());
// Fallback to greedy
return (int) torch.argmax(logits.squeeze(0), new LongOptional(-1), false).item_long();
}
}
// ======================== Chat Template ========================
private String applyChatTemplate(String userMessage) {
// Qwen3-VL Instruct chat format:
// <|im_start|>system\nYou are a helpful assistant.<|im_end|>\n
// <|im_start|>user\n{message}<|im_end|>\n
// <|im_start|>assistant\n
return "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" +
"<|im_start|>user\n" + userMessage + "<|im_end|>\n" +
"<|im_start|>assistant\n";
}
// ======================== Weight Utilities ========================
private Tensor resolveEmbedWeight() {
String[] names = {
"model.embed_tokens.weight",
"model.language_model.embed_tokens.weight",
"language_model.model.embed_tokens.weight"
};
for (String n : names) {
Tensor t = weights.get(n);
if (t != null) {
System.out.println("[权重] Embedding: " + n + " [" + t.size(0) + ", " + t.size(1) + "]");
return t;
}
}
// Fuzzy search
for (Map.Entry<String, Tensor> e : weights.entrySet()) {
if (e.getKey().contains("embed_tokens") && e.getKey().endsWith(".weight")) {
System.out.println("[权重] Embedding (fuzzy): " + e.getKey());
return e.getValue();
}
}
System.err.println("[权重] ✗ 未找到 embedding 权重!");
return null;
}
private Tensor getEmbedWeight() {
return cachedEmbedWeight;
}
/**
* Find weight by name, trying multiple prefixes.
* Qwen3-VL weights may use:
* model.layers.X... (standard)
* model.language_model.layers.X... (some exports)
* language_model.model.layers.X... (some exports)
*/
private Tensor findWeight(String name) {
// Direct lookup
Tensor t = weights.get(name);
if (t != null) return t;
// If name starts with "model.", try without prefix
if (name.startsWith("model.")) {
String stripped = name.substring("model.".length());
// Try "model.language_model." prefix
t = weights.get("model.language_model." + stripped);
if (t != null) return t;
// Try "language_model.model." prefix
t = weights.get("language_model.model." + stripped);
if (t != null) return t;
// Try "language_model." prefix
t = weights.get("language_model." + name);
if (t != null) return t;
}
return null;
}
public Qwen3VLInstructConfig getConfig() { return config; }
public Map<String, Tensor> getWeights() { return weights; }
public int getWeightCount() { return weights.size(); }
/**
* Convert List to long[]
*/
private long[] toLongArray(List<Integer> list) {
long[] arr = new long[list.size()];
for (int i = 0; i < list.size(); i++) {
arr[i] = list.get(i);
}
return arr;
}
}
```
```java
package lance.pytorch;
import lance.pytorch.tokenizer.DJLTokenizer;
import lance.pytorch.tokenizer.LanceTokenizer;
import lance.dtype.TensorData;
import lance.safetensors.ModelFetcher;
import lance.safetensors.SafeTensorSupport;
import org.bytedeco.pytorch.Device;
import org.bytedeco.pytorch.Tensor;
import org.bytedeco.pytorch.global.torch;
import com.google.gson.*;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.nio.file.*;
import java.util.*;
/**
* Qwen3-VL-2B-Instruct 模型加载器
*
* 功能:
* 1. 使用 ModelFetcher 从 HuggingFace 下载模型文件和配置
* 2. 解析 config.json 构建模型结构
* 3. 加载 tokenizer.json 初始化 BPE tokenizer
* 4. 加载 safetensors 权重 (零拷贝 mmap)
* 5. 组装为 Qwen3VLInstructModel 进行推理
*/
public class Qwen3VLInstructLoader {
private static final String REPO_ID = "Qwen/Qwen3-VL-2B-Instruct";
// Files to download
private static final String[] CONFIG_FILES = {
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"merges.txt",
"vocab.json",
"preprocessor_config.json",
"chat_template.jinja",
"special_tokens_map.json"
};
/**
* 从 HuggingFace 下载并加载模型 (完整流程)
*
* @param cacheDir 本地缓存目录
* @return 加载好的 Qwen3VLInstructModel 实例
*/
public static Qwen3VLInstructModel load(Path cacheDir) throws Exception {
// By default respect the environment variable LANCE_SMALL_LOAD for safe dev testing.
String small = System.getenv("LANCE_SMALL_LOAD");
if (small != null && !small.isEmpty()) {
try {
int n = Integer.parseInt(small);
System.out.println("[Loader] LANCE_SMALL_LOAD=" + n + " -> running limited load for dev safety");
return load(cacheDir, detectDevice(), n);
} catch (NumberFormatException e) {
System.out.println("[Loader] Invalid LANCE_SMALL_LOAD value: " + small + " -> ignoring");
}
}
return load(cacheDir, detectDevice());
}
public static Qwen3VLInstructModel load(Path cacheDir, Device device) throws Exception {
// default: no limit (full load)
return load(cacheDir, device, Integer.MAX_VALUE);
}
/**
* New overload: allow a small-scale test mode by limiting how many tensors we actually
* convert into native torch Tensors. This avoids creating massive on-device tensors
* during debugging and prevents triggering native crashes while iterating on loading code.
*
* @param cacheDir cache directory
* @param device target Device
* @param maxTensorsToLoad stop after this many tensors (use Integer.MAX_VALUE for full)
*/
public static Qwen3VLInstructModel load(Path cacheDir, Device device, int maxTensorsToLoad) throws Exception {
Files.createDirectories(cacheDir);
ModelFetcher fetcher = new ModelFetcher(cacheDir);
String prefix = REPO_ID.replace('/', '_') + "_";
// ==================== Step 1: Download config files ====================
System.out.println("\n" + "=".repeat(70));
System.out.println("[Step 1] 下载配置文件");
System.out.println("=".repeat(70));
for (String f : CONFIG_FILES) {
try {
// ModelFetcher currently provides fetch(repoId, filename) -> Path
Path p = fetcher.fetch(REPO_ID, f);
if (p != null && Files.exists(p)) {
System.out.println(" ✓ " + f + " → " + p.getFileName());
}
} catch (FileNotFoundException fnf) {
// Some files like special_tokens_map.json or added_tokens.json may not exist for all repos.
System.out.println(" ⚠ Optional file not found, skip: " + f + " -> " + fnf.getMessage());
} catch (IOException ioe) {
// Network/IO issues should be reported but non-fatal for optional files
System.out.println(" ⚠ " + f + " 下载失败: " + ioe.getMessage());
} catch (Exception e) {
System.out.println(" ⚠ " + f + " 下载失败 (可选): " + e.getMessage());
}
}
// ==================== Step 2: Download safetensors ====================
System.out.println("\n" + "=".repeat(70));
System.out.println("[Step 2] 下载模型权重");
System.out.println("=".repeat(70));
// Check if sharded or single file
List<Path> safetensorPaths = downloadSafetensors(fetcher, cacheDir, prefix);
System.out.println(" ✓ 共 " + safetensorPaths.size() + " 个 safetensors 文件");
// ==================== Step 3: Parse config ====================
System.out.println("\n" + "=".repeat(70));
System.out.println("[Step 3] 解析模型配置");
System.out.println("=".repeat(70));
Path configPath = cacheDir.resolve(prefix + "config.json");
Qwen3VLInstructConfig config = new Qwen3VLInstructConfig(configPath);
System.out.println(" " + config);
// ==================== Step 4: Load tokenizer ====================
System.out.println("\n" + "=".repeat(70));
System.out.println("[Step 4] 加载 Tokenizer");
System.out.println("=".repeat(70));
Path tokenizerJsonPath = cacheDir.resolve(prefix + "tokenizer.json");
if (!Files.exists(tokenizerJsonPath)) {
throw new FileNotFoundException("Tokenizer file not found: " + tokenizerJsonPath);
}
System.out.println(" 加载 tokenizer: " + tokenizerJsonPath);
LanceTokenizer tokenizer = new DJLTokenizer(tokenizerJsonPath);
// Verify tokenizer roundtrip
String testStr = "Hello, world! 你好世界";
long[] encoded = tokenizer.encode(testStr);
String decoded = tokenizer.decode(encoded);
System.out.println(" 验证: \"" + testStr + "\" → " + encoded.length + " tokens → \"" + decoded + "\"");
// ==================== Step 5: Load weights ====================
System.out.println("\n" + "=".repeat(70));
System.out.println("[Step 5] 加载模型权重 (零拷贝)");
System.out.println("=".repeat(70));
long loadStart = System.currentTimeMillis();
Map<String, Tensor> allWeights = new LinkedHashMap<>();
int totalTensors = 0;
int loadedTensors = 0;
int failedTensors = 0;
List<String> failedNames = new ArrayList<>();
// threshold: skip very large tensors when in test-limited mode
final long SKIP_IF_LARGER_THAN_BYTES = 64L * 1024L * 1024L; // 64MB
outer: for (Path stPath : safetensorPaths) {
System.out.println(" 加载: " + stPath.getFileName());
Map<String, TensorData> tensorDataMap = SafeTensorSupport.loadLazy(stPath.toFile());
totalTensors += tensorDataMap.size();
for (Map.Entry<String, TensorData> e : tensorDataMap.entrySet()) {
if (loadedTensors >= maxTensorsToLoad) {
// Stop early for small-scale testing
System.out.println(" (测试模式) 达到 maxTensorsToLoad=" + maxTensorsToLoad + ", 停止加载更多权重");
break outer;
}
try {
TensorData td = e.getValue();
// If the tensor reports a large byte size and we're in limited test mode,
// skip attempting zero-copy conversion which can trigger native errors.
long sizeBytes = -1L;
try {
// prefer sizeBytes() if available
sizeBytes = td.sizeBytes();
} catch (Throwable ignore) {
// fallback: try reflective call to sizeInBytes() if present
try {
java.lang.reflect.Method m = td.getClass().getMethod("sizeInBytes");
Object v = m.invoke(td);
if (v instanceof Number) sizeBytes = ((Number) v).longValue();
} catch (Throwable ignore2) {
// unknown API - leave sizeBytes = -1
}
}
boolean isLarge = (sizeBytes >= SKIP_IF_LARGER_THAN_BYTES);
// Safety guard: loading very large tensors (e.g. hundreds of MBs / GBs, BF16 mmap) via
// zero-copy into native torch tensors can trigger native crashes on some platforms
// (especially MPS / macOS). By default we skip tensors larger than SKIP_IF_LARGER_THAN_BYTES
// to allow the loader to run safely on developer machines. To force full loading set
// environment variable LANCE_ALLOW_LARGE_TENSORS=1
boolean allowLarge = Boolean.parseBoolean(System.getenv("LANCE_ALLOW_LARGE_TENSORS"));
if (isLarge && !allowLarge) {
System.out.println(" (跳过大张量)[安全模式] " + e.getKey() + " 大小=" + (sizeBytes <= 0 ? "?" : Long.toString(sizeBytes)) + " bytes - set LANCE_ALLOW_LARGE_TENSORS=1 to force load");
failedTensors++;
failedNames.add(e.getKey() + " (skipped-large)");
continue;
}
// If we're in small-test mode also skip too-large tensors (prevent accidental full load)
if (isLarge && maxTensorsToLoad != Integer.MAX_VALUE && !allowLarge) {
// already counted above; just continue (kept for readability)
continue;
}
Tensor t = null;
// Wrap conversion in try/catch — we still avoid calling native bridge for large tensors
try {
t = TensorDataTorchBridge.toTorchTensor(td, device);
} catch (Throwable bridgeErr) {
// If conversion fails, don't crash JVM: log and mark this tensor as failed.
System.err.println(" ⚠ 转换权重失败: " + e.getKey() + " -> " + bridgeErr.getClass().getName() + ": " + bridgeErr.getMessage());
bridgeErr.printStackTrace(System.err);
failedTensors++;
failedNames.add(e.getKey() + " (convert-failed)");
continue;
}
allWeights.put(e.getKey(), t);
loadedTensors++;
if (loadedTensors % 100 == 0) {
System.out.println(" 进度: " + loadedTensors + "/" + totalTensors);
}
} catch (Throwable ex) {
// Catch Throwable to avoid failing noisily in Java; native crashes still possible
failedTensors++;
failedNames.add(e.getKey());
if (failedTensors <= 5) {
System.err.println(" ⚠ " + e.getKey() + ": " + ex.getMessage());
ex.printStackTrace(System.err);
}
}
}
}
long loadMs = System.currentTimeMillis() - loadStart;
System.out.println(" ✓ 加载完成: " + loadedTensors + "/" + totalTensors + " 权重 (" + loadMs + "ms)");
if (failedTensors > 0) {
System.out.println(" ⚠ 失败: " + failedTensors + " 个权重:");
// Categorize failed weights
int textFailed = 0, visionFailed = 0, otherFailed = 0;
for (String name : failedNames) {
if (name.contains("model.layers.") || name.contains("embed_tokens") || name.equals("model.norm.weight")) {
textFailed++;
System.out.println(" ✗ [TEXT] " + name);
} else if (name.contains("visual") || name.contains("vision")) {
visionFailed++;
System.out.println(" ✗ [VISION] " + name);
} else {
otherFailed++;
System.out.println(" ✗ [OTHER] " + name);
}
}
System.out.println(" 失败分类: text=" + textFailed + " vision=" + visionFailed + " other=" + otherFailed);
}
// ==================== Step 6: Build model ====================
System.out.println("\n" + "=".repeat(70));
System.out.println("[Step 6] 构建模型");
System.out.println("=".repeat(70));
Qwen3VLInstructModel model = new Qwen3VLInstructModel(allWeights, config, tokenizer);
System.out.println(" ✓ Qwen3-VL-2B-Instruct 模型就绪");
return model;
}
/**
* Download safetensors files (handles both single and sharded models)
*/
private static List<Path> downloadSafetensors(ModelFetcher fetcher, Path cacheDir, String prefix) throws IOException {
List<Path> paths = new ArrayList<>();
try {
Path indexPath = fetcher.fetch(REPO_ID, "model.safetensors.index.json", true);
if (indexPath != null && Files.exists(indexPath)) {
String indexContent = Files.readString(indexPath, StandardCharsets.UTF_8);
JsonObject indexJson = JsonParser.parseString(indexContent).getAsJsonObject();
JsonObject weightMap = indexJson.getAsJsonObject("weight_map");
Set<String> shardFiles = new LinkedHashSet<>();
for (Map.Entry<String, JsonElement> e : weightMap.entrySet()) {
shardFiles.add(e.getValue().getAsString());
}
System.out.println(" 发现分片模型: " + shardFiles.size() + " 个分片");
for (String shard : shardFiles) {
Path shardPath = fetcher.fetch(REPO_ID, shard);
paths.add(shardPath);
System.out.println(" ✓ " + shard);
}
return paths;
}
} catch (Exception e) {
System.out.println(" 分片索引读取失败,尝试单文件: " + e.getMessage());
}
try {
Path modelPath = fetcher.fetch(REPO_ID, "model.safetensors");
paths.add(modelPath);
return paths;
} catch (IOException e) {
throw new IOException("无法下载模型权重文件: " + e.getMessage());
}
}
/**
* Detect best available compute device
*/
public static Device detectDevice() {
String os = System.getProperty("os.name").toLowerCase();
if (os.contains("mac")) {
try {
Device mps = new Device(torch.DeviceType.MPS);
System.out.println("[Device] 使用 MPS (Apple GPU)");
return mps;
} catch (Exception e) {
System.out.println("[Device] MPS 不可用,使用 CPU");
}
}
try {
if (torch.cuda_is_available()) {
System.out.println("[Device] 使用 CUDA");
return new Device(torch.DeviceType.CUDA);
}
} catch (Exception ignored) {}
System.out.println("[Device] 使用 CPU");
return new Device(torch.DeviceType.CPU);
}
/**
* Small CLI runner for quick testing. Use --small N to limit tensors loaded (safe dev mode).
*/
public static void main(String[] args) throws Exception {
Path cache = Paths.get("./cache_qwen3vl_instruct");
int small = 0;
for (int i = 0; i < args.length; i++) {
if ("--cache".equals(args[i]) && i + 1 < args.length) {
cache = Paths.get(args[++i]);
} else if ("--small".equals(args[i]) && i + 1 < args.length) {
small = Integer.parseInt(args[++i]);
}
}
Device dev = detectDevice();
if (small > 0) {
System.out.println("Running in small-test mode, maxTensorsToLoad=" + small);
load(cache, dev, small);
} else {
System.out.println("Running full load (may crash on native code) - use --small to avoid");
load(cache, dev, Integer.MAX_VALUE);
}
}
}
```
tokenizer
```java
package lance.pytorch.tokenizer;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* A Java implementation of the Hugging Face BPE (Byte-Pair Encoding) Tokenizer.
* This class is designed to load a tokenizer.json file and perform encoding and decoding,
* similar to how Hugging Face's tokenizers work.
*/
public class HfBpeTokenizer implements LanceTokenizer {
private final Map<String, Long> vocab;
private final Map<Long, String> reversedVocab;
private final Map<Pair, Integer> merges;
private final Map<String, Long> specialTokens;
private final Map<Long, String> reversedSpecialTokens;
private final Pattern pattern;
private final Map<Byte, Character> byteToUnicodeMap = createByteToUnicodeMap();
private final Map<Character, Byte> unicodeToByteMap = byteToUnicodeMap.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
public HfBpeTokenizer(Path tokenizerJsonPath, Path vocabPath, Path mergesPath) throws IOException {
// Load tokenizer.json for overall structure and special tokens
String tokenizerContent = Files.readString(tokenizerJsonPath, StandardCharsets.UTF_8);
Gson gson = new Gson();
Map<String, Object> tokenizerData = gson.fromJson(tokenizerContent, new TypeToken<Map<String, Object>>() {}.getType());
// Load vocab.json
String vocabContent = Files.readString(vocabPath, StandardCharsets.UTF_8);
this.vocab = gson.fromJson(vocabContent, new TypeToken<Map<String, Long>>() {}.getType());
// Load merges.txt
List<String> mergeList = Files.readAllLines(mergesPath, StandardCharsets.UTF_8);
this.merges = new HashMap<>();
// Skip header line if present
for (int i = 0; i < mergeList.size(); i++) {
String line = mergeList.get(i).trim();
if (line.isEmpty() || line.startsWith("#")) continue;
String[] parts = line.split("");
if (parts.length == 2) {
this.merges.put(new Pair(parts[0], parts[1]), i);
}
}
// Extract special tokens and pattern from tokenizer.json
List<Map<String, Object>> addedTokensList = (List<Map<String, Object>>) tokenizerData.get("added_tokens");
this.specialTokens = new HashMap<>();
if (addedTokensList != null) {
for (Map<String, Object> token : addedTokensList) {
this.specialTokens.put((String) token.get("content"), ((Double) token.get("id")).longValue());
}
}
String splitPattern = "'s|'t|'re|'ve|'m|'ll|'d| ?[\\p{L}]+| ?[\\p{N}]+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
Map<String, Object> preTokenizerData = (Map<String, Object>) tokenizerData.get("pre_tokenizer");
if (preTokenizerData != null) {
Map<String, Object> preTokenizerConfig = (Map<String, Object>) preTokenizerData.get("pretokenizers");
if (preTokenizerConfig != null && preTokenizerConfig.containsKey("pattern")) {
splitPattern = (String) ((Map<String,Object>)preTokenizerConfig.get("pattern")).get("String");
}
}
this.pattern = Pattern.compile(splitPattern, Pattern.UNICODE_CHARACTER_CLASS);
// Setup reverse maps
this.reversedVocab = new HashMap<>();
for (Map.Entry<String, Long> entry : this.vocab.entrySet()) {
this.reversedVocab.put(entry.getValue(), entry.getKey());
}
this.reversedSpecialTokens = new HashMap<>();
if (this.specialTokens != null) {
for (Map.Entry<String, Long> entry : this.specialTokens.entrySet()) {
this.reversedSpecialTokens.put(entry.getValue(), entry.getKey());
}
}
}
private HfBpeTokenizer(Map<String, Long> vocab, Map<Pair, Integer> merges, Map<String, Long> specialTokens, String splitPattern) {
this.vocab = vocab;
this.merges = merges;
this.specialTokens = specialTokens != null ? specialTokens : new HashMap<>();
this.pattern = Pattern.compile(splitPattern, Pattern.UNICODE_CHARACTER_CLASS);
this.reversedVocab = new HashMap<>();
for (Map.Entry<String, Long> entry : vocab.entrySet()) {
this.reversedVocab.put(entry.getValue(), entry.getKey());
}
this.reversedSpecialTokens = new HashMap<>();
if (this.specialTokens != null) {
for (Map.Entry<String, Long> entry : this.specialTokens.entrySet()) {
this.reversedSpecialTokens.put(entry.getValue(), entry.getKey());
}
}
}
/**
* Loads a tokenizer from a tokenizer.json file.
*
* @param tokenizerPath Path to the tokenizer.json file.
* @return A new instance of HfBpeTokenizer.
* @throws IOException If the file cannot be read.
*/
public static HfBpeTokenizer fromFile(Path tokenizerPath) throws IOException {
String content = Files.readString(tokenizerPath, StandardCharsets.UTF_8);
Gson gson = new Gson();
Map<String, Object> tokenizerData = gson.fromJson(content, new TypeToken<Map<String, Object>>() {}.getType());
Map<String, Object> modelData = (Map<String, Object>) tokenizerData.get("model");
Map<String, Long> vocab = ((Map<String, Double>) modelData.get("vocab")).entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().longValue()));
List<String> mergeList = (List<String>) modelData.get("merges");
Map<Pair, Integer> merges = new HashMap<>();
for (int i = 0; i < mergeList.size(); i++) {
String[] parts = mergeList.get(i).split("");
merges.put(new Pair(parts[0], parts[1]), i);
}
List<Map<String, Object>> addedTokensList = (List<Map<String, Object>>) tokenizerData.get("added_tokens");
Map<String, Long> specialTokens = new HashMap<>();
if (addedTokensList != null) {
for (Map<String, Object> token : addedTokensList) {
specialTokens.put((String) token.get("content"), ((Double) token.get("id")).longValue());
}
}
String splitPattern = "'s|'t|'re|'ve|'m|'ll|'d| ?[\\p{L}]+| ?[\\p{N}]+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
Map<String, Object> preTokenizerData = (Map<String, Object>) tokenizerData.get("pre_tokenizer");
if (preTokenizerData != null) {
Map<String, Object> preTokenizerConfig = (Map<String, Object>) preTokenizerData.get("pretokenizers");
if (preTokenizerConfig != null && preTokenizerConfig.containsKey("pattern")) {
splitPattern = (String) ((Map<String,Object>)preTokenizerConfig.get("pattern")).get("String");
}
}
return new HfBpeTokenizer(vocab, merges, specialTokens, splitPattern);
}
@Override
public long[] encode(String text) {
List<Long> ids = new ArrayList<>();
String specialTokenRegex = specialTokens.keySet().stream()
.map(Pattern::quote)
.collect(Collectors.joining("|"));
Pattern specialTokenPattern = Pattern.compile(specialTokenRegex);
Matcher matcher = specialTokenPattern.matcher(text);
int lastEnd = 0;
while (matcher.find()) {
if (matcher.start() > lastEnd) {
ids.addAll(encodeChunk(text.substring(lastEnd, matcher.start())));
}
ids.add(specialTokens.get(matcher.group()));
lastEnd = matcher.end();
}
if (lastEnd < text.length()) {
ids.addAll(encodeChunk(text.substring(lastEnd)));
}
return ids.stream().mapToLong(l -> l).toArray();
}
private List<Long> encodeChunk(String text) {
List<Long> ids = new ArrayList<>();
Matcher matcher = pattern.matcher(text);
while (matcher.find()) {
String token = matcher.group();
byte[] bytes = token.getBytes(StandardCharsets.UTF_8);
List<String> parts = new ArrayList<>();
for (byte b : bytes) {
parts.add(byteToUnicode(b));
}
while (parts.size() > 1) {
Pair bestPair = findBestPair(parts);
if (bestPair == null) {
break;
}
parts = merge(parts, bestPair);
}
for (String part : parts) {
if (vocab.containsKey(part)) {
ids.add(vocab.get(part));
}
}
}
return ids;
}
@Override
public String decode(long[] ids) {
StringBuilder sb = new StringBuilder();
List<Byte> byteBuffer = new ArrayList<>();
for (long id : ids) {
if (reversedSpecialTokens.containsKey(id)) {
if (!byteBuffer.isEmpty()) {
sb.append(decodeBytes(byteBuffer));
byteBuffer.clear();
}
sb.append(reversedSpecialTokens.get(id));
} else {
String token = reversedVocab.get(id);
if (token != null) {
for (char c : token.toCharArray()) {
byteBuffer.add(unicodeToByte(c));
}
}
}
}
if (!byteBuffer.isEmpty()) {
sb.append(decodeBytes(byteBuffer));
}
return sb.toString();
}
@Override
public long getEosTokenId() {
// Common names for end-of-sentence token
String[] eosNames = {"<|endoftext|>", "<|im_end|>", ""};
for (String name : eosNames) {
if (specialTokens.containsKey(name)) {
return specialTokens.get(name);
}
if (vocab.containsKey(name)) {
return vocab.get(name);
}
}
return -1; // Not found
}
@Override
public long getBosTokenId() {
return 0;
}
@Override
public String getChatTemplate() {
return "";
}
@Override
public void close() {
}
private String decodeBytes(List<Byte> byteBuffer) {
byte[] bytes = new byte[byteBuffer.size()];
for (int i = 0; i < byteBuffer.size(); i++) {
bytes[i] = byteBuffer.get(i);
}
return new String(bytes, StandardCharsets.UTF_8);
}
private Pair findBestPair(List<String> parts) {
Pair bestPair = null;
int minRank = Integer.MAX_VALUE;
for (int i = 0; i < parts.size() - 1; i++) {
Pair pair = new Pair(parts.get(i), parts.get(i + 1));
if (merges.containsKey(pair)) {
int rank = merges.get(pair);
if (rank < minRank) {
minRank = rank;
bestPair = pair;
}
}
}
return bestPair;
}
private List<String> merge(List<String> parts, Pair pairToMerge) {
List<String> newParts = new ArrayList<>();
int i = 0;
while (i < parts.size()) {
if (i < parts.size() - 1 && parts.get(i).equals(pairToMerge.first) && parts.get(i + 1).equals(pairToMerge.second)) {
newParts.add(pairToMerge.first + pairToMerge.second);
i += 2;
} else {
newParts.add(parts.get(i));
i++;
}
}
return newParts;
}
private static Map<Byte, Character> createByteToUnicodeMap() {
Map<Byte, Character> map = new HashMap<>();
int i = 0;
for (int b = 0; b < 256; b++) {
if ((b >= '!' && b <= '~') || (b >= '¡' && b <= '¬') || (b >= '®' && b <= 'ÿ')) {
map.put((byte) b, (char) b);
} else {
map.put((byte) b, (char) (256 + i++));
}
}
return Collections.unmodifiableMap(map);
}
private String byteToUnicode(byte b) {
return String.valueOf(byteToUnicodeMap.get(b));
}
private byte unicodeToByte(char c) {
return unicodeToByteMap.get(c);
}
@Override
public Map<String, Long> getVocab() {
return Collections.unmodifiableMap(vocab);
}
public Map<String, Long> getSpecialTokens() {
return Collections.unmodifiableMap(specialTokens);
}
private static class Pair {
final String first;
final String second;
Pair(String first, String second) {
this.first = first;
this.second = second;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Pair pair = (Pair) o;
return first.equals(pair.first) && second.equals(pair.second);
}
@Override
public int hashCode() {
return Objects.hash(first, second);
}
}
}
```
看看我们的测试用例
```java
package lance.test;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import lance.pytorch.*;
import lance.pytorch.tokenizer.LanceTokenizer;
import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* Qwen3-VL-2B-Instruct 模型完整测试 V2
*
* 使用 DJL HuggingFace tokenizers 库替代自实现 tokenizer
*
* 测试内容:
* 1. 从 HuggingFace 下载模型文件 (config, tokenizer, safetensors)
* 2. DJL Tokenizer 编码/解码验证
* 3. 模型权重加载与检查
* 4. 文本生成推理
* 5. 多轮对话
*/
public class Qwen3VLInstructTestV2 {
public static void main(String[] args) {
// Enable heap fallback for BF16 on macOS because native from_blob is unstable
System.setProperty("LANCE_ALLOW_HEAP_BF16_FALLBACK", "true");
// lance.pytorch.Bf16RuntimeConfig.ALLOW_HEAP_BF16_FALLBACK = true;
try {
printBanner("Qwen3-VL-2B-Instruct 完整测试 V2");
// 缓存目录(与其他模型分开)
Path cacheDir = new File("./cache_qwen3vl_instruct").toPath();
System.out.println("缓存目录: " + cacheDir.toAbsolutePath());
// ==================== 加载模型 ====================
System.out.println("\n[阶段1] 下载并加载模型...");
long startMs = System.currentTimeMillis();
Qwen3VLInstructModel model = Qwen3VLInstructLoader.load(cacheDir);
long loadMs = System.currentTimeMillis() - startMs;
System.out.println("\n✓ 模型加载完成 (" + loadMs + "ms)");
// ==================== 测试1: Tokenizer ====================
testTokenizer(model);
// ==================== 测试2: 权重检查 ====================
testWeights(model);
// Before running native inference, perform safety checks to avoid known
// issues (zero-copy BF16 failures, huge tensors that would force heap
// conversions, or device support absent). If unsafe, skip heavy native
// inference and print diagnostics so the user can address TorchOps/TorchBridge.
if (!isSafeForInference(model)) {
System.out.println("\n⚠ 检测到模型/权重目前不安全用于本地原生推理 (可能触发 native 崩溃)。");
System.out.println(" 建议:修复 lance.pytorch.TorchOps 与 TensorDataTorchBridge 的 BF16/from_blob/Device 路径,或在更强的环境(MPS/GPU)上运行。\n");
printBanner("跳过本地推理测试(已输出 tokenizer 与权重检查)");
return;
}
// ==================== 测试3: 文本生成 ====================
testTextGeneration(model);
// ==================== 测试4: 多轮对话 ====================
testMultiTurn(model);
// ==================== 完成 ====================
printBanner("所有测试完成 ✓");
} catch (Exception e) {
System.err.println("\n✗ 测试失败: " + e.getMessage());
e.printStackTrace();
System.exit(1);
}
}
// ==================== Test 1: Tokenizer ====================
private static void testTokenizer(Qwen3VLInstructModel model) {
System.out.println("\n" + "=".repeat(70));
System.out.println("[测试1] Tokenizer 编码/解码");
System.out.println("=".repeat(70));
try {
Object tokenizer = model.getTokenizerObj();
if (tokenizer == null) {
// fallback to old accessor if new one not found or returns null
tokenizer = model.getTokenizer();
}
if (tokenizer == null) {
throw new IllegalStateException("模型未返回 tokenizer 实例");
}
String[] tests = {
"Hello, world!",
"你好,世界!",
"What is artificial intelligence?",
"2 + 2 = 4",
"The quick brown fox jumps over the lazy dog.",
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>",
};
int passed = 0;
for (String text : tests) {
try {
long[] ids = encodeTokenizer(tokenizer, text);
String decoded = decodeTokenizer(tokenizer, ids);
boolean ok = ids.length > 0 && decoded != null && !decoded.isEmpty();
String status = ok ? "✓" : "✗";
System.out.println(" " + status + " \"" + text.substring(0, Math.min(50, text.length())) +
(text.length() > 50 ? "..." : "") +
"\" → " + ids.length + " tokens → \"" + decoded.substring(0, Math.min(50, decoded.length())) +
(decoded.length() > 50 ? "..." : "") + "\"");
if (ok) passed++;
} catch (Exception e) {
System.err.println(" ✗ 处理 '" + text + "' 时出错: " + e.getMessage());
}
}
System.out.println(" 结果: " + passed + "/" + tests.length + " 通过");
} catch (Exception e) {
System.err.println(" ✗ Tokenizer 测试失败: " + e.getMessage());
e.printStackTrace();
}
}
private static long[] encodeTokenizer(Object tokenizer, String text) {
if (tokenizer instanceof HuggingFaceTokenizer djlTok) {
// DJL returns an Encoding object; convert to ids
var enc = djlTok.encode(text);
return enc.getIds();
}
if (tokenizer instanceof LanceTokenizer lanceTok) {
return lanceTok.encode(text);
}
throw new IllegalStateException("Unsupported tokenizer: " + tokenizer.getClass().getName());
}
private static String decodeTokenizer(Object tokenizer, long[] ids) {
if (tokenizer instanceof HuggingFaceTokenizer djlTok) {
// DJL has decode taking long[]
return djlTok.decode(ids, true);
}
if (tokenizer instanceof LanceTokenizer lanceTok) {
return lanceTok.decode(ids);
}
throw new IllegalStateException("Unsupported tokenizer: " + tokenizer.getClass().getName());
}
// ==================== Test 2: Weight Check ====================
private static void testWeights(Qwen3VLInstructModel model) {
System.out.println("\n" + "=".repeat(70));
System.out.println("[测试2] 权重检查");
System.out.println("=".repeat(70));
Map<String, org.bytedeco.pytorch.Tensor> weights = model.getWeights();
System.out.println(" 总权重数: " + weights.size());
// Print first 10 weight names to check naming convention
System.out.println("\n --- 前20个权重名称 ---");
int count = 0;
for (String key : weights.keySet()) {
if (count++ < 20) {
org.bytedeco.pytorch.Tensor t = weights.get(key);
long[] shape = t.sizes().vec().get();
System.out.println(" " + key + " " + java.util.Arrays.toString(shape) + " " + t.dtype());
}
}
if (weights.size() > 20) {
System.out.println(" ... (共 " + weights.size() + " 个)");
}
Qwen3VLInstructConfig cfg = model.getConfig();
// Check critical weights exist (text model)
String[] critical = {
"model.language_model.embed_tokens.weight",
"model.language_model.norm.weight",
"model.language_model.layers.0.self_attn.q_proj.weight",
"model.language_model.layers.0.self_attn.k_proj.weight",
"model.language_model.layers.0.self_attn.v_proj.weight",
"model.language_model.layers.0.self_attn.o_proj.weight",
"model.language_model.layers.0.self_attn.q_norm.weight",
"model.language_model.layers.0.self_attn.k_norm.weight",
"model.language_model.layers.0.mlp.gate_proj.weight",
"model.language_model.layers.0.mlp.up_proj.weight",
"model.language_model.layers.0.mlp.down_proj.weight",
"model.language_model.layers.0.input_layernorm.weight",
"model.language_model.layers.0.post_attention_layernorm.weight",
};
// Also check vision tower weights
String[] visionCritical = {
"model.visual.patch_embed.proj.weight",
"model.visual.blocks.0.attn.qkv.weight",
"model.visual.blocks.0.attn.proj.weight",
"model.visual.blocks.0.mlp.linear_fc1.weight",
"model.visual.blocks.0.mlp.linear_fc2.weight",
"model.visual.blocks.0.norm1.weight",
"model.visual.blocks.0.norm2.weight",
};
int found = 0;
for (String name : critical) {
org.bytedeco.pytorch.Tensor t = weights.get(name);
if (t != null) {
found++;
long[] shape = t.sizes().vec().get();
System.out.println(" ✓ " + name + " " + java.util.Arrays.toString(shape));
} else {
System.out.println(" ✗ " + name + " (缺失)");
}
}
// Check layer count
int maxLayer = -1;
for (String key : weights.keySet()) {
if (key.startsWith("model.language_model.layers.")) {
try {
int idx = Integer.parseInt(key.split("\\.")[3]);
maxLayer = Math.max(maxLayer, idx);
} catch (Exception ignore) {}
}
}
System.out.println(" 最大层索引: " + maxLayer + " (期望: " + (cfg.numHiddenLayers - 1) + ")");
System.out.println(" 结果: " + found + "/" + critical.length + " 关键文本权重存在");
// Check vision weights
System.out.println("\n --- 视觉模块权重 ---");
int vFound = 0;
for (String name : visionCritical) {
org.bytedeco.pytorch.Tensor t = weights.get(name);
if (t != null) {
vFound++;
long[] shape = t.sizes().vec().get();
System.out.println(" ✓ " + name + " " + java.util.Arrays.toString(shape));
} else {
System.out.println(" ✗ " + name + " (缺失)");
}
}
// Count vision blocks
int maxVisBlock = -1;
for (String key : weights.keySet()) {
if (key.startsWith("model.visual.blocks.")) {
try {
int idx = Integer.parseInt(key.split("\\.")[3]);
maxVisBlock = Math.max(maxVisBlock, idx);
} catch (Exception ignore) {}
}
}
if (maxVisBlock >= 0) {
System.out.println(" 视觉层数: " + (maxVisBlock + 1) + " (期望: " + cfg.visionDepth + ")");
}
System.out.println(" 结果: " + vFound + "/" + visionCritical.length + " 关键视觉权重存在");
}
// ==================== Test 3: Text Generation ====================
private static void testTextGeneration(Qwen3VLInstructModel model) {
System.out.println("\n" + "=".repeat(70));
System.out.println("[测试3] 文本生成");
System.out.println("=".repeat(70));
String[] prompts = {
"什么是人工智能?",
"Please write a haiku about spring.",
"2 + 2 equals?",
};
for (String prompt : prompts) {
System.out.println("\n [输入] " + prompt);
try {
long start = System.currentTimeMillis();
String result = model.generate(prompt, 64, 0.7f, 0.9f);
long elapsed = System.currentTimeMillis() - start;
String preview = result.length() > 200
? result.substring(0, 200) + "..."
: result;
System.out.println(" [输出] " + preview);
System.out.println(" [耗时] " + elapsed + "ms");
} catch (Exception e) {
System.err.println(" ✗ 生成失败: " + e.getMessage());
}
}
}
// ==================== Test 4: Multi-turn ====================
private static void testMultiTurn(Qwen3VLInstructModel model) {
System.out.println("\n" + "=".repeat(70));
System.out.println("[测试4] 多轮对话模拟");
System.out.println("=".repeat(70));
String[] turns = {
"你好!",
"请介绍一下你自己。",
"Can you count from 1 to 5?",
};
List<Map<String, String>> messages = new ArrayList<>();
messages.add(Map.of("role", "system", "content", "You are a helpful assistant."));
for (int i = 0; i < turns.length; i++) {
System.out.println("\n [轮次 " + (i + 1) + "] 用户: " + turns[i]);
messages.add(Map.of("role", "user", "content", turns[i]));
try {
// The generate method should handle the full chat history
String result = model.generate(messages, 32, 0.7f, 0.9f);
String preview = result.length() > 100 ? result.substring(0, 100) + "..." : result;
System.out.println(" [助手] " + preview);
messages.add(Map.of("role", "assistant", "content", result));
} catch (Exception e) {
System.err.println(" ✗ 失败: " + e.getMessage());
e.printStackTrace();
}
}
}
// ==================== Utility ====================
private static void printBanner(String text) {
System.out.println("\n╔" + "═".repeat(68) + "╗");
System.out.println("║ " + String.format("%-66s", text) + "║");
System.out.println("╚" + "═".repeat(68) + "╝");
}
// Safety check helper: prevents calling into heavy native inference if the
// model contains very large BF16 tensors or missing embedding weights.
private static boolean isSafeForInference(Qwen3VLInstructModel model) {
try {
var weights = model.getWeights();
var embed = weights.get("model.language_model.embed_tokens.weight");
if (embed == null) {
System.err.println("[Safety] 未找到 embedding 权重,跳过推理");
return false;
}
// Check dtype
String dtype = "unknown";
try {
if (embed.dtype() != null && embed.dtype().name() != null) {
dtype = embed.dtype().name().getString();
}
} catch (Throwable ignored) {
// fallback
}
// Additional safety: if tokenizer is DJL HuggingFaceTokenizer (native JNI/JNA
// backed), avoid running native model inference here because some macOS
// native combinations exhibit crashes when DJL's native tokenizer instance
// is used together with our native torch bridge. Prefer user to use
// the standalone DJL tokenizer for encoding/decoding only, or use
// Lance's pure-java tokenizer.
Object tok = model.getTokenizer();
if (tok instanceof HuggingFaceTokenizer) {
System.err.println("[Safety] Tokenizer is DJL HuggingFaceTokenizer (native). Skipping local native inference to avoid potential native crashes.");
return false;
}
// Strict BF16 safety: if dtype indicates BF16/BFloat16, refuse local native inference
if (dtype != null && dtype.toLowerCase().contains("bfloat")) {
System.err.println("[Safety] Detected BF16 embedding dtype (" + dtype + "). Zero-copy BF16 path currently unstable on this platform; skipping native inference.");
return false;
}
// Check size - avoid huge tensors that would require heap conversions
long[] shape = embed.sizes().vec().get();
long elems = 1L;
for (long s : shape) elems = Math.multiplyExact(elems, s);
// If very large ( > 100M elements ) then not safe here
if (elems > 100_000_000L) {
System.err.println("[Safety] embedding is very large (" + elems + " elems). Zero-copy path required but may be failing; skipping native inference.");
return false;
}
// Otherwise assume safe
return true;
} catch (Throwable t) {
System.err.println("[Safety] 推理安全检查失败: " + t.getMessage());
return false;
}
}
}
```