【Triton 教程】triton_language.cast

62 阅读1分钟

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

更多 Triton 中文文档可访问 →triton.hyper.ai/

triton.language.cast(input, dtype: dtype, fp_downcast_rounding: str | None = None, bitcast: bool = False)

将张量转换为指定的 dtype

参数

  • dtype (tl.dtype) - 目标数据类型。
  • fp_downcast_rounding (stroptional) - 向下转换浮点值的舍入模式。仅当 self 是浮点张量且 dtype 是比特宽度较小的浮点类型时使用。支持的值为 "rtne"(四舍五入到最接近的偶数)和 "rtz"(向零舍入)。
  • bitcast (booloptional) - 如果为 true,则将张量位转换为给定的 dtype,而不是进行数值转换。

此函数也可以作为 tensor 上的成员函数调用,作为 x.cast (...) 而不是 cast (x,...)