十年 C++ 后端 GAP 六个月,写了一个近 3 万行的LLM-TFFInfer推理框架项目解析

6 阅读20分钟

十年 C++ 后端 GAP 六个月,写了一个近 3 万行的LLM-TFFInfer推理框架项目解析

目录


1. 项目概述

1.1 项目简介

TFFInfer 是一个从零自主研发的大语言模型(LLM)推理框架,采用现代 C++20 标准开发,支持 CUDA GPU 加速。项目实现了完整的 LLM 推理流水线,包括:

  • 模型加载(支持 GGUF 格式)
  • 计算图构建与优化
  • 显存管理与优化
  • KV Cache 管理
  • 任务调度与并行执行
  • 高性能 CUDA Kernel 实现

1.2 项目规模

  • 代码量:近 30,000 行
  • 语言:C++20 + CUDA
  • 支持的模型:Qwen3-8B (Q8_0 量化)、GGUF 格式模型
  • 编译要求:CMake 3.18+、CUDA 12.x

1.3 核心技术栈

类别技术选型版本用途
编程语言C++20-核心逻辑、模板元编程
GPU 后端CUDA12.xGPU 算子实现
构建系统CMake3.18+跨平台构建
日志系统Google glogv0.7.1高性能日志记录
任务调度Taskflowv3.10.0DAG 任务流引擎
数学优化libdividev5.2.0快速除法运算
JSON 解析nlohmann/jsonv4.0.0配置文件解析
CUTLASSNVIDIA CUTLASS-GEMM 优化库

2. 技术架构

2.1 整体架构层次

应用层 (Application Layer)
    ↓
运行时层 (Runtime Layer) - LLMInferRuntime、TaskFlowManager、MemManager、KVCache
    ↓
核心层 (Core Layer) - Graph Engine、Tensor/Memory、Device Manager
    ↓
模型层 (Model Layer) - ModelLoader、ModelCreator、ModelDetector
    ↓
算子层 (Kernel Layer) - CUDA Kernels、CPU Kernels、OP Creators
    ↓
基础设施层 (Infrastructure Layer) - ModuleFactory、FunctionFactory、Logger

2.2 目录结构说明

tffinfer/
├── include/                    # 公共头文件(对外 API)
│   ├── ExportInc.h            # 导出宏定义
│   ├── ModuleFactory.h        # 模块工厂(核心注册机制)
│   ├── FunctionFactory.h      # 函数工厂(回调注册)
│   ├── ModelFactory.h         # 模型工厂
│   └── ModuleObject.h         # 模块对象基类
│
├── src/
│   ├── core/                  # 核心引擎层
│   │   ├── device/            # 设备管理(CPU/GPU)
│   │   ├── graph/             # 计算图引擎
│   │   ├── mem/               # 内存管理(Tensor/Memory)
│   │   ├── model/             # 模型层(Loader/Creator/Detector)
│   │   ├── runtime/           # 运行时层(Runtime/Manager)
│   │   ├── quant/             # 量化相关
│   │   └── sampler/           # 采样器
│   │
│   ├── kernel/                # 算子实现层
│   │   ├── cuda/              # CUDA Kernel(FlashAttention/GEMM/RoPE等)
│   │   ├── cpu/               # CPU Kernel(备用)
│   │   └── include/           # Kernel 接口定义
│   │
│   ├── taskgraph/             # 任务调度层
│   ├── factory/               # 工厂实现
│   ├── log/                   # 日志封装
│   └── utils/                 # 工具函数
│
├── test/                      # 测试程序
├── model/                     # 模型文件
└── docs/                      # 文档

3. 核心模块详解

3.1 基础设施层 (Infrastructure Layer)

3.1.1 ModuleFactory - 模块工厂模式

文件位置include/ModuleFactory.h

核心功能: 这是一个高度灵活的模块注册与创建工厂,支持两种创建策略:

  • PROTOTYPE: 每次创建新实例
  • SINGLETON: 单例模式,重复使用同一实例

关键设计:

// 支持多种键类型(字符串或枚举)
using ModuleKeyType = std::variant<std::string, int>;

// 创建者信息结构
template<typename Base>
struct CreatorInfo {
    std::function<std::shared_ptr<Base>()> creator;  // 创建函数
    CreationPolicy policy;                            // 创建策略
};

// 单例实例映射(弱引用,避免内存泄漏)
template<typename Base>
struct SingletonInstanceMap {
    static std::unordered_map<ModuleKeyType, std::weak_ptr<Base>, 
                              ModuleKeyHash, ModuleKeyEqual> instances;
};

注册宏:

// 原型模式注册(每次创建新对象)
REGISTER_MODULE_OBJECT_PROTOTYPE(T, Base, type, key, ...)

// 单例模式注册(全局唯一实例)
REGISTER_MODULE_OBJECT(T, Base, type, key, ...)

使用示例:

// 注册模块
REGISTER_MODULE_OBJECT(DeviceCUDA, DeviceBaseObject, 
                      DEVICE_BACKEND_FLAG, DEVICE_BACKEND_TYPE_CUDA);

