NumPy:快速逻辑遍历np数组

168 阅读1分钟

基础代码

import numpy as np

x = [1, 2, 3]
xa = np.asarray(x) # 将x转换为ndarray

# y = (x < 2)
ya = (xa < 2)
print(ya)
print(type(ya))

普通列表是不支持遍历判断x<2的,而ndarray数组支持,输出结果ya为一个布尔表,满足条件的索引为True,反之False。

使用实例

在mnist数据集中,想要单独训练5和非5的二元分类器,需要将y_train和y_test根据是否为5设置T/F。

y_train_5 = (y_train == 5) # 所有5标签为T,反之为F
y_test_5 = (y_test == 5)