Transformers.js 的基本使用

203 阅读3分钟

使用背景

目前 WebApp 使用 AI 模型一般是调用后台接口(模型的云服务或者在自己的服务器上使用模型),这样的方案存在服务器成本、网络稳定性、用户隐私等问题。所以,在这种情况下就可以使用 Transformers.js 在本地进行 AI 推理。

什么是 Transformers.js

Transformers.js 是由 Hugging Face 开发的一个 JavaScript 库,旨在让用户能够直接在浏览器中运行最先进的机器学习模型,而无需服务器支持。该库与 Hugging Face 的 Python 版 transformers 库功能等效,支持多种预训练模型,涵盖自然语言处理、计算机视觉和语音识别等任务。

Transformers.js 使用 ONNX Runtime 运行模型,支持在 CPU 和 WebGPU 上执行,提供了高效的模型转换和量化工具,方便用户将 PyTorch、TensorFlow 或 JAX 模型转换为 ONNX 格式并在浏览器中运行。

功能列表

  • 自然语言处理:文本分类、命名实体识别、问答、语言建模、摘要、翻译、多项选择和文本生成。
  • 计算机视觉:图像分类、对象检测、分割和深度估计。
  • 语音识别:自动语音识别、音频分类和文本转语音。
  • 多模态任务:嵌入、零镜头音频分类、零镜头图像分类和零镜头对象检测。

Transformers.js 在 v3 版本可以利用 WebGPU 进行高性能推理,速度相比 wasm 方案有了极大的提升。huggingface.co/blog/transf…

使用

线上 DEMO:transformers-js-basic-use.vercel.app/

快速安装

# npm
npm i @huggingface/transformers

# pnpm
pnpm add @huggingface/transformers

接口使用

import { pipeline } from "@huggingface/transformers";

const generator = await pipeline('summarization', 'Xenova/distilbart-cnn-6-6', config);
const text = 'xxx';
const output = await generator(text, {
  max_new_tokens: 100,
}); // [{ summary_text: 'xxx' }]

使用 pipeline 接口(pipeline 简化了模型的下载、加载和使用)加载模型,第一个参数传递的是 task 类型,这里使用的是 'summarization',第二个参数是模型名字(如果忽略的话会使用默认模型),第三个参数是配置信息(包含进度、缓存、设置等),这里主要了解下 config.progress_callbackconfig.deviceconfig.dtype 的使用。

  • config.progress_callback :进度回调。这里的进度包含初始化、下载、加载和准备阶段

              progress_callback: data => {
                switch (data.status) {
                  // 模型开始初始化
                  case "initiate":
                    {
                      const { name, file } = data;
                      console.log("initiate", name, file);
                    }
                    break;
                  // 模型开始下载
                  case "download":
                    {
                      const { name, file } = data;
                      console.log("download", name, file);
                    }
                    break;
                  // 模型下载进度
                  case "progress":
                    {
                      const { name, file, progress, loaded, total } = data;
                      console.log("progress", name, file, progress, loaded, total);
                    }
                    break;
                  // 模型下载完成
                  case "done":
                    {
                      const { name, file } = data;
                      console.log("done", name, file);
                    }
                    break;
                  // 模型准备完成
                  case "ready":
                    {
                      const { task, model } = data;
                      console.log("ready", task, model);
                    }
                    break;
                }
              },
    
  • config.device :设置推理的设置,默认使用的是 wasm,如果条件允许的话尽量使用 webgpu,推理速度会更快。完整的 device 类型如下:

    /**
     * The list of devices supported by Transformers.js
     */
    export const DEVICE_TYPES = Object.freeze({
        auto: 'auto', // Auto-detect based on device and environment
        gpu: 'gpu', // Auto-detect GPU
        cpu: 'cpu', // CPU
        wasm: 'wasm', // WebAssembly
        webgpu: 'webgpu', // WebGPU
        cuda: 'cuda', // CUDA
        dml: 'dml', // DirectML
    
        webnn: 'webnn', // WebNN (default)
        'webnn-npu': 'webnn-npu', // WebNN NPU
        'webnn-gpu': 'webnn-gpu', // WebNN GPU
        'webnn-cpu': 'webnn-cpu', // WebNN CPU
    });
    
  • config.dtype :设置推理的精度,如果忽略的话会根据 device 使用不同的 dtype,如果 devicewasm 的话 dtypeq8,否则为 fp32 。需要根据具体的应用场景和模型使用参数,一般来说 webgpu 使用 fp32 或者 fp16wasm 使用 q8 或者 q4

如果使用的模型没有在内置的 pipeline 中,那么需要自己手动处理模型的逻辑

// 使用 pipline
await pipeline("summarization", "Xenova/distilbart-cnn-6-6");
const summary = await generatorRef.current(input);
const output = summary[0].summary_text;

// 不适用 pipline
const model = await AutoModelForSeq2SeqLM.from_pretrained("Xenova/distilbart-cnn-6-6");
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/distilbart-cnn-6-6");
const inputs = await tokenizerRef.current([input], {
    truncation: true,
    return_tensors: true,
});
const modelOutputs = await model .generate(inputs);
const summary = await tokenizer .batch_decode(modelOutputs, {
    skip_special_tokens: true,
});
const output = summary[0];

参考

  1. www.aisharenet.com/transformer…
  2. blog.csdn.net/shebao3333/…
  3. blog.csdn.net/m0_38015699…