// 创建模块
auto device = ModuleFactory::instance()->create_shared<DeviceBaseObject>(
    DEVICE_BACKEND_FLAG, DEVICE_BACKEND_TYPE_CUDA);

优势:

  • ✅ 解耦模块创建与使用
  • ✅ 支持运行时动态注册
  • ✅ 自动管理生命周期
  • ✅ 线程安全(通过锁保护)

3.1.2 FunctionFactory - 函数回调工厂

文件位置include/FunctionFactory.h

核心功能: 提供基于标签的函数注册与调用机制,用于解耦算子实现与调用。

三层结构:

类型 (flag) → 分组 (group) → 键 (key) → 函数 (function)

关键方法:

// 注册回调函数
register_callback<Func>(flag, key, func)

// 获取回调函数
get_callback<Func>(flag, key)

// 调用回调函数
invoke(flag, key, args...)

// 检查是否存在
has_callback(flag, key)

应用场景:

  • 算子实现的动态绑定
  • 不同后端的函数切换(CPU/GPU)
  • 插件式架构支持

3.2 设备管理层 (Device Layer)

3.2.1 DeviceManager - 设备管理器

文件位置src/core/device/DeviceManager.h

核心职责:

  • 统一管理所有计算设备(CPU/GPU)
  • 设备初始化与查询
  • 设备 ID 映射

关键数据结构:

class DeviceManager : public ModuleObject {
    // 设备列表
    std::set<std::shared_ptr<DeviceBaseObject>> _devices;
  
    // 设备 ID → 设备对象映射
    std::unordered_map<int, std::shared_ptr<DeviceBaseObject>> _device_map;
};

初始化流程:

DeviceManager() {
    // 1. 创建并初始化 GPU 设备
    auto gpu_device = ModuleFactory::instance()
        ->create_shared<DeviceBaseObject>(DEVICE_BACKEND_FLAG, 
                                          DEVICE_BACKEND_TYPE_CUDA);
    gpu_device->device_init();
    _devices.insert(gpu_device);
  
    // 2. 创建并初始化 CPU 设备
    auto cpu_device = ModuleFactory::instance()
        ->create_shared<DeviceBaseObject>(DEVICE_BACKEND_FLAG, 
                                          DEVICE_BACKEND_TYPE_CPU);
    cpu_device->device_init();
    _devices.insert(cpu_device);
}

3.2.2 MemBufferAllocator - 内存分配器

文件位置:

  • src/core/device/cuda/MemBufferAllocatorCUDA.h
  • src/core/device/cpu/MemBufferAllocatorCPU.h

设计理念:

  • 策略模式:不同设备有不同分配策略
  • 统一接口:上层无需关心具体实现
  • 资源隔离:每个设备独立管理自己的内存

3.3 内存与张量层 (Memory & Tensor Layer)

3.3.1 Tensor - 张量类

文件位置src/core/mem/Tensor.h

核心功能: Tensor 是框架中最基本的数据单元,封装了:

  • 多维数组表示
  • 数据类型管理
  • 内存生命周期
  • Stride 推导

关键属性:

class Tensor : public std::enable_shared_from_this<Tensor> {
    // 维度信息
    size_t _n_dims;                                    // 有效维度数
    std::array<int64_t, MAX_TENSOR_DIM> _shape;        // 形状
    std::array<int64_t, MAX_TENSOR_DIM> _strides;      // 步长
  
    // 类型信息
    DataType _data_type;                               // 数据类型(FP32/FP16/INT8等)
    size_t _type_size;                                 // 元素大小
    uint32_t _blk_size;                                // 块大小(量化用)
  
    // 内存管理
    MemoryType _memory_type;                           // 内存类型(权重/KV/工作区)
    std::shared_ptr<Memory> _buffer;                   // 底层内存
    bool _use_external;                                // 是否使用外部内存
    int64_t _external_memory_index;                    // 外部内存索引
  
    // 生命周期(用于内存复用优化)
    int _start;                                        // 活跃开始时间点
    int _end;                                          // 活跃结束时间点
    int _priority;                                     // 优先级
  
    // 分配器
    std::shared_ptr<MemBufferAllocatorBaseObject> _allocator;
};

Stride 推导算法:

void stride_infer() {
    _strides[0] = _type_size;
    if (_strides.size() > 1) {
        _strides[1] = _strides[0] * (_shape[0] / _blk_size);
    }
    for (int j = 2; j < _shape.size(); ++j) {
        _strides[j] = _strides[j - 1] * _shape[j - 1];
    }
}

内存分配策略:

// 内部内存(自动管理)
Tensor(DataType type, MemoryType mem_type, shapes, false, allocator);
tensor.allocate();  // 自动分配

// 外部内存(由 MemManager 统一管理)
Tensor(DataType type, MemoryType mem_type, shapes, true, nullptr);
tensor.set_buffer_data(ptr, size, mem_offset);  // 绑定外部内存

访问接口:

// 类型安全的元素访问
template<typename T, typename... Args>
T& at(Args... indices) {
    return *reinterpret_cast<T*>(_buffer->ptr() + compute_offset(indices...));
}

