PyTorch项目实战07——Tensor的比较操作

140 阅读3分钟

1 常用的比较运算

函数功能
lt/le/ne/eq/ge/gt小于/小于等于/不等于/等于/大于等于/大于
topk最大的前K个数
sort排序
max/min最大值/最小值

2 应用

2.1 创建张量

使用随机数创建两个2行3列的张量。

image.png

2.2 比较大小

lt() 方法,是进行逐元素比较,即两个张量中,所有对应位置上的数据进行两两比较,并会将其比较结果生成为一个包含 True/False 的布尔值的张量。 image.png

lt() 方法等同于 <,计算效果相同

image.png

其它方法也可以替换为相同功能和数学符号:

方法符号
lt<
le<=
ne!=
eq==
ge>=
gt>

2.3 张量的相等比较

有一种和 eq 类似的比较符是 equal,但是其作用不同,equal 并不是逐元素比较,而是整个张量进行比较。所以它的结果只有一个布尔值。

image.png

2.4 最大前K个数

如果我们想要得出张量中,按某一个维度最大的数据,或最大前K个数,就可以使用 topk() 方法。

2.4.1 按行比较

a.topk(k) 默认会按行统计出最大的前K个数,并按照从大到小的顺序,从左到右列出,并在下边给出元素的索引。

如 a.topk(1)

image.png

上边一行 values 给出每行最大的元素;

下边的 indices 则给出每行最大元素所在列的索引值,如 0.4848 在张量a中第1行第0个元素,1.2878 则在第二行中第1个元素。

再如 a.topk(2),取出张量 a 中按行比较后每行最大的前2个元素,并从左到右输出,如第二行最大的元素是 1.2878,其索引在第二行中下标为1,而第二大元素是 -0.0059,其索引在第二行中下标为0。

image.png

2.4.3 按列比较

在 topk() 方法中设置 dim 参数,指定要比较的维度,dim=0 是按列比较。

values 中给出每列排在最前边的元素内容,indices 中给出该元素在张量中的索引位置。

​a.topk(1, dim=0)​​取出每列中最大的元素,第一列最大的元素为 0.4848,该元素在第一列的下标为0,第二列最大的元素为 1.2878,该元素在第二列的下标为1,依此类推。

image.png

2.5 元素排序

使用 sort() 方法,可以对张量中的元素进行排序。如果不指定排序维度时,默认按行的方向对元素数据按从小到大,从左到右的顺序,对元素进行排序。

image.png

values 中给出每行排序后元素的内容,indices 中给出该元素在张量中的索引位置。

当在 sort() 方法中,显示指定了排序的维度,如 dim=0,程序将按照列的方向对元素进行排序。

image.png

2.6 最大值

2.6.1 张量之间的比较

两个张量之间,使用 max() 方法进行比较,是取出两个张量相同索引的元素数值,经比较后,取出最大者放入到生成后张量的同一索引位置中。

image.png

张量 a 中的第1行下标为0的元素为 0.4848,张量 b 中相同索引位置的元素为 0.4033,两者中 0.4848 较大,因此将 0.4848 放入到生成的张量中第1行下标为0的索引位置。其它数据依此类推。

2.6.2 张量中所有元素的比较

如果想要获取一个张量中最大的元素数据,可以使用不带任何参数的 max() 方法。

image.png

2.6.3 张量中按维度进行的比较

如果想要以某一维度,获得张量中最大的元素数据,可以在 max() 方法中输入 dim 参数,比如要获取每列最大的元素,就可以使用 max(dim=0) 即可。

values 中给出每列中最大的元素数据,indices 中给出该元素在张量当前列中的索引位置。

image.png

2.7 注意

需要注意的是,张量中所有比较两个张量之间的比较,都是建议在维度值相同的前提下。

如2行3列的张量,只能与2行3列的张量进行比较运算,如果是2行4列则会报错。

image.png

张量可以直接和标量进行比较运算。

image.png