基础代码
import numpy as np
x = [1, 2, 3]
xa = np.asarray(x) # 将x转换为ndarray
# y = (x < 2)
ya = (xa < 2)
print(ya)
print(type(ya))
普通列表是不支持遍历判断x<2的,而ndarray数组支持,输出结果ya为一个布尔表,满足条件的索引为True,反之False。
使用实例
在mnist数据集中,想要单独训练5和非5的二元分类器,需要将y_train和y_test根据是否为5设置T/F。
y_train_5 = (y_train == 5) # 所有5标签为T,反之为F
y_test_5 = (y_test == 5)