// 计算偏移量
size_t compute_offset(Args... indices) {
    std::array<size_t, num_indices> idxs{static_cast<size_t>(indices)...};
    size_t offset = 0;
    for (size_t i = 0; i < num_indices; ++i) {
        offset += idxs[i] * _strides[i];
    }
    return offset;
}

3.4 计算图引擎 (Graph Engine)

3.4.1 Graph - DAG 计算图

文件位置src/core/graph/Graph.h

核心概念: 计算图是有向无环图(DAG),节点表示算子,边表示数据依赖。

关键数据结构:

class Graph : public std::enable_shared_from_this<Graph> {
    // 节点集合
    std::vector<std::shared_ptr<GraphNode>> _output_node;   // 输出节点
    std::vector<std::shared_ptr<GraphNode>> _leafs;         // 叶子节点
    std::vector<std::shared_ptr<GraphNode>> _nodes;         // 所有节点
    std::vector<std::shared_ptr<GraphNode>> _total_nodes;   // 完整节点列表
  
    // 执行顺序
    std::unordered_map<std::shared_ptr<GraphNode>, int> _exec_time;  // 执行时间戳
    std::unordered_map<std::shared_ptr<GraphNode>, size_t> _use_counts;  // 使用次数
  
    // 生命周期分析
    std::unordered_map<std::shared_ptr<GraphNode>, int> _first_use;  // 首次使用时间
    std::unordered_map<std::shared_ptr<GraphNode>, int> _last_use;   // 最后使用时间
  
    // 索引映射
    std::unordered_map<std::shared_ptr<GraphNode>, size_t> _node_to_index;
    std::unordered_map<std::shared_ptr<GraphNode>, size_t> _leaf_to_index;
};

建图流程:

void build_graph(std::shared_ptr<GraphNode> output_node) {
    // 1. 从输出节点反向遍历,收集所有依赖节点
    visit(output_node);
  
    // 2. 拓扑排序,确定执行顺序
    // 3. 识别叶子节点(输入节点)
    // 4. 分析节点生命周期
    analyze_lifetimes();
}

void visit(std::shared_ptr<GraphNode> node) {
    if (_visited.count(node)) return;
    _visited.insert(node);
  
    // 递归访问所有输入节点
    for (auto& input : node->inputs()) {
        visit(input);
    }
  
    // 添加到节点列表(后序遍历保证拓扑序)
    _nodes.push_back(node);
}

生命周期分析:

void analyze_lifetimes() {
    // 为每个节点分配执行时间戳
    for (size_t i = 0; i < _nodes.size(); ++i) {
        _exec_time[_nodes[i]] = i;
    }
  
    // 计算每个张量的首次和最后使用时间
    for (auto& node : _nodes) {
        for (auto& output : node->outputs()) {
            int exec_time = _exec_time[node];
            if (_first_use.find(output) == _first_use.end()) {
                _first_use[output] = exec_time;
            }
            _last_use[output] = exec_time;
        }
    }
}

3.4.2 GraphNode - 图节点

文件位置src/core/graph/GraphNode.h

节点类型:

enum class GraphNodeType {
    TFF_NODE_WEIGHT,      // 权重节点
    TFF_NODE_INPUT,       // 输入节点
    TFF_NODE_OUTPUT,      // 输出节点
    TFF_NODE_OP,          // 算子节点
    TFF_NODE_HOST,        // Host 节点(CPU)
    TFF_NODE_DEVICE,      // Device 节点(GPU)
    TFF_NODE_MEM_CPY,     // 内存拷贝节点
};

3.5 模型层 (Model Layer)

3.5.1 模型加载架构

设计模式: 策略模式 + 工厂模式

核心类层次:

ModelDetectorBase (检测器基类)
    └── GGUFDetector (GGUF 格式检测)

ModelLoaderBase (加载器基类)
    └── GGUFLoader (GGUF 格式加载)

ModelCreatorBase (创建器基类)
    └── QWen3Creator (Qwen3 模型创建)
    └── LLAMACreator (LLaMA 模型创建)

3.5.2 ModelDetector - 模型检测器

文件位置src/core/model/base/ModelDetectorBase.h

职责: 根据模型文件格式自动选择合适的加载器。

class ModelDetectorBase {
    // 判断是否匹配当前格式
    virtual bool matches(const std::string& file_format) const = 0;
  
    // 检测器名称
    virtual const char* name() const = 0;
  
    // 创建对应的加载器
    virtual std::shared_ptr<ModelLoaderBase> create_loader() = 0;
};

3.5.3 ModelLoader - 模型加载器

文件位置src/core/model/base/ModelLoaderBase.h

核心接口:

class ModelLoaderBase : public ModuleObject {
    // 从文件加载模型
    virtual ModelLoadResult load_from_file(
        const std::vector<std::string>& model_files_name,
        const ModelConfig& params) = 0;
  
    // 获取模型上下文
    virtual std::shared_ptr<ModelContext> get_model_ctx() = 0;
  
    // 获取模型架构名称
    virtual const char* get_arch_name() const = 0;
  
    // 获取模型配置
    virtual const ModelConfig& get_model_config() const = 0;
  
    // 获取权重映射
    virtual const std::unordered_map<std::string, ModelWeight>& 
        get_weight_map() const = 0;
  
