第11章: 性能优化 — 理解Python的快与慢
Java/Kotlin 开发者习惯了 JVM 的 JIT 优化,Python 的性能模型完全不同。
理解"为什么慢"比盲目优化更重要。
11.1 CPython 执行模型
Java/Kotlin 对比
public class Perf {
public static int add(int a, int b) {
return a + b;
}
public static void main(String[] args) {
for (int i = 0; i < 1_000_000; i++) {
add(i, i + 1);
}
}
}
fun add(a: Int, b: Int): Int = a + b
fun main() {
repeat(1_000_000) { add(it, it + 1) }
}
Python 实现
import dis
def add(a, b):
return a + b
print("=== add(a, b) 字节码 ===")
dis.dis(add)
def sum_loop(n):
total = 0
for i in range(n):
total += i
return total
print("\n=== sum_loop(n) 字节码 ===")
dis.dis(sum_loop)
def direct_add(a, b):
return a + b
def call_add(a, b):
return direct_add(a, b)
print("\n=== direct_add 字节码 ===")
dis.dis(direct_add)
print("\n=== call_add 字节码 ===")
dis.dis(call_add)
import timeit
setup = """
def add(a, b):
return a + b
"""
direct = "a + b"
func_call = "add(a, b)"
t_direct = timeit.timeit(direct, setup="a, b = 1, 2", number=10_000_000)
t_func = timeit.timeit(func_call, setup=setup + "; a, b = 1, 2", number=10_000_000)
print(f"直接加法: {t_direct:.4f}s")
print(f"函数调用加法: {t_func:.4f}s")
print(f"函数调用开销: {t_func / t_direct:.1f}x 慢")
import dis
def demo_instructions():
x = 42
y = [1, 2, 3]
z = x + 1
result = func(z)
return result
def func(v):
return v * 2
print("=== demo_instructions 字节码 ===")
dis.dis(demo_instructions)
核心差异
| 维度 | JVM (Java/Kotlin) | CPython |
|---|
| 字节码执行 | 解释 + JIT 编译为机器码 | 纯解释执行(3.13 实验性 JIT 除外) |
| 热点优化 | 有(方法/循环计数器 → C1 → C2) | 无(3.11+ 有轻量自适应解释器) |
| 内联 | 支持(C2 激进内联) | 不支持 |
| 逃逸分析 | 支持(标量替换、锁消除) | 不支持 |
| 启动速度 | 慢(JVM 预热) | 快 |
| 稳态性能 | 快(JIT 优化后) | 慢(纯解释) |
| 内存模型 | 堆 + 栈 + JIT 代码缓存 | 堆(引用计数)+ 栈 |
| 字节码格式 | 栈式 + 寄存器混合 | 纯栈式 |
常见陷阱
def compute():
total = 0
for i in range(1_000_000):
total += i
return total
for _ in range(1000):
compute()
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def distance(p):
return (p.x ** 2 + p.y ** 2) ** 0.5
def distance_fast(p):
x, y = p.x, p.y
return (x ** 2 + y ** 2) ** 0.5
t1 = timeit.timeit('distance(p)', setup='from __main__ import distance, Point; p=Point(3,4)', number=1_000_000)
t2 = timeit.timeit('distance_fast(p)', setup='from __main__ import distance_fast, Point; p=Point(3,4)', number=1_000_000)
print(f"属性访问: {t1:.4f}s")
print(f"局部变量: {t2:.4f}s")
print(f"加速比: {t1/t2:.2f}x")
何时使用
- 理解 CPython 模型是所有性能优化的基础
- 用
dis 模块分析热点代码的字节码开销
- 函数调用开销大:避免在热循环中频繁调用小函数
- 局部变量访问比全局/属性访问快:在热循环中提取为局部变量
11.2 3.11+ 专用自适应解释器 (PEP 659)
Java/Kotlin 对比
Python 实现
def add_numbers(a, b):
return a + b
import timeit
def stable_add(n):
total = 0
for i in range(n):
total += i
return total
def unstable_add(n):
total = 0
for i in range(n):
if i % 2 == 0:
total += i
else:
total += float(i)
return total
N = 100_000
t_stable = timeit.timeit(lambda: stable_add(N), number=100)
t_unstable = timeit.timeit(lambda: unstable_add(N), number=100)
print(f"类型稳定: {t_stable:.4f}s")
print(f"类型不稳定: {t_unstable:.4f}s")
print(f"稳定更快: {t_unstable/t_stable:.2f}x")
import timeit
def hot_loop():
"""3.11 自适应解释器对这种纯计算循环提速最大"""
result = 0
for i in range(1000):
for j in range(100):
result += i * j
return result
t = timeit.timeit(hot_loop, number=100)
print(f"热循环 100 次: {t:.4f}s")
print(f"平均每次: {t/100:.4f}s")
class Foo:
def __init__(self):
self.name = "foo"
self.value = 42
f = Foo()
核心差异
| 维度 | JVM JIT | Python 3.11+ 自适应 |
|---|
| 优化方式 | 编译为原生机器码 | 特化字节码指令 |
| 内联 | 支持(深度内联) | 不支持 |
| 逃逸分析 | 支持(标量替换) | 不支持 |
| 循环展开 | 支持 | 不支持 |
| 内存开销 | 较大(代码缓存可达数百MB) | 极小(每个缓存条目几十字节) |
| 速度提升 | 10-100x(vs 纯解释) | 1.25-1.6x(vs 3.10) |
| 去优化 | 支持(回退到解释) | 支持(回退到通用字节码) |
| 编译暂停 | 有(C2 编译可能停顿几十ms) | 无(解释器内完成) |
常见陷阱
def slow_search(lst, target):
"""O(n) 搜索 — 即使有自适应优化也慢"""
for item in lst:
if item == target:
return True
return False
def fast_search(s, target):
"""O(1) 哈希查找 — 3.10 和 3.11 都快"""
return target in s
何时使用
- 升级到 Python 3.11+ 即可免费获得性能提升
- 对纯 Python 代码有效,对 C 扩展调用无帮助
- 保持类型稳定以获得最大收益
- 不要依赖自适应解释器弥补算法差距
11.3 3.13+ 实验性 JIT 编译器 (PEP 744)
Java/Kotlin 对比
Python 实现
import sys
if sys.version_info >= (3, 13):
pass
def compute():
total = 0
for i in range(1_000_000):
total += i * i
return total
核心差异
| 维度 | JVM JIT | Python 3.13 JIT |
|---|
| 成熟度 | 生产级(30年) | 实验性 |
| 默认开启 | 是 | 否 |
| 编译策略 | 多层 (C1/C2) | copy-and-patch |
| 优化深度 | 极深(内联、逃逸分析等) | 浅(简单操作) |
| 代码缓存 | 数百MB | 极小 |
| 性能提升 | 10-100x | ~5%(当前) |
| 去优化 | 完善 | 基础 |
何时使用
- 关注但不依赖,等稳定后再用于生产
- 可以在测试环境体验:
PYTHON_JIT=1
- 长期来看,Python JIT 是缩小与 JVM 性能差距的关键
- 当前 3.11+ 的自适应解释器已经提供了更显著的提升
11.4 profiling 工具链
Java/Kotlin 对比
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public class MyBenchmark {
@Benchmark
public void testMethod() {
}
}
Python 实现
4.1 timeit: 微基准测试
import timeit
t = timeit.timeit('sum(range(1000))', number=10000)
print(f"sum(range(1000)) x10000: {t:.4f}s")
times = timeit.repeat('sum(range(1000))', number=10000, repeat=5)
print(f"5 轮耗时: {[f'{t:.4f}' for t in times]}")
print(f"最小值: {min(times):.4f}s (最可信)")
print(f"最大值: {max(times):.4f}s")
print(f"平均值: {sum(times)/len(times):.4f}s")
def my_func():
return [i * 2 for i in range(1000)]
t = timeit.timeit(my_func, number=10000)
print(f"my_func x10000: {t:.4f}s")
t = timeit.timeit(
'data.sort()',
setup='import random; data = [random.random() for _ in range(10000)]',
number=1000
)
print(f"排序 10000 个元素 x1000: {t:.4f}s")
import time
start = time.time()
sum(range(1000))
elapsed = time.time() - start
start = time.perf_counter()
sum(range(1000))
elapsed = time.perf_counter() - start
t = timeit.timeit('sum(range(1000))', number=100000)
print(f"timeit: {t/100000*1e6:.2f} usec per call")
4.2 cProfile: 函数级分析
import cProfile
import pstats
import io
def slow_search(data, target):
"""O(n) 线性搜索"""
for i, item in enumerate(data):
if item == target:
return i
return -1
def fast_search(data, target):
"""O(1) 哈希查找"""
data_set = set(data)
return target in data_set
def process_data(data, targets):
"""处理数据: 对每个 target 执行搜索"""
results = []
for target in targets:
idx = slow_search(data, target)
results.append(idx)
return results
def load_data():
"""模拟加载数据"""
import random
return [random.randint(0, 100_000) for _ in range(10_000)]
def main():
data = load_data()
targets = [random.randint(0, 100_000) for _ in range(1000)]
for _ in range(10):
process_data(data, targets)
profiler = cProfile.Profile()
profiler.enable()
main()
profiler.disable()
profiler.print_stats(sort='cumulative')
s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats('tottime')
ps.print_stats(10)
print(s.getvalue())
profiler = cProfile.Profile()
profiler.enable()
process_data(list(range(10000)), list(range(1000)))
profiler.disable()
ps = pstats.Stats(profiler)
ps.strip_dirs()
ps.sort_stats('tottime')
ps.print_stats(5)
4.3 line_profiler: 行级分析
def example_for_line_profiler():
"""
用 kernprof -l -v 运行此文件来查看行级分析结果
"""
import random
import math
data = [random.random() for _ in range(10000)]
total = 0.0
for x in data:
total += math.sqrt(x)
total += math.log(x + 0.001)
result = total / len(data)
return result
from line_profiler import LineProfiler
def target_function():
import math
total = 0
for i in range(100000):
total += math.sin(i) * math.cos(i)
return total
lp = LineProfiler()
lp_wrapper = lp(target_function)
lp_wrapper()
lp.print_stats()
4.4 py-spy: 采样分析
4.5 memray: 内存分析
def memory_heavy_function():
"""演示内存分配"""
data = []
for i in range(100000):
data.append({'id': i, 'value': i * 2, 'name': f'item_{i}'})
return data
4.6 tracemalloc: 轻量内存跟踪
import tracemalloc
import linecache
tracemalloc.start()
data = [list(range(1000)) for _ in range(1000)]
current, peak = tracemalloc.get_traced_memory()
print(f"当前内存: {current / 1024 / 1024:.2f} MB")
print(f"峰值内存: {peak / 1024 / 1024:.2f} MB")
tracemalloc.stop()
tracemalloc.start()
snapshot1 = tracemalloc.take_snapshot()
leaky = []
for i in range(100000):
leaky.append([i] * 100)
snapshot2 = tracemalloc.take_snapshot()
top_stats = snapshot2.compare_to(snapshot1, 'lineno')
print("=== 内存增长 Top 10 ===")
for stat in top_stats[:10]:
print(stat)
tracemalloc.start()
data = [list(range(1000)) for _ in range(1000)]
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
print("\n=== 内存分配 Top 5 ===")
for stat in top_stats[:5]:
print(stat)
profiling 决策树
需要分析性能?
├── 微基准测试(对比两种实现)
│ └── timeit(自动多次运行,取最小值)
│ $ python -m timeit -n 10000 -r 5 "code_here"
│
├── 找函数级瓶颈
│ └── cProfile
│ $ python -m cProfile -s tottime script.py
│ 或代码中: cProfile.Profile()
│
├── 找函数内哪一行慢
│ └── line_profiler
│ $ pip install line_profiler
│ $ kernprof -l -v script.py
│
├── 生产环境无侵入分析
│ └── py-spy
│ $ py-spy top --pid <PID>
│ $ py-spy record -o flame.svg -- python script.py
│
├── 内存分析
│ ├── 轻量跟踪 → tracemalloc(标准库)
│ └── 深度分析 → memray(第三方)
│ $ memray run -o output.bin script.py
│
└── 持续监控
└── py-spy top(实时采样)
或 memray live
核心差异
| 工具 | 用途 | Java 对应物 | 侵入性 |
|---|
| timeit | 微基准测试 | JMH | 无 |
| cProfile | 函数级分析 | VisualVM CPU | 低 |
| line_profiler | 行级分析 | JFR 事件 | 低 |
| py-spy | 采样分析(无侵入) | async-profiler | 零 |
| memray | 内存分析 | VisualVM 内存 | 低 |
| tracemalloc | 内存跟踪 | - | 低 |
常见陷阱
import time
start = time.time()
elapsed = time.time() - start
start = time.perf_counter()
elapsed = time.perf_counter() - start
何时使用
- timeit: 对比两种实现的性能
- cProfile: 找到程序的热点函数
- line_profiler: 找到函数内最慢的代码行
- py-spy: 生产环境无侵入分析
- memray/tracemalloc: 内存问题排查
11.5 时间复杂度与 benchmark
Java/Kotlin 对比
import java.util.*;
Python 实现
5.1 完整时间复杂度表
5.2 实际 benchmark: O(1) vs O(n)
import timeit
N = 10000
setup_list = f'data = list(range({N}))'
setup_set = f'data = set(range({N}))'
t_list = timeit.timeit('9999 in data', setup=setup_list, number=10000)
t_set = timeit.timeit('9999 in data', setup=setup_set, number=10000)
print(f"=== set vs list 成员检查 (n={N}) ===")
print(f"list: {t_list:.4f}s")
print(f"set: {t_set:.4f}s")
print(f"set 快 {t_list/t_set:.0f}x")
N = 10000
t_append = timeit.timeit(
'lst.append(i)',
setup=f'lst = []; N = {N}',
number=N
)
t_insert = timeit.timeit(
'lst.insert(0, i)',
setup=f'lst = []; N = {N}',
number=N
)
print(f"\n=== list.append vs list.insert(0) (n={N}) ===")
print(f"append: {t_append:.4f}s (O(1))")
print(f"insert(0): {t_insert:.4f}s (O(n))")
print(f"insert(0) 慢 {t_insert/t_append:.0f}x")
from collections import deque
t_deque = timeit.timeit(
'dq.appendleft(i)',
setup=f'from collections import deque; dq = deque(); N = {N}',
number=N
)
print(f"\n=== deque.appendleft vs list.insert(0) (n={N}) ===")
print(f"deque.appendleft: {t_deque:.4f}s (O(1))")
print(f"list.insert(0): {t_insert:.4f}s (O(n))")
print(f"deque 快 {t_insert/t_deque:.0f}x")
5.3 str.join() vs + 拼接
import timeit
N = 10000
t_join = timeit.timeit(
"''.join(str(i) for i in range(N))",
setup=f'N = {N}',
number=100
)
t_plus = timeit.timeit(
"s = ''; [s := s + str(i) for i in range(N)]",
setup=f'N = {N}',
number=100
)
print(f"=== str.join vs + 拼接 (n={N}) ===")
print(f"join: {t_join:.4f}s")
print(f"+: {t_plus:.4f}s")
print(f"join 快 {t_plus/t_join:.0f}x")
for n in [100, 1000, 5000, 10000]:
t_j = timeit.timeit(
"''.join(str(i) for i in range(N))",
setup=f'N = {n}', number=100
)
t_p = timeit.timeit(
"s = ''; [s := s + str(i) for i in range(N)]",
setup=f'N = {n}', number=100
)
print(f"n={n:5d}: join={t_j:.4f}s +={t_p:.4f}s +/join={t_p/t_j:.1f}x")
5.4 dict.get() vs try/except
import timeit
N = 100000
setup = f"""
d = {{i: i*2 for i in range({N})}}
key = {N - 1} # 一定存在
"""
t_get = timeit.timeit('d.get(key)', setup=setup, number=1_000_000)
t_try = timeit.timeit(
'try:\n d[key]\nexcept KeyError:\n pass',
setup=setup, number=1_000_000
)
print(f"=== dict.get vs try/except (key 存在) ===")
print(f"get(): {t_get:.4f}s")
print(f"try/except: {t_try:.4f}s")
print(f"get() 快 {t_try/t_get:.2f}x")
setup_miss = f"""
d = {{i: i*2 for i in range({N})}}
key = -1 # 一定不存在
default = 0
"""
t_get_miss = timeit.timeit('d.get(key, default)', setup=setup_miss, number=1_000_000)
t_try_miss = timeit.timeit(
'try:\n v = d[key]\nexcept KeyError:\n v = default',
setup=setup_miss, number=1_000_000
)
print(f"\n=== dict.get vs try/except (key 不存在) ===")
print(f"get(key, default): {t_get_miss:.4f}s")
print(f"try/except: {t_try_miss:.4f}s")
5.5 locals() vs globals() 查找速度
import timeit
t_local = timeit.timeit(
'x',
setup='x = 42',
number=10_000_000
)
t_global = timeit.timeit(
'g',
setup='import __main__; __main__.g = 42',
number=10_000_000
)
print(f"=== locals vs globals 查找 ===")
print(f"局部变量 (LOAD_FAST): {t_local:.4f}s")
print(f"全局变量 (LOAD_GLOBAL): {t_global:.4f}s")
print(f"局部变量快 {t_global/t_local:.1f}x")
import math
def slow_version(n):
total = 0
for i in range(n):
total += math.sqrt(i)
return total
def fast_version(n):
sqrt = math.sqrt
total = 0
for i in range(n):
total += sqrt(i)
return total
N = 100000
t_slow = timeit.timeit(lambda: slow_version(N), number=10)
t_fast = timeit.timeit(lambda: fast_version(N), number=10)
print(f"\n=== 全局函数 vs 局部缓存 ===")
print(f"全局查找: {t_slow:.4f}s")
print(f"局部缓存: {t_fast:.4f}s")
print(f"局部缓存快 {t_slow/t_fast:.2f}x")
核心差异
| 操作 | Python | Java | 时间复杂度 |
|---|
| list.append | O(1) 均摊 | ArrayList.add | O(1) 均摊 |
| list.insert(0, x) | O(n) | ArrayList.add(0, x) | O(n) |
| list[i] | O(1) | ArrayList.get | O(1) |
| list.pop() | O(1) | ArrayList.remove(last) | O(1) |
| list.pop(0) | O(n) | ArrayList.remove(0) | O(n) |
| dict[key] | O(1) 均摊 | HashMap.get | O(1) 均摊 |
| dict 插入顺序 | O(1) 3.7+ | LinkedHashMap | O(1) |
| set.add | O(1) 均摊 | HashSet.add | O(1) 均摊 |
| x in list | O(n) | ArrayList.contains | O(n) |
| x in set | O(1) | HashSet.contains | O(1) |
| str.join | O(n) | String.join | O(n) |
| str + str | O(n) | String + String | O(n) |
常见陷阱
names = ["Alice", "Bob", "Charlie"]
if "David" in names:
pass
name_set = {"Alice", "Bob", "Charlie"}
if "David" in name_set:
pass
items = []
for i in range(1000):
items.insert(0, i)
items = []
for i in range(1000):
items.append(i)
items.reverse()
from collections import deque
items = deque()
for i in range(1000):
items.appendleft(i)
result = ""
for s in strings:
result += s
result = "".join(strings)
何时使用
- 成员检查: set > dict > list
- 频繁头插: deque > list
- 字符串拼接: join > + (大量拼接时)
- 缓存全局函数为局部变量: 在热循环中
- EAFP vs LBYL: key 大概率存在用 try/except,否则用 get()
11.6 slots: 内存与速度
Java/Kotlin 对比
data class Point(val x: Int, val y: Int)
Python 实现
import sys
import tracemalloc
import timeit
class Point:
"""普通 Python 对象: 每个实例有一个 __dict__"""
def __init__(self, x, y):
self.x = x
self.y = y
class SlotPoint:
"""使用 __slots__: 禁止 __dict__,固定属性"""
__slots__ = ('x', 'y')
def __init__(self, x, y):
self.x = x
self.y = y
p1 = Point(1, 2)
p2 = SlotPoint(1, 2)
print(f"=== 单个对象大小 ===")
print(f"普通对象: {sys.getsizeof(p1)} bytes")
print(f"slots对象: {sys.getsizeof(p2)} bytes")
N = 100000
tracemalloc.start()
points1 = [Point(i, i) for i in range(N)]
current1, peak1 = tracemalloc.get_traced_memory()
tracemalloc.stop()
tracemalloc.start()
points2 = [SlotPoint(i, i) for i in range(N)]
current2, peak2 = tracemalloc.get_traced_memory()
tracemalloc.stop()
print(f"\n=== {N} 个对象的内存占用 ===")
print(f"普通对象: {peak1 / 1024 / 1024:.2f} MB")
print(f"slots对象: {peak2 / 1024 / 1024:.2f} MB")
print(f"节省: {(1 - peak2/peak1) * 100:.1f}%")
setup_normal = f"""
from __main__ import Point
p = Point(3, 4)
"""
setup_slots = f"""
from __main__ import SlotPoint
p = SlotPoint(3, 4)
"""
t_normal = timeit.timeit('p.x', setup=setup_normal, number=10_000_000)
t_slots = timeit.timeit('p.x', setup=setup_slots, number=10_000_000)
print(f"\n=== 属性访问速度 ===")
print(f"普通对象 (__dict__查找): {t_normal:.4f}s")
print(f"slots对象 (描述符访问): {t_slots:.4f}s")
print(f"slots 快 {t_normal/t_slots:.2f}x")
t_write_normal = timeit.timeit('p.x = 42', setup=setup_normal, number=10_000_000)
t_write_slots = timeit.timeit('p.x = 42', setup=setup_slots, number=10_000_000)
print(f"\n=== 属性写入速度 ===")
print(f"普通对象: {t_write_normal:.4f}s")
print(f"slots对象: {t_write_slots:.4f}s")
print(f"slots 快 {t_write_normal/t_write_slots:.2f}x")
class Base:
__slots__ = ('x',)
class Child(Base):
__slots__ = ('y',)
c = Child()
c.x = 1
c.y = 2
class GrandChild(Child):
pass
gc = GrandChild()
gc.x = 1
gc.y = 2
gc.z = 3
class GrandChild2(Child):
__slots__ = ('z',)
gc2 = GrandChild2()
gc2.x = 1
gc2.y = 2
gc2.z = 3
class User:
__slots__ = ('name', 'age')
u = User()
u.name = "Alice"
u.age = 30
try:
u.email = "a@b.com"
except AttributeError as e:
print(f"动态属性报错: {e}")
class FlexibleUser:
__slots__ = ('name', 'age', '__dict__')
fu = FlexibleUser()
fu.name = "Bob"
fu.email = "b@c.com"
import pickle
class Data:
__slots__ = ('x', 'y')
def __init__(self, x, y):
self.x = x
self.y = y
d = Data(1, 2)
serialized = pickle.dumps(d)
restored = pickle.loads(serialized)
print(f"\npickle 序列化/反序列化: x={restored.x}, y={restored.y}")
核心差异
| 维度 | 普通对象 | slots 对象 |
|---|
| dict | 有(动态属性) | 无 |
| 内存 | 大 | 小 40-60% |
| 属性访问 | 略慢(字典查找) | 略快(描述符) |
| 动态属性 | 支持 | 不支持 |
| 序列化 | pickle 正常 | 需注意 |
| weakref | 自动支持 | 需显式声明 |
常见陷阱
class Parent:
__slots__ = ('x',)
class Child(Parent):
pass
class Child(Parent):
__slots__ = ('y',)
import weakref
class Node:
__slots__ = ('value',)
n = Node()
n.value = 42
try:
weakref.ref(n)
except TypeError as e:
print(f"弱引用报错: {e}")
class WeakNode:
__slots__ = ('value', '__weakref__')
wn = WeakNode()
wn.value = 42
wr = weakref.ref(wn)
何时使用
- 创建大量小对象(10万+)时: 必须用 slots
- 数据模型/DTO 类: 推荐使用
- 不需要动态属性的场景: 推荐使用
- 需要动态属性: 不要用 slots(或加 dict)
- 需要弱引用: 在 slots 中加 weakref
11.7 缓存策略
Java/Kotlin 对比
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import java.util.concurrent.TimeUnit;
Cache<String, String> cache = Caffeine.newBuilder()
.maximumSize(10_000)
.expireAfterWrite(10, TimeUnit.MINUTES)
.expireAfterAccess(5, TimeUnit.MINUTES)
.recordStats()
.build();
cache.put("key", "value");
String value = cache.getIfPresent("key");
String value2 = cache.get("key", k -> expensiveCompute(k));
System.out.println(cache.stats());
class LruCache<K, V>(private val maxSize: Int) {
private val cache = LinkedHashMap<K, V>(maxSize, 0.75f, true)
fun get(key: K): V? = cache[key]
fun put(key: K, value: V) {
cache[key] = value
if (cache.size > maxSize) {
cache.remove(cache.keys.first())
}
}
}
Python 实现
@lru_cache 的函数式编程视角(偏函数、单分派)详见 5.5 functools,本节聚焦缓存性能优化。
from functools import lru_cache, cache, cached_property
import time
import timeit
@lru_cache(maxsize=128)
def fibonacci(n):
"""经典: 无缓存 O(2^n),有缓存 O(n)"""
if n < 2:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
print(f"fibonacci(100) = {fibonacci(100)}")
print(f"缓存信息: {fibonacci.cache_info()}")
fibonacci.cache_clear()
print(f"清除后: {fibonacci.cache_info()}")
@lru_cache(maxsize=128, typed=True)
def add(a, b):
"""typed=True: 1+2 和 1.0+2.0 视为不同调用"""
return a + b
add(1, 2)
add(1.0, 2.0)
print(f"typed=True 缓存: {add.cache_info()}")
@cache
def expensive_computation(x):
"""无大小限制的缓存 — 注意内存泄漏风险"""
time.sleep(0.001)
return x * x
t1 = timeit.timeit(lambda: expensive_computation(42), number=1000)
t2 = timeit.timeit(lambda: expensive_computation(43), number=1000)
print(f"\n缓存命中 (42): {t1:.4f}s (1000次)")
print(f"缓存未命中 (43): {t2:.4f}s (1000次, 第一次慢)")
class DataProcessor:
def __init__(self, data):
self.data = data
@cached_property
def processed(self):
"""只在第一次访问时计算,之后返回缓存值"""
print(" 计算中...")
return [x * 2 for x in self.data]
@cached_property
def summary(self):
"""依赖其他 cached_property"""
print(" 计算摘要...")
return sum(self.processed)
dp = DataProcessor([1, 2, 3])
print("第一次访问 processed:")
print(dp.processed)
print("第二次访问 processed:")
print(dp.processed)
print("访问 summary:")
print(dp.summary)
del dp.processed
print("清除缓存后再次访问:")
print(dp.processed)
def no_cache_fib(n):
if n < 2:
return n
return no_cache_fib(n - 1) + no_cache_fib(n - 2)
t_no_cache = timeit.timeit(lambda: no_cache_fib(30), number=1)
t_with_cache = timeit.timeit(
lambda: fibonacci(30),
setup='from __main__ import fibonacci',
number=1
)
print(f"\n=== 缓存性能对比 (fib(30)) ===")
print(f"无缓存: {t_no_cache:.4f}s")
print(f"有缓存: {t_with_cache:.6f}s")
print(f"缓存快 {t_no_cache/t_with_cache:.0f}x")
from cachetools import cached, TTLCache, LRUCache
import time
class SimpleCache:
"""简易 TTL 缓存"""
def __init__(self, ttl=60):
self._cache = {}
self._ttl = ttl
def get(self, key, default=None):
if key in self._cache:
value, timestamp = self._cache[key]
if time.time() - timestamp < self._ttl:
return value
else:
del self._cache[key]
return default
def set(self, key, value):
self._cache[key] = (value, time.time())
def invalidate(self, key):
self._cache.pop(key, None)
def clear(self):
self._cache.clear()
cache = SimpleCache(ttl=5)
cache.set("user:1", {"name": "Alice"})
print(cache.get("user:1"))
time.sleep(6)
print(cache.get("user:1"))
核心差异
| 特性 | @lru_cache | @cache | @cached_property | Java Caffeine |
|---|
| 大小限制 | 有 | 无 | 无 | 有 |
| 过期策略 | 无 | 无 | 无 | TTL/TTF |
| 线程安全 | 是 | 是 | 是 | 是 |
| 统计信息 | cache_info() | cache_info() | 无 | stats() |
| LRU 淘汰 | 有 | 无 | 无 | 有 |
| 适用场景 | 纯函数 | 纯函数 | 计算属性 | 通用缓存 |
常见陷阱
@lru_cache(maxsize=128)
def process(items):
return sum(items)
try:
process([1, 2, 3])
except TypeError as e:
print(f"不可哈希参数: {e}")
@lru_cache(maxsize=128)
def process(items):
return sum(items)
print(process((1, 2, 3)))
@lru_cache(maxsize=128)
def get_config():
return {'debug': True, 'port': 8080}
config = get_config()
config['debug'] = False
config2 = get_config()
print(config2['debug'])
何时使用
- 递归函数(如 fibonacci): @lru_cache
- 耗时纯函数: @lru_cache 或 @cache
- 计算属性: @cached_property
- 需要 TTL 过期: 用 cachetools 第三方库
- 需要缓存失效通知: 用自定义缓存类
11.8 字符串优化
Java/Kotlin 对比
String s = "Hello" + " " + "World";
String.format("Hello, %s! You are %d years old.", name, age);
"Hello, " + name + "! You are " + age + " years old.";
val s = "Hello, $name! You are $age years old."
Python 实现
import timeit
N = 10000
t_plus = timeit.timeit(
"s = ''\nfor i in range(N): s += str(i)",
setup=f'N = {N}',
number=100
)
t_join = timeit.timeit(
"''.join(str(i) for i in range(N))",
setup=f'N = {N}',
number=100
)
t_list_join = timeit.timeit(
"parts = [str(i) for i in range(N)]; ''.join(parts)",
setup=f'N = {N}',
number=100
)
t_stringio = timeit.timeit(
"""
import io
buf = io.StringIO()
for i in range(N):
buf.write(str(i))
s = buf.getvalue()
""",
setup=f'N = {N}',
number=100
)
print(f"=== 字符串拼接性能 (n={N}) ===")
print(f"+ 拼接: {t_plus:.4f}s (O(n²))")
print(f"join: {t_join:.4f}s (O(n))")
print(f"列表+join: {t_list_join:.4f}s (O(n))")
print(f"StringIO: {t_stringio:.4f}s (O(n))")
name = "World"
age = 42
t_fstring = timeit.timeit(
"f'Hello, {name}! Age: {age}'",
setup="name='World'; age=42",
number=1_000_000
)
t_percent = timeit.timeit(
"'Hello, %s! Age: %d' % (name, age)",
setup="name='World'; age=42",
number=1_000_000
)
t_format = timeit.timeit(
"'Hello, {}! Age: {}'.format(name, age)",
setup="name='World'; age=42",
number=1_000_000
)
from string import Template
t_template = timeit.timeit(
"Template('Hello, $name! Age: $age').substitute(name=name, age=age)",
setup="from string import Template; name='World'; age=42",
number=1_000_000
)
print(f"\n=== 字符串格式化性能 ===")
print(f"f-string: {t_fstring:.4f}s (最快)")
print(f"% 格式化: {t_percent:.4f}s")
print(f".format(): {t_format:.4f}s")
print(f"Template: {t_template:.4f}s (最慢)")
t_fstring_complex = timeit.timeit(
"f'{name:>10} | {age:05d} | {score:.2f}'",
setup="name='Alice'; age=42; score=95.678",
number=1_000_000
)
t_format_complex = timeit.timeit(
"'{name:>10} | {age:05d} | {score:.2f}'.format(name=name, age=age, score=score)",
setup="name='Alice'; age=42; score=95.678",
number=1_000_000
)
print(f"\n=== 复杂格式化 ===")
print(f"f-string: {t_fstring_complex:.4f}s")
print(f".format(): {t_format_complex:.4f}s")
print(f"f-string 快 {t_format_complex/t_fstring_complex:.2f}x")
import sys
a = "hello"
b = "hello"
print(f"短字符串: a is b = {a is b}")
c = "a" * 1000
d = "a" * 1000
print(f"长字符串: c is d = {c is d}")
e = sys.intern("a" * 1000)
f = sys.intern("a" * 1000)
print(f"手动驻留: e is f = {e is f}")
import timeit
t_str = timeit.timeit(
"''.join(chr(i % 128) for i in range(N))",
setup='N = 10000',
number=1000
)
t_bytes = timeit.timeit(
"bytes(i % 128 for i in range(N))",
setup='N = 10000',
number=1000
)
print(f"\n=== bytes vs str 性能 ===")
print(f"str: {t_str:.4f}s")
print(f"bytes: {t_bytes:.4f}s")
print(f"bytes 快 {t_str/t_bytes:.2f}x")
text = "a" * 1000000 + "needle" + "a" * 1000000
t_find = timeit.timeit(
"text.find('needle')",
setup="text = 'a' * 1000000 + 'needle' + 'a' * 1000000",
number=1000
)
t_index = timeit.timeit(
"text.index('needle')",
setup="text = 'a' * 1000000 + 'needle' + 'a' * 1000000",
number=1000
)
t_in = timeit.timeit(
"'needle' in text",
setup="text = 'a' * 1000000 + 'needle' + 'a' * 1000000",
number=1000
)
print(f"\n=== 字符串查找 ===")
print(f"find(): {t_find:.4f}s")
print(f"index(): {t_index:.4f}s")
print(f"in: {t_in:.4f}s")
核心差异
| 操作 | Python | Java | 说明 |
|---|
| 少量拼接 | + 或 f-string | + (编译器优化) | 都很快 |
| 大量拼接 | join | StringBuilder | Python 无自动优化 |
| 格式化 | f-string | String.format | f-string 更快 |
| 字符串驻留 | sys.intern | String.intern | 机制类似 |
| 不可变性 | 不可变 | 不可变 | 一致 |
常见陷阱
result = ""
for chunk in large_list:
result += chunk
result = "".join(large_list)
if name == "Alice":
pass
if name is "Alice":
pass
data = b"hello world"
text = "hello world"
何时使用
- 少量拼接: + 或 f-string
- 大量拼接: str.join()
- 格式化: f-string(最快且最可读)
- 纯 ASCII 处理: 考虑 bytes
- 字符串比较: ==(不要用 is)
11.9 加速方案对比
Java/Kotlin 对比
public native int compute(int[] data);
Python 实现
9.1 numpy 向量化
import timeit
import random
def python_sum(n):
total = 0.0
for i in range(n):
total += i * i + i * 0.5
return total
import numpy as np
def numpy_sum(n):
arr = np.arange(n, dtype=np.float64)
return np.sum(arr * arr + arr * 0.5)
N = 1_000_000
t_python = timeit.timeit(lambda: python_sum(N), number=10)
t_numpy = timeit.timeit(lambda: numpy_sum(N), number=10)
print(f"=== 纯 Python vs numpy 向量化 (n={N}) ===")
print(f"纯 Python: {t_python:.4f}s")
print(f"numpy: {t_numpy:.4f}s")
print(f"numpy 快 {t_python/t_numpy:.0f}x")
def python_matrix_multiply(n):
"""纯 Python 矩阵乘法 — 极慢"""
A = [[random.random() for _ in range(n)] for _ in range(n)]
B = [[random.random() for _ in range(n)] for _ in range(n)]
C = [[0.0] * n for _ in range(n)]
for i in range(n):
for j in range(n):
for k in range(n):
C[i][j] += A[i][k] * B[k][j]
return C
def numpy_matrix_multiply(n):
"""numpy 矩阵乘法 — 极快"""
A = np.random.rand(n, n)
B = np.random.rand(n, n)
return A @ B
t_py_mat = timeit.timeit(lambda: python_matrix_multiply(100), number=1)
t_np_mat = timeit.timeit(lambda: numpy_matrix_multiply(100), number=100)
print(f"\n=== 矩阵乘法 100x100 ===")
print(f"纯 Python (1次): {t_py_mat:.4f}s")
print(f"numpy (100次): {t_np_mat:.4f}s")
print(f"numpy 每次: {t_np_mat/100:.6f}s")
print(f"numpy 快 {t_py_mat/(t_np_mat/100):.0f}x")
def python_normalize(data):
"""纯 Python: 归一化"""
mean = sum(data) / len(data)
std = (sum((x - mean) ** 2 for x in data) / len(data)) ** 0.5
return [(x - mean) / std for x in data]
def numpy_normalize(data):
"""numpy: 归一化"""
arr = np.array(data)
return (arr - arr.mean()) / arr.std()
data = [random.random() for _ in range(100000)]
t_py_norm = timeit.timeit(lambda: python_normalize(data), number=10)
t_np_norm = timeit.timeit(lambda: numpy_normalize(data), number=100)
print(f"\n=== 归一化 (n=100000) ===")
print(f"纯 Python (10次): {t_py_norm:.4f}s")
print(f"numpy (100次): {t_np_norm:.4f}s")
print(f"numpy 快 {t_py_norm/10/(t_np_norm/100):.0f}x")
9.2 Numba: JIT 编译数值计算
from numba import jit
import timeit
import math
def monte_carlo_pi_python(n):
"""蒙特卡洛方法计算 pi"""
count = 0
for i in range(n):
x = random.random()
y = random.random()
if x * x + y * y <= 1.0:
count += 1
return 4.0 * count / n
@jit(nopython=True)
def monte_carlo_pi_numba(n):
"""蒙特卡洛方法计算 pi — Numba JIT 编译"""
count = 0
for i in range(n):
x = random.random()
y = random.random()
if x * x + y * y <= 1.0:
count += 1
return 4.0 * count / n
import random
N = 1_000_000
_ = monte_carlo_pi_numba(1000)
t_python = timeit.timeit(lambda: monte_carlo_pi_python(N), number=10)
t_numba = timeit.timeit(lambda: monte_carlo_pi_numba(N), number=10)
print(f"=== 蒙特卡洛 pi (n={N}) ===")
print(f"纯 Python: {t_python:.4f}s")
print(f"Numba: {t_numba:.4f}s")
print(f"Numba 快 {t_python/t_numba:.0f}x")
from numba import prange
@jit(nopython=True, parallel=True)
def parallel_sum(arr):
"""并行求和"""
total = 0
for i in prange(len(arr)):
total += arr[i]
return total
arr = np.random.rand(10_000_000)
t_serial = timeit.timeit(lambda: np.sum(arr), number=10)
t_parallel = timeit.timeit(lambda: parallel_sum(arr), number=10)
print(f"\n=== 并行求和 (n=10M) ===")
print(f"numpy 串行: {t_serial:.4f}s")
print(f"Numba 并行: {t_parallel:.4f}s")
9.3 Cython: Python → C 编译
9.4 C 扩展: ctypes/cffi
import ctypes
import timeit
libc = ctypes.CDLL(None)
abs_func = libc.abs
abs_func.argtypes = [ctypes.c_int]
abs_func.restype = ctypes.c_int
print(f"abs(-42) = {abs_func(-42)}")
9.5 PyPy: 替代解释器
加速方案选择决策树
需要加速 Python 代码?
├── 能用内置函数/库?
│ └── 用 sum(), max(), sorted() 等 C 实现 → 通常快 10-100x
│
├── 数值计算/科学计算?
│ └── numpy/pandas 向量化 → 快 50-500x
│
├── 循环密集型数值计算?
│ ├── 简单: Numba @jit → 快 50-200x,改动最小
│ └── 复杂: Cython → 快 50-150x,需要编译
│
├── 需要调用 C/C++ 库?
│ ├── 简单调用: ctypes (标准库)
│ ├── 嵌入 C 代码: cffi
│ └── 深度集成: Cython
│
├── 纯 Python 逻辑密集?
│ └── PyPy 替代解释器 → 快 3-10x
│
├── CPU 密集型?
│ └── multiprocessing 多进程 → 利用多核
│
└── 不确定?
└── 先 profile → 找到瓶颈 → 针对性优化
核心差异
| 方案 | 提速 | 改动成本 | 适用场景 | vs Java/Kotlin |
|---|
| 内置函数 | 10-100x | 极低 | 通用 | 类似 JIT 内联 |
| numpy | 50-500x | 低 | 数值计算 | 类似 Stream API |
| Numba | 50-200x | 低 | 数值循环 | 类似 JIT |
| Cython | 50-150x | 中 | 性能关键 | 类似 JNI |
| ctypes/cffi | 取决于 C 代码 | 中 | 调用 C 库 | 类似 JNI |
| PyPy | 3-10x | 极低 | 纯 Python | 类似 JIT |
| multiprocessing | N 核 | 低 | CPU 密集 | 类似 parallel stream |
常见陷阱
import numpy as np
import timeit
small = list(range(100))
t_py = timeit.timeit('sum(x*x for x in data)', setup='data=list(range(100))', number=10000)
t_np = timeit.timeit('np.sum(np.array(data)**2)', setup='import numpy as np; data=list(range(100))', number=10000)
print(f"小数据: Python={t_py:.4f}s, numpy={t_np:.4f}s")
何时使用
- 优先用内置函数和标准库(C 实现)
- 数值计算: numpy/pandas
- 循环密集: Numba 或 Cython
- 调用 C 库: ctypes/cffi
- 纯 Python 加速: PyPy
- 多核并行: multiprocessing
11.10 内存管理: 引用计数 + 分代GC
Java/Kotlin 对比
import java.lang.ref.WeakReference;
Object obj = new Object();
WeakReference<Object> ref = new WeakReference<>(obj);
obj = null;
System.gc();
Object recovered = ref.get();
Python 实现
10.1 引用计数机制
import sys
import gc
a = [1, 2, 3]
print(f"创建后: refcount = {sys.getrefcount(a)}")
b = a
print(f"赋值后: refcount = {sys.getrefcount(a)}")
c = [a, a, a]
print(f"列表引用: refcount = {sys.getrefcount(a)}")
del b
print(f"del b后: refcount = {sys.getrefcount(a)}")
del c
print(f"del c后: refcount = {sys.getrefcount(a)}")
10.2 循环引用问题
import gc
import weakref
class Node:
def __init__(self, value):
self.value = value
self.next = None
def __del__(self):
print(f" Node {self.value} 被回收")
print("=== 循环引用演示 ===")
a = Node(1)
b = Node(2)
a.next = b
b.next = a
print("删除外部引用...")
del a
del b
print("手动触发 GC...")
collected = gc.collect()
print(f"GC 回收了 {collected} 个对象")
print(f"\n=== GC 配置 ===")
print(f"GC 阈值: {gc.get_threshold()}")
print(f"GC 计数: {gc.get_count()}")
print(f"GC 是否启用: {gc.isenabled()}")
gc.set_debug(gc.DEBUG_STATS)
print(f"\n被跟踪的对象数: {len(gc.get_objects())}")
refs = [obj for obj in gc.get_objects() if isinstance(obj, list)]
print(f"list 对象数: {len(refs)}")
gc.set_debug(0)
10.3 弱引用
import weakref
import gc
class MyClass:
def __init__(self, name):
self.name = name
obj = MyClass("test")
ref = weakref.ref(obj)
print(f"弱引用存在: {ref() is not None}")
print(f"对象名: {ref().name}")
del obj
print(f"删除后: {ref() is None}")
class Parent:
def __init__(self, name):
self.name = name
self._children = []
def add_child(self, child):
self._children.append(weakref.ref(child))
def get_children(self):
return [ref() for ref in self._children if ref() is not None]
class Child:
def __init__(self, name, parent):
self.name = name
self.parent = weakref.ref(parent)
parent = Parent("root")
child1 = Child("a", parent)
child2 = Child("b", parent)
parent.add_child(child1)
parent.add_child(child2)
print(f"\n子节点: {[c.name for c in parent.get_children()]}")
del child1, child2
print(f"删除后子节点: {parent.get_children()}")
cache = weakref.WeakValueDictionary()
class Data:
def __init__(self, key):
self.key = key
d1 = Data("key1")
cache["key1"] = d1
print(f"\n缓存大小: {len(cache)}")
del d1
print(f"删除后缓存大小: {len(cache)}")
10.4 del 的正确用法和陷阱
import gc
class Resource:
def __init__(self, name):
self.name = name
print(f" {name}: 创建")
def __del__(self):
print(f" {self.name}: 销毁")
print("=== __del__ 基本用法 ===")
r = Resource("r1")
del r
class A:
def __init__(self):
self.b = None
print(" A: 创建")
def __del__(self):
print(" A: 销毁")
class B:
def __init__(self):
self.a = None
print(" B: 创建")
def __del__(self):
print(" B: 销毁")
print("\n=== __del__ 与循环引用 ===")
a = A()
b = B()
a.b = b
b.a = a
del a, b
print("触发 GC...")
gc.collect()
class SafeResource:
def __init__(self, name):
self.name = name
print(f" {name}: 打开资源")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print(f" {self.name}: 关闭资源")
return False
def do_work(self):
print(f" {self.name}: 工作中")
print("\n=== 上下文管理器 (推荐) ===")
with SafeResource("db_conn") as r:
r.do_work()
10.5 内存分析工具
import tracemalloc
import sys
import gc
tracemalloc.start(25)
data = [list(range(1000)) for _ in range(1000)]
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
print("=== 内存分配 Top 5 ===")
for stat in top_stats[:5]:
print(stat)
tracemalloc.stop()
tracemalloc.start(25)
snapshot1 = tracemalloc.take_snapshot()
leaky_list = []
for i in range(10000):
leaky_list.append({'data': list(range(100))})
snapshot2 = tracemalloc.take_snapshot()
top_diffs = snapshot2.compare_to(snapshot1, 'lineno')
print("\n=== 内存增长 Top 5 ===")
for stat in top_diffs[:5]:
print(stat)
tracemalloc.stop()
print(f"\n=== 当前对象统计 ===")
print(f"list 对象: {sum(1 for obj in gc.get_objects() if isinstance(obj, list))}")
print(f"dict 对象: {sum(1 for obj in gc.get_objects() if isinstance(obj, dict))}")
print(f"tuple 对象: {sum(1 for obj in gc.get_objects() if isinstance(obj, tuple))}")
print(f"\n=== 对象大小 ===")
print(f"空 list: {sys.getsizeof([])} bytes")
print(f"空 dict: {sys.getsizeof({})} bytes")
print(f"空 set: {sys.getsizeof(set())} bytes")
print(f"空 tuple: {sys.getsizeof(())} bytes")
print(f"空 str: {sys.getsizeof('')} bytes")
print(f"int 0: {sys.getsizeof(0)} bytes")
print(f"int 2**30: {sys.getsizeof(2**30)} bytes")
print(f"float 1.0: {sys.getsizeof(1.0)} bytes")
print(f"bool True: {sys.getsizeof(True)} bytes")
print(f"None: {sys.getsizeof(None)} bytes")
核心差异
| 维度 | JVM GC | Python GC |
|---|
| 主要机制 | 分代 GC | 引用计数 |
| 辅助机制 | 无 | 分代 GC(处理循环引用) |
| 回收时机 | GC 周期(不确定) | 引用归零立即回收 |
| STW 停顿 | 有(GC 时) | 有(GC 循环检测时) |
| 手动控制 | System.gc()(建议) | gc.collect() |
| 循环引用 | 自动处理 | 需要辅助 GC |
| 弱引用 | WeakReference | weakref.ref |
| 软引用 | SoftReference | 无(用 cachetools) |
| 资源清理 | try-with-resources / AutoCloseable | with / enter/exit |
常见陷阱
class Danger:
def __init__(self):
self.file = open('/tmp/test', 'w')
def __del__(self):
self.file.close()
class Safe:
def __enter__(self):
self.file = open('/tmp/test', 'w')
return self
def __exit__(self, *args):
self.file.close()
def create_processor():
big_data = list(range(1_000_000))
def process(x):
return x + big_data[0]
return process
f = create_processor()
何时使用
- 理解引用计数是理解 Python 内存行为的基础
- 循环引用场景: 用 weakref 或确保 gc.collect() 能运行
- 大量临时对象: 考虑对象池或 slots
- 资源管理: 用 with 语句,不要依赖 del
- 内存泄漏排查: tracemalloc + gc.get_objects()
11.11 并发性能选型
Java/Kotlin 对比
ExecutorService executor = Executors.newFixedThreadPool(4);
List<Future<Integer>> futures = new ArrayList<>();
for (int i = 0; i < 10; i++) {
futures.add(executor.submit(() -> heavyComputation()));
}
for (Future<Integer> f : futures) {
result += f.get();
}
import kotlinx.coroutines.*
suspend fun cpuBound() = withContext(Dispatchers.Default) {
heavyComputation()
}
suspend fun ioBound() = withContext(Dispatchers.IO) {
httpClient.get(url)
}
val results = coroutineScope {
listOf(async { task1() }, async { task2() })
}.awaitAll()
Python 实现
11.1 CPU 密集型: multiprocessing vs threading vs asyncio
import time
import timeit
import multiprocessing
import threading
import asyncio
def cpu_task(n):
"""CPU 密集型: 纯计算"""
total = 0
for i in range(n):
total += i * i
return total
N = 5_000_000
def serial():
results = []
for _ in range(4):
results.append(cpu_task(N))
return results
def threading_test():
results = [None] * 4
threads = []
for i in range(4):
t = threading.Thread(target=lambda idx=i: results.__setitem__(idx, cpu_task(N)))
threads.append(t)
t.start()
for t in threads:
t.join()
return results
def multiprocessing_test():
with multiprocessing.Pool(4) as pool:
results = pool.map(cpu_task, [N] * 4)
return results
async def async_cpu_task(n):
"""asyncio 不适合 CPU 密集型(会阻塞事件循环)"""
return cpu_task(n)
async def asyncio_test():
tasks = [async_cpu_task(N) for _ in range(4)]
return await asyncio.gather(*tasks)
print("=== CPU 密集型 benchmark ===")
print(f"串行:")
t = timeit.timeit(serial, number=3)
print(f" 耗时: {t:.4f}s")
print(f"threading (4线程):")
t = timeit.timeit(threading_test, number=3)
print(f" 耗时: {t:.4f}s")
print(f"multiprocessing (4进程):")
t = timeit.timeit(multiprocessing_test, number=3)
print(f" 耗时: {t:.4f}s")
print(f"asyncio (4协程):")
t = timeit.timeit(lambda: asyncio.run(asyncio_test()), number=3)
print(f" 耗时: {t:.4f}s")
11.2 I/O 密集型: threading vs asyncio
import time
import threading
import asyncio
import urllib.request
def io_task(url):
"""模拟 I/O 操作(网络请求)"""
try:
resp = urllib.request.urlopen(url, timeout=5)
return len(resp.read())
except Exception:
return 0
def io_task_simulated(delay):
"""模拟 I/O 延迟"""
time.sleep(delay)
return delay
async def async_io_task_simulated(delay):
"""asyncio 版本"""
await asyncio.sleep(delay)
return delay
URLS = ['https://httpbin.org/delay/0.1'] * 8
def serial_io():
results = []
for url in URLS:
results.append(io_task(url))
return results
def threading_io():
results = [None] * len(URLS)
threads = []
for i, url in enumerate(URLS):
t = threading.Thread(target=lambda u=url, idx=i: results.__setitem__(idx, io_task(u)))
threads.append(t)
t.start()
for t in threads:
t.join()
return results
async def asyncio_io():
loop = asyncio.get_event_loop()
tasks = [loop.run_in_executor(None, io_task, url) for url in URLS]
return await asyncio.gather(*tasks)
N_IO = 8
DELAY = 0.1
print("\n=== I/O 密集型 benchmark ===")
start = time.perf_counter()
for _ in range(N_IO):
io_task_simulated(DELAY)
t_serial = time.perf_counter() - start
print(f"串行: {t_serial:.4f}s")
start = time.perf_counter()
threads = [threading.Thread(target=io_task_simulated, args=(DELAY,)) for _ in range(N_IO)]
for t in threads: t.start()
for t in threads: t.join()
t_threading = time.perf_counter() - start
print(f"threading: {t_threading:.4f}s")
start = time.perf_counter()
asyncio.run(asyncio.gather(*[async_io_task_simulated(DELAY) for _ in range(N_IO)]))
t_asyncio = time.perf_counter() - start
print(f"asyncio: {t_asyncio:.4f}s")
11.3 concurrent.futures
import time
import concurrent.futures
import multiprocessing
def cpu_work(n):
total = 0
for i in range(n):
total += i * i
return total
def io_work(delay):
time.sleep(delay)
return delay
def process_pool_demo():
with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(cpu_work, 5_000_000) for _ in range(8)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
return results
def thread_pool_demo():
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
futures = [executor.submit(io_work, 0.1) for _ in range(8)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
return results
def map_demo():
with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
results = list(executor.map(cpu_work, [5_000_000] * 8))
return results
print("=== concurrent.futures benchmark ===")
start = time.perf_counter()
process_pool_demo()
t_proc = time.perf_counter() - start
print(f"ProcessPoolExecutor (8个CPU任务, 4进程): {t_proc:.4f}s")
start = time.perf_counter()
thread_pool_demo()
t_thread = time.perf_counter() - start
print(f"ThreadPoolExecutor (8个I/O任务, 8线程): {t_thread:.4f}s")
start = time.perf_counter()
map_demo()
t_map = time.perf_counter() - start
print(f"ProcessPoolExecutor.map (8个CPU任务, 4进程): {t_map:.4f}s")
并发方案选择决策树
需要并发?
├── CPU 密集型(计算)
│ └── multiprocessing / ProcessPoolExecutor
│ - 绕过 GIL,真正并行
│ - 进程间通信有开销(pickle 序列化)
│ - 适合: 数据处理、数值计算、机器学习
│
├── I/O 密集型(网络/文件/数据库)
│ ├── 简单场景 → threading / ThreadPoolExecutor
│ │ - 简单直接,不需要改代码结构
│ │ - 适合: 少量并发 I/O
│ │
│ ├── 高并发 → asyncio
│ │ - 单线程,无锁问题
│ │ - 适合: Web 服务器、API 客户端、WebSocket
│ │ - 需要: async/await 全链路
│ │
│ └── 混合 → asyncio + run_in_executor
│ - CPU 部分用线程池/进程池
│ - I/O 部分用 asyncio
│
├── 需要共享状态?
│ ├── 少量共享 → threading + Lock/Queue
│ └── 大量共享 → multiprocessing + SharedMemory/Queue
│
└── 不确定?
└── 先用 concurrent.futures(统一接口,容易切换)
核心差异
| 维度 | Java/Kotlin | Python |
|---|
| CPU 并行 | 多线程(真正并行) | multiprocessing(绕过 GIL) |
| I/O 并发 | 多线程 / 协程 | threading / asyncio |
| 线程开销 | 轻(OS 线程) | 轻(OS 线程) |
| 协程 | Kotlin 协程 | asyncio |
| GIL | 无 | 有(CPython) |
| 共享内存 | 天然支持 | 多进程需要序列化 |
| 线程池 | ExecutorService | ThreadPoolExecutor |
| 进程池 | ProcessBuilder | ProcessPoolExecutor |
常见陷阱
async def bad():
result = sum(i * i for i in range(10_000_000))
return result
async def good():
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, lambda: sum(i * i for i in range(10_000_000)))
return result
何时使用
- CPU 密集型: multiprocessing / ProcessPoolExecutor
- I/O 密集型(简单): threading / ThreadPoolExecutor
- I/O 密集型(高并发): asyncio
- 混合: asyncio + run_in_executor
- 快速原型: concurrent.futures(统一接口)
本章总结: 性能优化决策树
需要优化性能?
│
├── 第一步: 先 profile!不要猜瓶颈
│ ├── 微基准测试 → timeit
│ ├── 函数级分析 → cProfile
│ ├── 行级分析 → line_profiler
│ ├── 生产环境 → py-spy
│ └── 内存分析 → tracemalloc / memray
│
├── 算法/数据结构问题?
│ ├── O(n) → O(1): set 替代 list 查找
│ ├── O(n²) → O(n log n): 排序 + 二分查找
│ ├── O(n²) → O(n): join 替代 + 拼接
│ └── deque 替代 list.insert(0)
│
├── I/O 瓶颈?
│ ├── 少量并发 → threading
│ ├── 高并发 → asyncio
│ └── 混合 → asyncio + run_in_executor
│
├── CPU 瓶颈?
│ ├── 能用内置函数? → sum(), max(), sorted() (C 实现)
│ ├── 能用缓存? → @lru_cache / @cache
│ ├── 数值计算? → numpy 向量化 (50-500x)
│ ├── 循环密集? → Numba @jit (50-200x)
│ ├── 需要编译? → Cython (50-150x)
│ ├── 多核可用? → multiprocessing
│ └── 纯 Python? → 升级 3.11+ (免费 25-60%)
│
├── 内存瓶颈?
│ ├── 大量小对象? → __slots__ (省 40-60%)
│ ├── 大数据集? → 生成器 (yield)
│ ├── 循环引用? → weakref
│ ├── 字符串拼接? → join
│ └── 缓存膨胀? → lru_cache(maxsize=...)
│
├── 函数调用开销?
│ ├── 热循环中避免小函数调用
│ ├── 属性访问提取为局部变量
│ ├── 全局函数缓存为局部变量
│ └── 用 dis 模块分析字节码
│
└── 还是不够快?
├── PyPy (纯 Python 3-10x)
├── C 扩展 (ctypes/cffi)
├── 重写为 C/C++/Rust
└── 换语言 (但通常不需要)
性能优化优先级
1. 选对算法和数据结构 → 潜力: 100-10000x
2. 用内置函数和标准库 → 潜力: 10-100x
3. numpy/pandas 向量化 → 潜力: 50-500x
4. Numba/Cython 编译 → 潜力: 50-200x
5. 升级 Python 3.11+ → 潜力: 1.25-1.6x
6. multiprocessing 多进程 → 潜力: N 核
7. 缓存策略 → 潜力: 取决于重复率
8. __slots__ 内存优化 → 潜力: 省内存 40-60%
9. asyncio/threading I/O → 潜力: 取决于 I/O 等待比例
10. PyPy 替代解释器 → 潜力: 3-10x
一句话总结
Python 性能优化的核心原则: 先 profile,再优化;先算法,再工具;先内置,再扩展。
理解 CPython 的执行模型(解释执行、GIL、引用计数)是一切优化的基础。
大多数性能问题可以通过选对数据结构和用对标准库来解决,不需要引入复杂的加速方案。