【Triton 教程】triton-ops

0 阅读3分钟

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

*在线运行 Triton 学习教程 → go.hyper.ai/wS9x1

triton_language.argmax

triton.language.argmax(input, axis, tie_break_left=True, keep_dims=False)

返回沿指定 axis 的 input 张量中所有元素的最大索引。

参数**:**

  • input (Tensor) - 输入值。
  • axis (int) - 要进行归约操作的维度。
  • keep_dims (bool) - 如果为 true,则保留长度为 1 的归约维度。
  • tie_break_left (bool) - 如果为 true,在出现平局的情况下(即多个元素具有相同的最大索引值),对于非 NaN 的值返回最左边的索引。

这个函数也可作为 tensor 的成员函数调用,使用 x.argmax(...) 而不是 argmax(x, ...)

triton_language.argmin

triton.language.argmin(input, axis, tie_break_left=True, keep_dims=False)

返回沿指定 axis 的 input 张量中所有元素的最小索引。

参数**:**

  • input (Tensor) - 输入值。
  • axis (int) - 要进行归约操作的维度。
  • keep_dims (bool) - 如果为 true,则保留长度为 1 的归约维度。
  • tie_break_left (bool) - 如果为 true,在出现平局的情况下(即多个元素具有相同的最小索引值),对于非 NaN 的值返回最左边的索引。

这个函数也可作为 tensor 的成员函数调用,使用 x.argmin(...) 而不是 argmin(x, ...)

triton_language.max

triton.language.max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False)

返回沿指定 axis 轴上 input 张量中所有元素的最大值。

参数**:**

  • input (Tensor) - 输入值。
  • axis (int) - 要进行归约操作的维度。
  • keep_dims (bool) - 如果为 true,则保留长度为 1 的归约维度。
  • return_indices (bool) - 如果为 true,则返回对应最大值的索引。
  • return_indices_tie_break_left (bool) - 如果为 true,在出现平局的情况下(即多个元素具有相同的最大值),对于非 NaN 的值返回最左边的索引。

这个函数也可作为 tensor 的成员函数调用,使用 x.max(...) 而不是 max(x, ...)

triton_language.min

triton.language.min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False)

返回沿指定 axis 轴上 input 张量中所有元素的最小值。

参数**:**

  • input (Tensor) - 输入值。
  • axis (int) - 要进行归约操作的维度。
  • keep_dims (bool) - 如果为 true,则保留长度为 1 的归约维度。
  • return_indices (bool) - 如果为 true,则返回对应最小值的索引。
  • return_indices_tie_break_left (bool) - 如果为 true,在出现平局的情况下(即多个元素具有相同的最小值),对于非 NaN 的值返回最左边的索引。

这个函数也可作为 tensor 的成员函数调用,使用 x.min(...) 而不是 min(x, ...)

triton_language.reduce

triton.language.reduce(input, axis, combine_fn, keep_dims=False)

将 combine_fn 应用于沿指定 axis 轴上 input 张量中的所有元素。

参数**:**

  • input (Tensor) - 输入张量,或张量的元组。
  • axis (int | None) - 要进行归约操作的维度。如果为 None,则归约所有维度。
  • combine_fn (Callable) - 1 个用于组合 2 组标量张量的函数(必须使用 @triton.jit 标记)。
  • keep_dims (bool) - 如果为 true,保留长度为 1 的归约维度。

这个函数也可作为 reduce 的成员函数调用,使用 x.reduce(...) 而不是 reduce(x, ...)

triton_language.sum

triton.language.sum(input, axis=None, keep_dims=False)

返回 input 张量中,沿指定 axis 的所有元素的总和。

参数**:**

  • input (Tensor) - 输入值。
  • axis (int) - 要进行归约操作的维度。
  • keep_dims (bool) - 如果为 true,保留长度为 1 的归约维度。

这个函数也可作为 tensor 的成员函数调用,使用 x.sum(...) 而不是 sum(x, ...)

triton_language.xor_sum

triton.language.xor_sum(input, axis=None, keep_dims=False)

沿指定 axis 的 input 张量中所有元素的异或和。

参数**:**

  • input (Tensor) - 输入值。
  • axis (int) - 要进行归约操作的维度。
  • keep_dims (bool) - 如果为 true,保留长度为 1 的归约维度。

这个函数也可作为 tensor 的成员函数调用,使用 x.xor_sum(...) 而不是 xor_sum(x, ...)