每天一个高级前端知识 - Day 8

2 阅读3分钟

每天一个高级前端知识 - Day 8

今日主题:前端AI - 在浏览器中运行大语言模型,使用WebGPU + Transformers.js实现本地AI助手

核心概念:AI推理正在从云端走向本地

得益于WebGPU和量化技术,现在可以在浏览器中直接运行小规模LLM(如Phi-3、Gemma-2B、Qwen-1.8B)。

性能数据(MacBook M2):

Phi-3-mini (3.8B 4-bit): 15-20 tokens/秒
Gemma-2B (2.6B 4-bit):   25-30 tokens/秒
Zephyr-3B (3B 4-bit):    20-25 tokens/秒

🔬 浏览器AI架构

┌─────────────────────────────────────────────┐
│          JavaScript (控制层)                 │
│  - 文本预处理/后处理                          │
│  - KV Cache管理                              │
│  - 流式输出                                  │
└─────────────────────────────────────────────┘
                    ↓
┌─────────────────────────────────────────────┐
│         Transformers.js (ONNX Runtime)      │
│  - 模型加载与解析                             │
│  - Tokenizer (BPE/SentencePiece)            │
│  - 张量操作抽象层                             │
└─────────────────────────────────────────────┘
                    ↓
┌─────────────────────────────────────────────┐
│              WebGPU 后端                     │
│  - 矩阵乘法加速 (GEMM)                       │
│  - Attention计算                            │
│  - 量化推理 (INT4/INT8)                     │
└─────────────────────────────────────────────┘

🚀 完整实现:本地AI助手

// 1. 安装依赖
// npm install @xenova/transformers

import { pipeline, env, AutoModel, AutoTokenizer } from '@xenova/transformers';

// 配置WebGPU后端
env.backends.onnx.wasm.numThreads = navigator.hardwareConcurrency;
env.useBrowserCache = true;
env.localModelPath = '/models/'; // 本地缓存模型

class LocalLLM {
  constructor(config) {
    this.modelName = config.modelName || 'Xenova/phi-3-mini-4k-instruct';
    this.model = null;
    this.tokenizer = null;
    this.isLoading = false;
    this.isGenerating = false;
    this.maxTokens = config.maxTokens || 2048;
    this.temperature = config.temperature || 0.7;
    this.topK = config.topK || 50;
    this.topP = config.topP || 0.95;
    this.kvCache = null;
    
    // 进度回调
    this.onProgress = config.onProgress || (() => {});
    this.onToken = config.onToken || (() => {});
    this.onComplete = config.onComplete || (() => {});
  }
  
  async load() {
    if (this.model) return;
    
    this.isLoading = true;
    this.onProgress({ status: 'loading', progress: 0 });
    
    try {
      // 加载Tokenizer(使用Web Worker避免阻塞UI)
      this.tokenizer = await AutoTokenizer.from_pretrained(this.modelName, {
        progress_callback: (progress) => {
          this.onProgress({ 
            status: 'loading-tokenizer', 
            progress: progress.progress 
          });
        }
      });
      
      // 加载模型(WebGPU后端自动选择)
      this.model = await AutoModel.from_pretrained(this.modelName, {
        dtype: 'q4f16',  // 4-bit量化,减少75%内存
        device: 'webgpu', // 使用WebGPU
        progress_callback: (progress) => {
          this.onProgress({ 
            status: 'loading-model', 
            progress: progress.progress 
          });
        }
      });
      
      this.isLoading = false;
      this.onProgress({ status: 'ready', progress: 1 });
      
      return true;
    } catch (error) {
      console.error('模型加载失败:', error);
      this.isLoading = false;
      throw error;
    }
  }
  