    // 支持 mmap
    virtual bool supports_mmap() const { return true; }
};

3.5.4 ModelCreator - 模型创建器

文件位置src/core/model/base/ModelCreatorBase.h

核心职责: 根据模型架构构建计算图。

GraphContext - 计算图上下文:

struct GraphContext {
    // 模型超参数
    int _n_layer;           // 层数
    int _n_head;            // 注意力头数
    int _n_head_kv;         // KV 头数(支持 GQA)
    int _n_embd_head;       // 每个头的维度
    int _n_tokens;          // Token 数量
  
    // RoPE 配置
    float _rope_freq_base;
    float _rope_freq_scale;
    RopeType _rope_type;
  
    // 归一化配置
    float _f_norm_rms_eps;
    TFFNormType _norm_type;
  
    // 精度配置
    bool _use_fp16;
    bool _is_fuse;          // 是否融合算子
  
    // 运行时状态
    bool _is_prefill;       // 是否是预填充阶段
    int _seq_id;            // 序列 ID
  
    // 关键张量
    std::shared_ptr<Tensor> _logits;      // 输出 logits
    std::shared_ptr<Tensor> _rope_table;  // RoPE 表
    std::shared_ptr<Tensor> _mask;        // Attention Mask
  
    // 依赖组件
    std::shared_ptr<ModelLoaderBase> _model_loader;
    std::shared_ptr<LLMMemManager> _mem_manager_ptr;
    std::unordered_map<int, std::shared_ptr<LLMKVCache>> _kv_cache_ptr;
};

计算图构建接口:

class ModelCreatorBase {
    // 构建主计算图
    virtual void build_graph(...) = 0;
  
    // 构建 IO 计算图(内存管理)
    virtual void build_mem_graph(...) = 0;
  
    // 设置模型上下文
    virtual void build_model_context(const GraphContext& ctx) = 0;
  
    // 辅助构建方法
    std::shared_ptr<GraphNode> build_norm(...);
    std::shared_ptr<GraphNode> build_mul_mat_node(...);
    std::shared_ptr<GraphNode> build_attn(...);
    std::shared_ptr<GraphNode> build_ffn(...);
    std::shared_ptr<GraphNode> build_kv_cache_store_node(...);
    std::shared_ptr<GraphNode> build_kv_cache_load_node(...);
};

3.6 运行时层 (Runtime Layer)

3.6.1 LLMInferRuntime - 推理运行时(核心入口)

文件位置src/core/runtime/InferRuntime.h

核心职责: 协调整个推理流程,是所有组件的总控制器。

主要成员:

class LLMInferRuntime {
    // 模型信息
    ModelConfig _model_config;
    std::shared_ptr<LLMVocabulary> _vocabulary_ptr;
    std::shared_ptr<ModelLoaderBase> _model_loader;
    std::shared_ptr<ModelCreatorBase> _model_creator;
  
    // 计算图
    std::shared_ptr<Graph> _prefill_graph_ptr;       // 预填充图
    std::shared_ptr<Graph> _decode_graph_ptr;        // 解码图
    std::shared_ptr<Graph> _mem_graph_ptr;           // 内存管理图
  
    // 优化器与管理器
    std::shared_ptr<GraphOptimizer> _graph_optimizer;
    std::shared_ptr<LLMMemManager> _mem_manager_ptr;
    std::shared_ptr<LLMTaskFlowManager> _task_manager;
    std::shared_ptr<LLMBatchManager> _llm_batch_manager_ptr;
    std::shared_ptr<DeviceManager> _device_manager;
  
    // KV Cache
    std::unordered_map<int, std::shared_ptr<LLMKVCache>> _kv_cache_ptr;
};

初始化流程:

LLMInferRuntime() {
    // 1. 初始化设备
    init_device();
  
    // 2. 创建内存管理器
    _mem_manager_ptr = ModuleFactory::instance()
        ->create_shared<ModuleObject>(WEIGHT_MEM_BUFFER_MANAGER_FLAG, ...);
  
    // 3. 创建批处理管理器
    _llm_batch_manager_ptr = ModuleFactory::instance()
        ->create_shared<ModuleObject>(BATCH_MANAGER_FLAG, ...);
  
    // 4. 创建任务流管理器
    _task_manager = ModuleFactory::instance()
        ->create_shared<ModuleObject>(TASK_FLOW_MANAGER_FLAG, ...);
  
    // 5. 创建图优化器
    _graph_optimizer = ModuleFactory::instance()
        ->create_shared<ModuleObject>(GRAPH_OPTIMIZER_FALG, ...);
}

推理流程:

bool infer(int n_predict, std::vector<std::string>& generate_str) {
    // 1. Prefill 阶段(处理 prompt)
    prefill(batch);
  
    // 2. Decode 阶段(逐个生成 token)
    for (int i = 0; i < n_predict; ++i) {
        decode(batch, n_predict, generate_str);
    
        // 3. 采样下一个 token
        int32_t token = sample_token();
    
        // 4. 更新 batch
        update_batch(token);
    }
  
    return true;
}

3.6.2 TaskFlowManager - 任务流管理器

文件位置src/core/runtime/TaskFlowManager.h

