Numpy np.where()的简单用法

182 阅读1分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

第一种用法是寻找符合条件的数的索引:

import numpy as np

arr = np.array([1, 2, 3, 4, 5])
idx = np.where(arr > 3)
print(idx)

输出:

(array([3, 4], dtype=int64),)

第二种用法是寻找符合条件的数并进行修改:

import numpy as np

arr = np.array([1, 2, 3, 4, 5])
new_arr = np.where(arr > 3, 1, 0)
print(new_arr)

输出:

[0 0 0 1 1]

其实就相当于第一种方法拿到索引后再修改:

import numpy as np

arr = np.array([1, 2, 3, 4, 5])
idx1 = np.where(arr > 3)
idx2 = np.where(arr <= 3)
arr[idx1[0]] = 1
arr[idx2[0]] = 0
print(arr)

还有其他的用法,不过本文不再列出。