语言大模型生命周期 - 推理

131 阅读5分钟

前言

名词解释

名词解释备注
LoRA、QLoRA模型微调方法,能加快微调速度降低成本的
HuggingFace是一家专注于自然语言处理(NLP)和机器学习的公司,提供了广泛的工具和资源来帮助开发者和研究人员构建、训练和部署机器学习模型
HuggingFace TransformersHugging Face 的 Transformers 库支持多种预训练模型,涵盖了各种自然语言处理(NLP)任务,如文本分类、文本生成、问答系统等
LLMLarge Language Model 语言大模型Chatgpt、Llama
WASM用于在现代 Web 浏览器中运行高性能应用程序。它是一种低级编程语言(支持 C 语言),设计用于在 Web 上执行,并且可以与 JavaScript 互操作。C 语言编译 WASM 模块
c #include <emscripten/emscripten.h> EMSCRIPTEN_KEEPALIVE int add(int a, int b) { return a + b; } 在 web 代码运行
html <!DOCTYPE html> <html> <head> <title>WebAssembly Example</title> </head> <body> <script> fetch('example.wasm').then(response => response.arrayBuffer() ).then(bytes => WebAssembly.instantiate(bytes) ).then(results => { const instance = results.instance; console.log(instance.exports.add(1, 2)); // 输出: 3 }); </script> </body> </html>

大模型主要生命周期

  1. 训练:大量的数据训练模型,使其执行特定任务
  2. 评估:使用验证集和测试集评估模型性能,计算评估指标
  3. 微调:进一步训练模型,提高性能
  4. 推理:使用训练好的模型进行预测或分类等任务
  5. 部署:将模型集成到实际应用中,提供预测或分类等服务

模型推理

主要流程

  1. 加载模型:从存储中加载训练好的模型。
  2. 预处理输入数据:将输入数据转换为模型可以接受的格式。
  3. 进行推理:使用模型对输入数据进行预测。
  4. 后处理输出数据:将模型的输出转换为用户可以理解的格式。
  5. 返回结果:将预测结果返回给用户或系统。

主流框架

主要针对 LLM 的推理部署框架

  • vLLM
    • 地址:github.com/vllm-projec…
    • 简介:适用于大批量 Prompt 输入,并对推理速度要求高的场景。吞吐量比 HuggingFace Transformers 高 14-24倍,比 HuggingFace Text Generation Inference(TGI)高 2.2x-2.5 倍,实现了 Continuous batching 和 PagedAttention 等技巧。但该框架对适配器(LoRA、QLoRA等)的支持不友好且缺少权重量化。
  • text-generation-inference
    • 地址:github.com/huggingface…
    • 简介:用于文本生成推断的Rust、Python 和 gRPC部署框架,可以监控服务器负载,实现了 flash attention 和 Paged attention,所有的依赖项都安装在 Docker 中:支持 HuggingFace 模型;但该框架对适配器(LoRA、QLoRA等)的支持不友好。
  • MLC LLM
    • 地址:github.com/mlc-ai/mlc-…
    • 简介:支持不同平台上的不同设备部署推理,包括移动设备(iOS 或 Android 设备等)的高效推理,压缩等。但对大规模批量调用相对不友好。

MLC - 机器学习编译

机器学习编译 (machine learning compilation, MLC) 是指,将机器学习算法从开发阶段,通过变换和优化算法,使其变成部署状态。

权重转换(Weight)

将模型的权重从一种格式转换为另一种格式,以便在不同的硬件或软件环境中进行高效推理或训练。

MLC 就是通过算法实现权重转换提高模型的算力。

Show me the code

运行环境

  • 操作系统:WIN11 WSL2
  • 硬件:RTX 3060 + 6G + i7 + 64G RAM
  • 模型:Llama-2-7b-chat-hf,申请访问地址点击这里
  • GPU:NVIDIA CU12.4

准备工作

  • 下载模型需要设置 git lfs fetch 和相关代理
  • 安装 WSL 的 UBUNTU 22.04 + DOCKER
  • 有 VPN 的话,docker 和 wsl 的代理要设置好转发端口之类的
  • 安装 conda 管理 python 环境

MLC 运行环境

安装 MLC LLM Python

  • 命令行脚本

具体可点击这里参考这个教程。

注意事项

  1. 有两种安装模式
    1. Prebuild 平台预构建包
    2. Build from source 自己构建,可以设置一些参数

安装 TVM 工具

构建 MLC LLM 的必备工具

  • 高性能代码生成器
  • 支持推理和训练功能
  • 高效的 python 编译器实现,MLC 的编译是使用它的 api 实现的

具体可点击这里参考这个教程。

WebLLM JS SDK

WebLLM 是一个高性能的浏览器内 LLM 推理引擎,以 AI 驱动的 Web 应用程序和代理的后端。

为 MLCEngine 的 Web 后端提供了专门的运行时,利用 WebGPU 进行本地加速,提供与 OpenAI 兼容的 API,并为 Web Worker 提供内置支持,以将繁重的计算与 UI 流程分开。

构建能在 WASM 运行的模型

构建教程

两种模式

  • 使用已有模型库
  • 构建自定义模型库

llm.mlc.ai/docs/deploy…

注意事项
  • 建议使用已有模型库 Llama,这样只需要转换个模型权重就行
上传配置和模型到 HuggingFace 和 github
  1. wasm 上传到 github
  2. 权重模型上传到 huggingface
    1. 利用 HG 的公共 api 转换

Demo

创建对话引擎
const useChatAi = ({
  engine,
  updateMessage,
  updateRuntimeStats,
} = {} as UseChatAiParams): UseChatAiReturnType => {
  const chatLoaded = useRef(false)
  const requestInProgress = useRef(false)
  const chatRequestChain = useRef<Promise<void>>(Promise.resolve())
  const [chatHistory, setChatHistory] = useState<ChatCompletionMessageParam[]>([])

  const unloadChat = useCallback(async () => {
    try {
      await engine.unload()

      chatLoaded.current = false
    } catch (err: unknown) {
      console.log(err)
    }
  }, [engine])

  const pushTask = useCallback((task: () => Promise<void>) => {
    const lastEvent = chatRequestChain.current
    const newEvent = lastEvent.then(task)

    chatRequestChain.current = newEvent
  }, [])

  const reset = useCallback(async (clearMessages: () => void) => {
    if (requestInProgress.current) {
      engine.interruptGenerate()
    }

    setChatHistory([])

    pushTask(async () => {
      await engine.resetChat()

      clearMessages()
    })

    return chatRequestChain.current
  }, [engine, pushTask])

  const asyncInit = useCallback(async () => {
    if (chatLoaded.current) return

    requestInProgress.current = true
    updateMessage('init', '', true)

    engine.setInitProgressCallback((report: { text: string }): void => {
      updateMessage('init', report.text, false)
    })

    try {
      const myAppConfig: AppConfig = {
        model_list: [
          {
            "model": "https://huggingface.co/my-hf-account/zen-huggingface-repo/resolve/main/",
            "model_id": "Llama-3.1-8B-Instruct-q4f32_1-MLC",
            "model_lib": "https://raw.githubusercontent.com/zen/my-repo/main/Llama-3.1-8B-Instruct-q4f32_1-MLC-webgpu.wasm",
            "required_features": ["shader-f16"],
          },
        ]
      }
      const selectedModel = 'Llama-3.1-8B-Instruct-q4f32_1-MLC'

      await engine.reload(selectedModel, { appConfig: appConfig })
    } catch (err: unknown) {
      updateMessage('error', 'Init error, ' + (err?.toString() ?? ''), true)

      console.log(err)

      await unloadChat()

      requestInProgress.current = false

      return
    }

    requestInProgress.current = false
    chatLoaded.current = true
  }, [engine, unloadChat, updateMessage])

  const asyncGenerate = useCallback(async (prompt: string) => {
    await asyncInit()

    requestInProgress.current = true

    if (prompt === '') {
      requestInProgress.current = false
      return
    }

    updateMessage('right', prompt, true)
    updateMessage('left', '', true)

    try {
      let curMessage = ''
      let usage: CompletionUsage | undefined = undefined
      const newChatHistory = [
        ...chatHistory,
        {
          content: prompt,
          role: 'user',
        },
      ] as ChatCompletionMessageParam[]

      setChatHistory(newChatHistory)

      const completion = await engine.chat.completions.create({
        messages: newChatHistory,
        stream: true,
        stream_options: { include_usage: true },
      })

      for await (const chunk of completion) {
        const curDelta = chunk.choices[0]?.delta.content

        if (curDelta) {
          curMessage += curDelta
        }

        updateMessage('left', curMessage, false)

        if (chunk.usage) {
          usage = chunk.usage
        }
      }

      const output = await engine.getMessage()

      setChatHistory((prevChatHistory) => [
        ...prevChatHistory,
        {
          content: output,
          role: 'assistant',
        },
      ])

      updateMessage('left', output, false)

      if (usage) {
        const runtimeStats =
          `prompt_tokens: ${String(usage.prompt_tokens)}, ` +
          `completion_tokens: ${String(usage.completion_tokens)}, ` +
          `prefill: ${usage.extra.prefill_tokens_per_s.toFixed(4)} tokens/sec, ` +
          `decoding: ${usage.extra.decode_tokens_per_s.toFixed(4)} tokens/sec`
        updateRuntimeStats(runtimeStats)
      }
    } catch (err: unknown) {
      updateMessage(
        'error',
        'Generate error, ' + (err?.toString() ?? ''),
        true,
      )

      console.log(err)

      await unloadChat()
    }

    requestInProgress.current = false
  }, [asyncInit, chatHistory, engine, unloadChat, updateMessage, updateRuntimeStats])

  const generate = useCallback(async (prompt: string) => {
    if (requestInProgress.current) {
      return
    }

    pushTask(async () => {
      await asyncGenerate(prompt)
    })

    return chatRequestChain.current
  }, [pushTask, asyncGenerate])

  return {
    asyncInit,
    generate,
    reset,
  }
}

export default useChatAi