核心职责: 基于 Taskflow 库实现 DAG 任务调度。

class LLMTaskFlowManager : public ModuleObject {
    std::shared_ptr<HybridScheduler> _task_scheduler;
  
    // 构建任务调度
    bool build_task_schedule(const TaskType& type, 
                            const std::shared_ptr<Graph>& graph_ptr);
  
    // 运行任务
    void run(const TaskType& type) {
        _task_scheduler->run(type);
    }
};

3.6.3 MemManager - 内存管理器

文件位置src/core/runtime/MemManager.h

核心职责:

  • 统一管理设备内存池
  • 基于生命周期的内存分配与回收
  • 内存复用优化

关键数据结构:

class LLMMemManager : public ModuleObject {
    // 内存块
    struct MemoryBlock {
        int64_t _offset;      // 偏移量
        int64_t _size;        // 大小
        int _free_time;       // 释放时间
        int _ref_count;       // 引用计数
        std::vector<std::shared_ptr<DeviceEvent>> _pending_events;
    };
  
    // 空闲块
    struct FreeBlock {
        int64_t _offset;
        int64_t _size;
    };
  
    // 内存池
    std::unordered_map<int, void*> _mem_buffer_map;              // 设备 → 基址
    std::unordered_map<int, int64_t> _current_offset;            // 设备 → 当前偏移
    std::unordered_map<int, std::set<FreeBlock>> _free_set;      // 设备 → 空闲块集合
    std::unordered_map<int, priority_queue<MemoryBlock>> _memory_heap; // 待释放堆
};

内存分配策略:

std::pair<int64_t, void*> allocate_memory(int64_t size, 
                                          const int& device_id,
                                          const MemoryType& type,
                                          const std::shared_ptr<DeviceEvent>& event) {
    std::lock_guard<std::mutex> lock(_mutex);
  
    // 1. 查找合适的空闲块
    auto& free_set = _free_set[device_id];
    auto it = find_best_fit(free_set, size);
  
    if (it != free_set.end()) {
        // 2. 复用空闲块
        int64_t offset = it->_offset;
        free_set.erase(it);
        void* ptr = static_cast<char*>(_mem_buffer_map[device_id]) + offset;
    
        // 3. 记录引用
        aquire_memory(device_id, offset, size, event);
    
        return {offset, ptr};
    } else {
        // 4. 扩展内存池
        int64_t offset = _current_offset[device_id];
        _current_offset[device_id] += align_up(size, _alignment);
    
        void* ptr = static_cast<char*>(_mem_buffer_map[device_id]) + offset;
        return {offset, ptr};
    }
}

延迟回收机制:

void release_memory(const int& device_id, const int64_t& offset) {
    // 不立即释放,而是加入待释放队列
    // 等待 CUDA Event 完成后真正释放
    _memory_heap[device_id].push(MemoryBlock(offset, size, current_time));
}

void collect(const int& device_id) {
    // 检查已完成的 Event,回收内存
    while (!_memory_heap[device_id].empty()) {
        auto& block = _memory_heap[device_id].top();
        if (all_events_completed(block._pending_events)) {
            _free_set[device_id].insert(FreeBlock{block._offset, block._size});
            _memory_heap[device_id].pop();
        } else {
            break;
        }
    }
}

3.6.4 KVCache - KV 缓存管理

文件位置src/core/runtime/KVCache.h

核心设计: Paged Attention

关键组件:

1. KVPage - 缓存页:

struct KVPage {
    std::shared_ptr<Tensor> _k;        // Key 张量
    std::shared_ptr<Tensor> _v;        // Value 张量
    int _n_tokens = 0;                 // 当前 token 数
    bool _is_used = false;             // 是否已使用
};

2. PageManager - 页管理器:

class PageManager {
    std::vector<std::shared_ptr<KVPage>> _pages;  // 所有页
    std::vector<PageID> _free_list;                // 空闲页列表
  
    // 分配页
    PageID allocate() {
        if (_free_list.empty()) return INVALID_PAGE_ID;
        PageID id = _free_list.back();
        _free_list.pop_back();
        _pages[id]->_is_used = true;
        return id;
    }
  
    // 释放页
    void free(PageID id) {
        _pages[id]->_is_used = false;
        _free_list.push_back(id);
    }
  
    // 获取 K/V 张量
    std::shared_ptr<Tensor> get_k(PageID id, ...);
    std::shared_ptr<Tensor> get_v(PageID id, ...);
};

3. LayerKVContext - 层上下文:

class LayerKVContext {
    int _seq_id;                    // 序列 ID
    int _layer_id;                  // 层 ID
    std::vector<PageID> _page_table; // 页表(逻辑页 → 物理页)
    int _num_tokens = 0;            // token 数量
  
    // 添加 token
    bool append_token() {
        if (_num_tokens == get_max_tokens()) {
            PageID new_page = _page_manager->allocate();
            if (new_page == INVALID_PAGE_ID) return false;
            _page_table.push_back(new_page);
        }
        _num_tokens++;
        return true;
    }
  
    // 获取 token 位置
    std::pair<PageID, int> get_location(int token_idx) {
        int page_id = token_idx / PAGE_SIZE;
        int offset = token_idx % PAGE_SIZE;
        return {_page_table[page_id], offset};
    }
};