  async generate(prompt, options = {}) {
    if (!this.model) {
      throw new Error('模型未加载,请先调用load()');
    }
    
    if (this.isGenerating) {
      console.warn('已有生成任务在进行中');
      return;
    }
    
    this.isGenerating = true;
    const startTime = performance.now();
    
    // 构建对话模板(Phi-3格式)
    const messages = [
      { role: 'system', content: 'You are a helpful AI assistant.' },
      { role: 'user', content: prompt }
    ];
    
    const formattedPrompt = this.applyChatTemplate(messages);
    
    // Tokenize输入
    const inputIds = this.tokenizer.encode(formattedPrompt);
    let currentIds = inputIds;
    
    // KV Cache(关键优化:避免重复计算)
    if (!this.kvCache) {
      this.kvCache = this.model.createKVCache();
    }
    
    let generatedTokens = [];
    let tokenCount = 0;
    const maxNewTokens = options.maxTokens || this.maxTokens;
    
    // 自回归生成循环
    for (let step = 0; step < maxNewTokens; step++) {
      // 前向传播
      const outputs = await this.model.forward({
        input_ids: currentIds,
        past_key_values: this.kvCache,
        use_cache: true
      });
      
      // 更新KV Cache
      this.kvCache = outputs.past_key_values;
      
      // 获取下一个token的logits
      const nextTokenLogits = outputs.logits.slice(-1)[0];
      
      // 采样策略
      let nextTokenId = this.sampleToken(nextTokenLogits, {
        temperature: options.temperature || this.temperature,
        topK: options.topK || this.topK,
        topP: options.topP || this.topP
      });
      
      // 检查结束符
      if (nextTokenId === this.tokenizer.eos_token_id) {
        break;
      }
      
      // 解码token
      const token = this.tokenizer.decode([nextTokenId]);
      generatedTokens.push(token);
      tokenCount++;
      
      // 流式输出回调
      this.onToken(token, {
        tokenCount,
        timestamp: performance.now() - startTime
      });
      
      // 准备下一步输入
      currentIds = [[nextTokenId]];
    }
    
    // 生成完成
    const fullText = generatedTokens.join('');
    this.isGenerating = false;
    
    this.onComplete(fullText, {
      tokenCount,
      totalTime: performance.now() - startTime,
      tokensPerSecond: tokenCount / ((performance.now() - startTime) / 1000)
    });
    
    return fullText;
  }
  
  sampleToken(logits, options) {
    // Top-K采样
    const logitsArray = Array.from(logits);
    
    if (options.topK && options.topK < logitsArray.length) {
      // 获取Top-K索引
      const indices = logitsArray
        .map((val, idx) => ({ val, idx }))
        .sort((a, b) => b.val - a.val)
        .slice(0, options.topK);
      
      // 重置其他logits为-Infinity
      const filtered = new Array(logitsArray.length).fill(-Infinity);
      indices.forEach(({ idx, val }) => {
        filtered[idx] = val;
      });
      
      // 应用Temperature
      const scaled = filtered.map(v => v / options.temperature);
      return this.softmaxSample(scaled);
    }
    
    // 普通Temperature采样
    const scaled = logitsArray.map(v => v / options.temperature);
    return this.softmaxSample(scaled);
  }
  
  softmaxSample(logits) {
    // 计算softmax
    const maxLogit = Math.max(...logits);
    const exp = logits.map(x => Math.exp(x - maxLogit));
    const sum = exp.reduce((a, b) => a + b, 0);
    const probs = exp.map(x => x / sum);
    
    // 随机采样
    const random = Math.random();
    let cumulative = 0;
    for (let i = 0; i < probs.length; i++) {
      cumulative += probs[i];
      if (random < cumulative) return i;
    }
    return probs.length - 1;
  }
  
  applyChatTemplate(messages) {
    // Phi-3 chat template
    let result = '';
    for (const msg of messages) {
      if (msg.role === 'system') {
        result += `<|system|>\n${msg.content}<|end|>\n`;
      } else if (msg.role === 'user') {
        result += `<|user|>\n${msg.content}<|end|>\n`;
      } else if (msg.role === 'assistant') {
        result += `<|assistant|>\n${msg.content}<|end|>\n`;
      }
    }
    result += '<|assistant|>\n';
    return result;
  }
  
  // 模型卸载,释放内存
  async unload() {
    this.model = null;
    this.tokenizer = null;
    this.kvCache = null;
    this.isGenerating = false;
    
    // 触发垃圾回收(提示浏览器)
    if (window.gc) window.gc();
  }
  
  // 获取模型信息
  getInfo() {
    return {
      modelName: this.modelName,
      isLoaded: !!this.model,
      isLoading: this.isLoading,
      isGenerating: this.isGenerating,
      maxTokens: this.maxTokens,
      temperature: this.temperature
    };
  }
}

💡 高级优化技术

// 1. 使用Web Worker避免UI阻塞
class WorkerLLM {
  constructor() {
    this.worker = new Worker('llm-worker.js');
    this.callbacks = new Map();
    this.requestId = 0;
  }
  
  async generate(prompt) {
    const id = this.requestId++;
    
    return new Promise((resolve) => {
      this.callbacks.set(id, resolve);
      this.worker.postMessage({ id, type: 'generate', prompt });
    });
  }
}

