.argmax()

169 阅读2分钟

NumPy  .argmax()  函数返回数组中沿指定轴的最大值的索引。

语法

numpy.argmax(array, axis=None, out=None, keepdims=<no value>)

参数:

  • array:用于查找最大值的输入数组。
  • axis(可选):整数,指定要沿其查找最大值的轴。默认情况下 () ,将计算展平数组的索引。None
  • out(可选):如果提供,结果将插入到此数组中。它应该具有适当的 shape 和 dtype。
  • keepdims(可选):如果 ,则减少的轴将作为大小为 1 的维度保留在结果中,从而允许结果针对输入数组正确广播。True

返回值:

将索引数组返回到数组中。它与输入数组的形状相同,但删除了沿指定轴的维度,除非设置为 。keepdims``True

import numpy as np

# Create a 2D array representing scores of 4 students in 3 subjects
scores = np.array([
    [85, 90, 78],  # Student 1
    [88, 76, 92],  # Student 2
    [79, 85, 88],  # Student 3
    [91, 89, 84]   # Student 4
])

#找到每个科目中最好的学生(轴=0表示找到每个列的最大索引)
best_students = np.argmax(scores, axis=0)

#学生科目中最高分数 (轴=1表示找到每个行的最大索引)
best_score = np.argmax(scores, axis=1)

# Subject names for reference
subjects = ["Math", "Science", "English"]

# Display results
for i, subject in enumerate(subjects):
    print(f"Top scorer in {subject}: Student {best_students[i] + 1}")

for i in range(len(best_score)):
    print(f"Student {i+1} scored the highest in {subjects[best_score[i]]}")

再看torch也是一样

import torch
x = torch.randn(3, 5)
print(x)

'''tensor([[ 0.7961,  0.0925, -0.1900,  1.6187, -1.1678],
        [-3.8870,  0.1368,  0.1494,  0.9987,  0.0694],
        [-0.2666, -0.2431,  0.2670, -0.3807,  0.6139]])'''
        

print(torch.argmax(x))          # 默认展平后找全局最大值的索引

print(torch.argmax(x, dim=0))   # 沿列(垂直方向)逐列找最大值索引
print(torch.argmax(x, dim=-2))  # dim=-2 等价于 dim=0(倒数第二个维度)

print(torch.argmax(x, dim=1))   # 沿行(水平方向)逐行找最大值索引
print(torch.argmax(x, dim=-1))  # dim=-1 等价于 dim=1(最后一个维度)

'''tensor(3)
tensor([0, 1, 2, 0, 2])
tensor([0, 1, 2, 0, 2])
tensor([3, 3, 4])
tensor([3, 3, 4])'''