4. LLMKVCache - KV 缓存主类:

class LLMKVCache {
    struct KVConfig {
        uint32_t _n_layer;          // 层数
        uint32_t _n_head;           // 头数
        uint32_t _n_embd_head;      // 头维度
        int _total_pages;           // 总页数
        int _page_size = 32;        // 每页 token 数
    };
  
    std::shared_ptr<PageManager> _page_manager;
    std::unordered_map<int, std::unique_ptr<LayerKVContext>> _seq_contexts;
  
    // 获取上下文
    LayerKVContext* get_context(int seq_id, int layer_id) {
        int key = make_key(seq_id, layer_id);
        auto it = _seq_contexts.find(key);
        if (it != _seq_contexts.end()) {
            return it->second.get();
        }
    
        auto ctx = std::make_unique<LayerKVContext>(seq_id, layer_id, _page_manager);
        LayerKVContext* ptr = ctx.get();
        _seq_contexts[key] = std::move(ctx);
        return ptr;
    }
  
    // 获取 K/V
    std::shared_ptr<Tensor> get_k(int seq_id, int layer_id, PageID page_id, ...) {
        auto ctx = get_context(seq_id, layer_id);
        return ctx->_page_manager->get_k(page_id, ...);
    }
};

Paged Attention 优势:

  • ✅ 消除内存碎片
  • ✅ 支持可变长度序列
  • ✅ 高效的内存利用率
  • ✅ 支持 Continuous Batching(未来扩展)

3.7 算子层 (Kernel Layer)

3.7.1 TFFOPCreator - 算子创建器

文件位置src/kernel/include/TFFOPCreator.h

核心功能: 定义所有支持的算子类型及其参数构建器。

支持的算子:

// 数据传输
Map2Cpu, MemCpy

// Embedding
Embedding

// 线性变换
Mul, MatMul, QuantMatMul

// 激活函数
Unary (SwiGLU, GeLU, etc.)

// 归一化
Norm, NormW (RMSNorm, LayerNorm)

// 位置编码
Rope

// 注意力
FlashAttn, PagedFlashAttn

// 量化
Quant, DeQuant, QuantAligned, QuantReshape

// 形状操作
Reshape, View

// 二元操作
Add, Binary

// 内存操作
MemRef, Gather, SetRow, GetRow

参数构建器模式:

class FlashAttnBuilder : public OpParamBuilderBase<FlashAttnBuilder> {
    struct Params {
        static constexpr const char* Q = "q_tensor";
        static constexpr const char* K = "k_tensor";
        static constexpr const char* V = "v_tensor";
        static constexpr const char* Mask = "mask";
        static constexpr const char* Out = "out";
    };
  
public:
    template<typename T>
    FlashAttnBuilder& q(T&& value) {
        return set(Params::Q, value);
    }
  
    template<typename T>
    FlashAttnBuilder& k(T&& value) {
        return set(Params::K, value);
    }
  
    // ... 其他参数
};

使用示例:

auto builder = FlashAttnBuilder()
    .q(q_tensor)
    .k(k_tensor)
    .v(v_tensor)
    .mask(mask_tensor)
    .out(output_tensor);

// 创建算子节点
auto node = create_op_node<FlashAttn>(builder);

3.7.2 Flash Attention 实现

文件位置src/kernel/cuda/flash_attention.cu

核心优化技术:

1. Shared Memory Tiling:

// 将 K/V 分块加载到 shared memory
__shared__ half2 k_sm[BLOCK_DIM_LD][BLOCK_DIM_K / 2 + PAD_SIZE];
__shared__ half2 v_sm[BLOCK_DIM_K][BLOCK_DIM_V / 2];

2. Async Copy (Hopper) :

// 异步拷贝全局内存到共享内存
cp_async_cg_16<64>(dst_addr, src_addr);
cp_async_commit();
cp_async_wait_all();

3. Swizzle 避免 Bank Conflict:

template<int B, int M, int S = B>
__device__ int swizzle(const int& offset) {
    const int bit_msk = (1 << B) - 1;
    const int yyy_msk = bit_msk << (M + max(0, S));
    if constexpr (S >= 0) {
        return offset ^ ((offset & yyy_msk) >> S);
    } else {
        return offset ^ ((offset & yyy_msk) << -S);
    }
}

4. RoPE 融合:

// 在加载 K/Q 时直接应用 RoPE
template<const int ELEMENTS_PER_LOAD>
__device__ void rote(const int addr, half2* sm, float2* cos_sin_table) {
    cp_async_wait_all();
    float4 sm_val = val_vec[0];
    half2* rot_value = reinterpret_cast<half2*>(&sm_val);
  
    for (int i = 0; i < ELEMENTS_PER_LOAD / 2; ++i) {
        half2 val = rot_value[i];
        rot_value[i] = complex_mul_half2(val, __float22half2_rn(cos_sin_table[i]));
    }
}

5. Online Softmax:

// 在线计算 softmax,避免额外 pass
float m_prev = m;
float m_new = max(m, row_max);
float l_prev = l * exp(m_prev - m_new);
float l_new = l_prev + exp(row_max - m_new);

