🐍 NumPy 版三元表达式 | 向量化一行秒掉 if-else 🔥

59 阅读2分钟

Python 三元已经够快了?
NumPy 直接 “广播” 帮你秒到飞起!
今天教你 5 招,彻底告别慢吞吞的 for 循环~

微信图片_20251014151033_10_20.jpg

1️⃣ 入门武器:np.where ⚔️

语法np.where(条件, 条件为真取值, 条件为假取值)

import numpy as np

score = np.array([55, 82, 91, 67])
level = np.where(score >= 60, 'Pass', 'Fail')

print(level)          # ['Pass' 'Pass' 'Pass' 'Pass']

2️⃣ 多条件嵌套:继续套娃 🪆

# 90+ 优秀,60-89 及格,<60 不及格
level = np.where(score >= 90, '优秀',
                 np.where(score >= 60, '及格', '不及格'))

print(level)          # ['及格' '及格' '优秀' '及格']

提示:嵌套超过两层建议换 select(见下)👇


3️⃣ 终极多分支:np.select 🎚️

conditions = [score >= 90, score >= 60, score < 60]
choices    = ['优秀',      '及格',      '不及格']

level = np.select(conditions, choices, default='未知')
print(level)          # ['及格' '及格' '优秀' '及格']
参数说明
conditions条件列表,按先后顺序匹配
choices与 conditions 一一对应的结果
default全都不满足时的兜底值

4️⃣ 数值运算式三元:广播秀操作 ⚡️

a = np.array([3, -2, 7, -5])
# 负数变 0,其他保持
b = np.where(a < 0, 0, a)

print(b)              # [3 0 7 0]

再升级:

# 负数绝对值,正数平方
c = np.where(a < 0, -a, a**2)
print(c)              # [9 2 49 5]

5️⃣ 性能对比:for 循环被按在地上摩擦 🏎️

import timeit

big = np.random.randn(1_000_000)

# np.where 向量化
def np_way():
    return np.where(big > 0, big, 0)

# 纯 Python 循环
def py_way():
    return [x if x > 0 else 0 for x in big]

print(timeit.timeit(np_way, number=10))  # ≈ 0.18 s
print(timeit.timeit(py_way, number=10))  # ≈ 4.5  s

25 倍提速不是梦!⚡️

6️⃣ 高维数组同样适用:RGB 图像示例 🖼️

img = np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8)

# 把灰度值<50 的像素直接变黑
dark_mask = img.mean(axis=2, keepdims=True) < 50
img_black = np.where(dark_mask, 0, img)

7️⃣ 易错点提醒 ⚠️

错误示例正确做法
维度不一致np.where(a>0, 1, [1,2,3])保证广播规则✅
类型混用np.where(a>0, 1.0, 'small')返回同类 dtype,先统一类型✅

🎯 总结口诀(背它!)

三元太慢?上 np.where
多分支?np.select 排排坐!
广播加持,百万数据一行代码秒掉!🎈