NumPy .argmax() 函数返回数组中沿指定轴的最大值的索引。
语法
numpy.argmax(array, axis=None, out=None, keepdims=<no value>)
参数:
array:用于查找最大值的输入数组。axis(可选):整数,指定要沿其查找最大值的轴。默认情况下 () ,将计算展平数组的索引。Noneout(可选):如果提供,结果将插入到此数组中。它应该具有适当的 shape 和 dtype。keepdims(可选):如果 ,则减少的轴将作为大小为 1 的维度保留在结果中,从而允许结果针对输入数组正确广播。True
返回值:
将索引数组返回到数组中。它与输入数组的形状相同,但删除了沿指定轴的维度,除非设置为 。keepdims``True
import numpy as np
# Create a 2D array representing scores of 4 students in 3 subjects
scores = np.array([
[85, 90, 78], # Student 1
[88, 76, 92], # Student 2
[79, 85, 88], # Student 3
[91, 89, 84] # Student 4
])
#找到每个科目中最好的学生(轴=0表示找到每个列的最大索引)
best_students = np.argmax(scores, axis=0)
#学生科目中最高分数 (轴=1表示找到每个行的最大索引)
best_score = np.argmax(scores, axis=1)
# Subject names for reference
subjects = ["Math", "Science", "English"]
# Display results
for i, subject in enumerate(subjects):
print(f"Top scorer in {subject}: Student {best_students[i] + 1}")
for i in range(len(best_score)):
print(f"Student {i+1} scored the highest in {subjects[best_score[i]]}")
再看torch也是一样
import torch
x = torch.randn(3, 5)
print(x)
'''tensor([[ 0.7961, 0.0925, -0.1900, 1.6187, -1.1678],
[-3.8870, 0.1368, 0.1494, 0.9987, 0.0694],
[-0.2666, -0.2431, 0.2670, -0.3807, 0.6139]])'''
print(torch.argmax(x)) # 默认展平后找全局最大值的索引
print(torch.argmax(x, dim=0)) # 沿列(垂直方向)逐列找最大值索引
print(torch.argmax(x, dim=-2)) # dim=-2 等价于 dim=0(倒数第二个维度)
print(torch.argmax(x, dim=1)) # 沿行(水平方向)逐行找最大值索引
print(torch.argmax(x, dim=-1)) # dim=-1 等价于 dim=1(最后一个维度)
'''tensor(3)
tensor([0, 1, 2, 0, 2])
tensor([0, 1, 2, 0, 2])
tensor([3, 3, 4])
tensor([3, 3, 4])'''