// 更新输出
acc = acc * (l_prev / l_new) + attn * (exp(row_max - m_new) / l_new);
m = m_new;
l = l_new;

3.8 任务调度层 (Task Scheduling Layer)

3.8.1 HybridScheduler - 混合调度器

文件位置src/taskgraph/include/TaskFlowSchedule.h

核心功能: 基于 Taskflow 实现 CPU-GPU 混合任务调度。

关键特性:

class HybridScheduler : public ModuleObject {
    // Taskflow 执行器
    std::unordered_map<TaskType, tf::Executor> _executor;
    std::unordered_map<TaskType, tf::Taskflow> _task_flow;
  
    // CUDA Graph 支持
    cudaStream_t _capture_stream;
    cudaGraph_t _graph;
    cudaGraphExec_t _graph_exec;
    bool _use_cuda_graph;
  
    // 多流支持(未来扩展)
    std::vector<cudaStream_t> _gpu_streams;
    std::vector<cudaEvent_t> _sync_events;
};

添加任务:

template<typename F, typename... Args>
tf::Task add_task(const TaskType& type, const std::string& name, 
                  F&& f, Args&&... args) {
    auto bound = [func = std::forward<F>(f),
                  tup = std::make_tuple(std::forward<Args>(args)...)]() mutable {
        std::apply(func, tup);
    };
    return _task_flow[type].emplace(std::move(bound)).name(name);
}

运行任务:

void run(const TaskType& type) {
    if (_use_cuda_graph && _graph_exec) {
        // 使用 CUDA Graph 执行
        cudaGraphLaunch(_graph_exec, _capture_stream);
    } else {
        // 使用 Taskflow 执行
        _future[type] = _executor[type].run(_task_flow[type]);
    }
}

void wait_until_completion(const TaskType& type) {
    if (!_use_cuda_graph) {
        _future[type].wait();
    } else {
        cudaStreamSynchronize(_capture_stream);
    }
}

4. 设计模式与架构思想

4.1 使用的设计模式

4.1.1 工厂模式 (Factory Pattern)
  • ModuleFactory: 模块对象的创建与管理
  • FunctionFactory: 函数回调的注册与调用
  • ModelFactory: 模型加载器的选择与创建

优势:

  • 解耦对象创建与使用
  • 支持运行时动态注册
  • 便于扩展新模块

4.1.2 策略模式 (Strategy Pattern)
  • DeviceBaseObject: CPU/GPU 设备策略
  • MemBufferAllocator: 不同设备的内存分配策略
  • ModelDetector: 不同格式的模型检测策略

优势:

  • 算法可互换
  • 符合开闭原则
  • 易于测试

4.1.3 观察者模式 (Observer Pattern)
  • GraphNode 依赖关系: 节点间的输入输出依赖
  • DeviceEvent: 异步事件通知

4.1.4 单例模式 (Singleton Pattern)
  • ModuleFactory::instance()
  • FunctionFactory::instance()
  • DeviceManager(通过工厂创建单例)

4.1.5 建造者模式 (Builder Pattern)
  • OpParamBuilderBase: 算子参数构建
  • 各种 Builder(Map2CpuBuilder, FlashAttnBuilder, etc.)

优势:

  • 参数设置清晰
  • 支持链式调用
  • 编译期类型检查

4.2 架构思想

4.2.1 分层架构 (Layered Architecture)
应用层 → 运行时层 → 核心层 → 模型层 → 算子层 → 基础设施层

优势:

  • 职责清晰
  • 降低耦合
  • 便于维护和测试

4.2.2 依赖注入 (Dependency Injection)

通过工厂模式和构造函数注入依赖:

LLMInferRuntime() {
    _mem_manager_ptr = ModuleFactory::instance()
        ->create_shared<...>(...);
    _task_manager = ModuleFactory::instance()
        ->create_shared<...>(...);
}

4.2.3 面向接口编程 (Programming to Interface)

大量使用虚函数和抽象基类:

class ModelLoaderBase { /* 纯虚接口 */ };
class DeviceBaseObject { /* 纯虚接口 */ };
class ModelCreatorBase { /* 纯虚接口 */ };

4.2.4 RAII (Resource Acquisition Is Initialization)
  • std::shared_ptr 管理资源生命周期
  • std::lock_guard 管理锁
  • Tensor/Memory 析构时自动释放资源

5. 数据流与执行流程

5.1 完整推理流程

1. 初始化阶段
   ├─ LLMInferRuntime 构造
   ├─ 设备初始化 (DeviceManager)
   ├─ 管理器创建 (MemManager, TaskFlowManager, BatchManager)
   └─ 图优化器创建
   
2. 模型加载阶段
   ├─ 加载模型文件 (ModelLoader)
   ├─ 解析元数据 (架构、超参数)
   ├─ 构建层映射 (Layer Map)
   ├─ 加载词表 (Vocabulary)
   └─ 加载权重 (支持 mmap)
   
