【Triton 教程】triton_language.expand_dims

56 阅读1分钟

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

更多 Triton 中文文档可访问 →triton.hyper.ai/

triton.language.expand_dims(input, axis)

通过插入新的长度为 1 的维度来扩展张量的形状。

轴索引是相对于生成的张量而言的,因此对于每个轴,result.shape[axis] 将为 1。

参数

  • input (tl.tensor) - 输入张量。
  • axis (int | Sequence[int] ) - 要添加新轴的索引。

该函数也可作为 tensor 的成员函数调用,使用 x.expand_dims(...) 而不是 expand_dims(x, ...)