// 2. 投机解码(Speculative Decoding)- 加速2-3倍
class SpeculativeDecoder {
  constructor(draftModel, targetModel) {
    this.draftModel = draftModel; // 小模型,快速生成
    this.targetModel = targetModel; // 大模型,验证
  }
  
  async generate(prompt, numTokens) {
    let tokens = [];
    
    while (tokens.length < numTokens) {
      // 草稿模型生成5个候选token
      const draftTokens = await this.draftModel.generateFast(prompt, 5);
      
      // 目标模型并行验证
      const validations = await this.targetModel.verify(prompt, draftTokens);
      
      // 接受验证通过的tokens
      const accepted = this.findAccepted(draftTokens, validations);
      tokens.push(...accepted);
      
      if (accepted.length < draftTokens.length) break;
    }
    
    return tokens;
  }
}

// 3. 上下文缓存 - 重复对话不重复计算
class ContextCache {
  constructor() {
    this.cache = new Map();
    this.maxSize = 10;
  }
  
  getOrCompute(sessionId, computeFn) {
    if (this.cache.has(sessionId)) {
      return this.cache.get(sessionId);
    }
    
    const result = computeFn();
    if (this.cache.size >= this.maxSize) {
      const firstKey = this.cache.keys().next().value;
      this.cache.delete(firstKey);
    }
    this.cache.set(sessionId, result);
    return result;
  }
}

🎯 今日挑战

实现一个浏览器内运行的RAG系统

要求:

  1. 使用本地LLM + 向量数据库(LanceDB)
  2. PDF文档上传和向量化(使用Transformers.js embedding模型)
  3. 实现相似度搜索 + LLM回答生成
  4. 完全离线运行
class LocalRAG {
  constructor() {
    this.llm = new LocalLLM({ modelName: 'Xenova/phi-3-mini-4k-instruct' });
    this.embedder = null;
    this.vectorDB = null;
    this.chunks = [];
  }
  
  async init() {
    await this.llm.load();
    
    // 加载embedding模型
    this.embedder = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
    
    // 初始化向量数据库(IndexedDB存储)
    this.vectorDB = await this.initVectorDB();
  }
  
  async addDocument(text, metadata) {
    // 分块
    const chunks = this.chunkText(text, 512);
    
    for (const chunk of chunks) {
      // 生成向量
      const embedding = await this.embedder(chunk, { pooling: 'mean' });
      
      // 存储
      await this.vectorDB.add({
        vector: embedding.data,
        text: chunk,
        metadata
      });
    }
  }
  
  async query(question, topK = 3) {
    // 问题向量化
    const questionEmbedding = await this.embedder(question, { pooling: 'mean' });
    
    // 相似度搜索
    const similar = await this.vectorDB.search(questionEmbedding.data, topK);
    
    // 构建prompt
    const context = similar.map(s => s.text).join('\n\n');
    const prompt = `基于以下上下文回答问题:\n\n上下文:${context}\n\n问题:${question}\n\n回答:`;
    
    // LLM生成
    return await this.llm.generate(prompt);
  }
  
  chunkText(text, chunkSize) {
    // 智能分块(按句子边界)
    const sentences = text.split(/[.!?]+/);
    const chunks = [];
    let current = '';
    
    for (const sentence of sentences) {
      if ((current + sentence).length > chunkSize) {
        chunks.push(current);
        current = sentence;
      } else {
        current += sentence;
      }
    }
    if (current) chunks.push(current);
    
    return chunks;
  }
}

📊 性能优化清单

  • WebGPU后端(比WebGL快5-10倍)
  • 4-bit量化(减少75%内存,保持95%精度)
  • KV Cache(避免重复计算,加速2-3倍)
  • Speculative Decoding(加速2-3倍)
  • 批处理推理(同时处理多个请求)
  • 模型分片加载(降低首屏时间)
  • IndexedDB持久化(二次加载快10倍)

🚀 在线演示模型

模型大小内存速度质量
Phi-3-mini-4k2.2GB2.8GB⭐⭐⭐⭐⭐⭐⭐⭐⭐
Gemma-2B1.5GB1.9GB⭐⭐⭐⭐⭐⭐⭐⭐⭐
Qwen-1.8B1.1GB1.4GB⭐⭐⭐⭐⭐⭐⭐⭐
TinyLlama-1.1B0.7GB0.9GB⭐⭐⭐⭐⭐⭐⭐

明日预告:编译原理在前端的应用 - 实现一个属于你自己的编译器(从AST到代码生成)

💡 核心洞察:本地AI不仅是隐私保护的解决方案,更是低延迟实时交互的关键!