3. 运行时初始化
   ├─ 初始化 KV Cache (LLMKVCache)
   ├─ 构建计算图 (ModelCreator)
   │   ├─ Prefill Graph
   │   ├─ Decode Graph
   │   └─ Memory Graph
   ├─ 图优化 (GraphOptimizer)
   │   ├─ 算子融合
   │   ├─ 死代码消除
   │   └─ 设备分配
   ├─ 内存规划 (MemManager)
   │   ├─ 生命周期分析
   │   └─ 内存池分配
   └─ 任务流构建 (TaskFlowManager)
   
4. Prefill 阶段 (处理 Prompt)
   ├─ Tokenize 输入文本
   ├─ Embedding 查找
   ├─ 前向传播 (多层 Transformer)
   │   ├─ RMS Norm
   │   ├─ Self Attention (RoPE + Flash Attention)
   │   │   ├─ Q/K/V Projection
   │   │   ├─ RoPE 编码
   │   │   ├─ KV Cache Store
   │   │   └─ Flash Attention
   │   ├─ FFN (SwiGLU)
   │   │   ├─ Gate Projection
   │   │   ├─ Up Projection
   │   │   ├─ SwiGLU Activation
   │   │   └─ Down Projection
   │   └─ Residual Connection
   └─ 输出 Logits
   
5. Decode 阶段 (逐 Token 生成)
   ├─ 采样下一个 Token (Sampler)
   ├─ 更新 Batch
   ├─ 前向传播 (单层,仅最后一个 token)
   ├─ KV Cache Update
   └─ 重复直到生成足够 Token 或遇到 EOS
   
6. 清理阶段
   ├─ 释放 KV Cache
   ├─ 重置内存池
   └─ 清理临时张量

5.2 计算图执行流程

TaskFlowManager::run(TFF_TASK_TYPE_INFER)
    ↓
HybridScheduler::run()
    ↓
遍历 Graph Nodes (拓扑序)
    ↓
对于每个 Node:
    ├─ 检查依赖是否就绪
    ├─ 分配输入张量内存
    ├─ 执行 Kernel (GPU/CPU)
    │   ├─ Launch CUDA Kernel
    │   └─ Record CUDA Event
    ├─ 标记输出张量就绪
    └─ 计划内存释放 (延迟回收)
    ↓
等待所有任务完成
    ↓
返回结果

5.3 内存管理流程

Tensor 请求内存
    ↓
MemManager::allocate_memory()
    ↓
查找空闲块 (Best Fit)
    ├─ 找到 → 复用
    └─ 未找到 → 扩展内存池
    ↓
记录引用 (aquire_memory)
    ↓
返回内存指针
    ↓
... 计算使用 ...
    ↓
Tensor 释放 / 超出生命周期
    ↓
MemManager::release_memory()
    ↓
加入待释放队列 (带 Event)
    ↓
collect() 检查 Event 完成
    ↓
真正释放 → 加入空闲块集合

6. 总结与展望

6.1 项目亮点

✅ 完整的推理框架: 从模型加载到推理执行的全链路实现 ✅ 高性能 CUDA Kernel: Flash Attention、量化 GEMM 等优化 ✅ 先进的内存管理: 生命周期分析、延迟回收、内存复用 ✅ Paged KV Cache: 高效管理注意力缓存 ✅ 灵活的架构: 工厂模式、策略模式,易于扩展 ✅ 现代 C++ : C++20 特性、智能指针、模板元编程


6.2 当前限制

⚠️ 仅支持 GGUF 格式: 需要扩展更多格式 ⚠️ 单卡推理: 尚未实现多卡并行 ⚠️ 静态 Batch: 缺少 Continuous Batching ⚠️ 有限模型支持: 目前仅 Qwen3-8B


9.3 未来计划

🚀 短期目标 (3-6 个月):

  • 支持更多模型 (LLaMA-3, Mistral)
  • 实现 Continuous Batching
  • 添加更多量化格式 (INT4, FP8)
  • 完善文档和示例
  • 另外结合最近的工作内容,未来可能会优先支持LDM结构化数据大模型->与LLM,世界模型等并列的AGI三大领域之一的LDM模型

🚀 中期目标 (6-12 个月):

  • 多卡并行推理
  • 支持更多硬件后端 (AMD ROCm, Intel GPU)
  • 实现 Speculative Decoding
  • 集成 vLLM 风格的调度器

🚀 长期愿景 (1-2 年):

  • 成为各位推理爱好者的学习项目,我们一起努力

附录

A. 关键文件索引

模块关键文件行数
运行时src/core/runtime/InferRuntime.h269
计算图src/core/graph/Graph.h168
张量src/core/mem/Tensor.h485
KV Cachesrc/core/runtime/KVCache.h476
内存管理src/core/runtime/MemManager.h277
任务调度src/taskgraph/include/TaskFlowSchedule.h124
算子创建src/kernel/include/TFFOPCreator.h1645
Flash Attentionsrc/kernel/cuda/flash_attention.cu822
模块工厂include/ModuleFactory.h331

B. 参考资料

  1. Flash AttentionDao et al., 2022
  2. Paged AttentionKwon et al., 2023
  3. GGUF Formatggml-org/gguf
  4. Taskflowtaskflow/taskflow
  5. CUTLASSNVIDIA/cutlass

C. 联系方式