落地 NPU 专属接口、底层适配和优化策略注册,核心是「三层协作、硬件对齐」:
- 上层接口:封装细节、保证易用,仅做参数校验和转发,不绑定硬件;
- 底层适配:对接 TVM、映射硬件,实现调度器和代码生成器,是性能核心;
- 优化策略:注册规则、自动选择,让 TVM 编译器赋能,降低用户使用门槛。
封装 NPU 专属上层接口(基于现有 TVM Conv2D 接口扩展)
``
from dataclasses import dataclass
from typing import Optional
@dataclass class NPUConv2DConfig:
"""NPU Conv2D算子专属配置类""" # 硬件基础配置
chip_model: str = "default_npu" # NPU芯片型号
compute_level: str = "high_perf" # 算力档位:
high_perf/low_power # 存储优化配置
cache_allocation: str = "weight_priority" # 缓存分配策略:weight_priority/input_priority
enable_onchip_memory: bool = True # 是否启用片上高速缓存 # 专属优化配置 enable_weight_preload: bool = True # 是否启用权重预加载(NPU核心优化点) enable_pipeline: bool = True # 是否启用流水线计算 # 量化相关(仅针对int8接口) quant_mode: Optional[str] = None #
int8_sym/int8_asym def validate(self): """校验配置的合法性(避免无效参数传入底层)""" valid_compute_levels = ["high_perf", "low_power"]
valid_cache_strategies = ["weight_priority", "input_priority"]
valid_quant_modes = [None, "int8_sym", "int8_asym"]
if self.compute_level not in valid_compute_levels:
raise ValueError(f"无效算力档位,支持:{valid_compute_levels}")
if self.cache_allocation not in valid_cache_strategies:
raise ValueError(f"无效缓存策略,支持:{valid_cache_strategies}")
if self.quant_mode not in valid_quant_modes:
raise ValueError(f"无效量化模式,支持:{valid_quant_modes}")