python从入门到精通-第5章: 函数式编程 — Python的函数式风格

1 阅读31分钟

第5章: 函数式编程 — Python的函数式风格

Java/Kotlin 开发者习惯了 Stream API 和 lambda 表达式,但 Python 的函数式编程走的是完全不同的路。Python 没有函数类型签名、没有受检异常、lambda 只能写单行——但 Python 有一等函数、装饰器、生成器、functoolsitertools 这些 JVM 世界不存在的利器。装饰器是 Python 最强大的元编程手段,生成器是惰性求值的底层原语,itertools 提供了 Java Stream 望尘莫及的组合能力。 本章从一等函数出发,逐步拆解 Python 函数式编程的全部核心。


5.1 一等函数: 函数是对象

Java/Kotlin 对比

// Java: Lambda 是函数式接口的实例,不是一等对象
// 你不能把一个方法本身传来传去,必须包装成函数式接口
import java.util.function.*;

Function<Integer, Integer> doubleIt = x -> x * 2;
// doubleIt 是一个对象,但类型是 Function<Integer,Integer>,不是"函数"
// 你不能直接把 System.out::println 存到 Map<String, ?> 里当通用函数用

// 方法引用只是语法糖,底层还是函数式接口
Consumer<String> printer = System.out::println;
// Kotlin: 函数类型是一等公民,和 Python 最接近
val doubleIt: (Int) -> Int = { x -> x * 2 }
val greet: (String) -> Unit = { println("Hi, $it") }

// 可以存入集合
val ops = mapOf(
    "double" to { x: Int -> x * 2 },
    "square" to { x: Int -> x * x }
)

Python 实现

# === 函数就是对象,类型是 function ===
def greet(name: str) -> str:
    return f"Hello, {name}"

print(type(greet))        # <class 'function'>
print(greet.__name__)     # 'greet'
print(id(greet))          # 内存地址


# === 1. 赋值给变量 ===
say_hi = greet            # 没有括号!不是调用,是引用
print(say_hi("Alice"))    # Hello, Alice


# === 2. 存入数据结构 ===
ops = {
    "double": lambda x: x * 2,
    "square": lambda x: x ** 2,
    "negate": lambda x: -x,
}

print(ops["double"](5))   # 10
print(ops["square"](3))   # 9

# 函数列表,按顺序执行
pipeline = [
    lambda x: x.strip(),
    lambda x: x.lower(),
    lambda x: x.replace(" ", "-"),
]

text = "  Hello World  "
result = text
for fn in pipeline:
    result = fn(result)
print(result)  # 'hello-world'


# === 3. 作为参数传递 ===
def apply(fn, value):
    return fn(value)

print(apply(lambda x: x * 2, 5))    # 10
print(apply(str.upper, "hello"))    # HELLO


# === 4. 作为返回值 ===
def power(exp):
    """返回一个求幂函数"""
    return lambda base: base ** exp

square = power(2)
cube = power(3)
print(square(5))   # 25
print(cube(3))     # 27


# === 5. 函数属性:函数对象可以携带任意数据 ===
def important_func():
    """一个重要的函数"""
    pass

important_func.version = "1.0"
important_func.author = "Alice"
important_func.tags = ["math", "utils"]

print(important_func.version)  # 1.0
print(important_func.tags)     # ['math', 'utils']

# 这在 Java 中做不到——方法不能携带自定义属性


# === 6. 函数比较 ===
def foo():
    pass

bar = foo
print(foo is bar)          # True — 同一个对象
print(foo == bar)          # True

def baz():
    pass

print(foo == baz)          # False — 不同函数对象,即使代码相同

核心差异

特性JavaKotlinPython
函数是一等对象否(通过函数式接口模拟)
函数可携带属性
函数类型Function<T,R> 等接口(T) -> RCallable 协议
lambda 限制必须匹配函数式接口只能单行表达式

常见陷阱

# 陷阱:lambda 的延迟绑定
funcs = [lambda: i for i in range(5)]
print([f() for f in funcs])  # [4, 4, 4, 4, 4] — 全部是 4!

# 原因:lambda 捕获的是变量 i 的引用,不是值
# 循环结束后 i = 4,所有 lambda 共享同一个 i

# 修复:用默认参数捕获值
funcs = [lambda i=i: i for i in range(5)]
print([f() for f in funcs])  # [0, 1, 2, 3, 4]

何时使用

  • 需要把行为作为数据传递时(策略模式、回调)
  • 构建函数管道/处理链时
  • 需要给函数附加元数据时(Python 独有优势)

5.2 高阶函数: map, filter, sorted

Java/Kotlin 对比

// Java Stream API — 链式调用,惰性求值
import java.util.*;
import java.util.stream.*;

List<Integer> nums = List.of(1, 2, 3, 4, 5);
List<Integer> result = nums.stream()
    .filter(n -> n % 2 == 0)
    .map(n -> n * n)
    .collect(Collectors.toList());
// [4, 16]

// sorted
List<String> names = List.of("Charlie", "Alice", "Bob");
List<String> sorted = names.stream()
    .sorted(Comparator.comparingInt(String::length))
    .collect(Collectors.toList());
// [Bob, Alice, Charlie]

// Java Stream 只能用一次!
Stream<Integer> s = nums.stream();
s.filter(n -> n > 2);  // 消费了
s.map(n -> n * 2);     // IllegalStateException: stream has already been operated upon
// Kotlin 集合函数 — 直接在集合上调用,非惰性
val nums = listOf(1, 2, 3, 4, 5)
val result = nums.filter { it % 2 == 0 }.map { it * it }
// [4, 16]

// 惰性版本:asSequence()
val result2 = nums.asSequence()
    .filter { it % 2 == 0 }
    .map { it * it }
    .toList()

Python 实现

# === map: 映射 ===
nums = [1, 2, 3, 4, 5]

# 返回迭代器(惰性),不是列表!
squares = map(lambda x: x ** 2, nums)
print(type(squares))          # <class 'map'>
print(list(squares))          # [1, 4, 9, 16, 25]

# 多参数 map
names = ["alice", "bob", "charlie"]
upper = map(str.upper, names)
print(list(upper))            # ['ALICE', 'BOB', 'CHARLIE']

# 多序列并行 map
a = [1, 2, 3]
b = [10, 20, 30]
print(list(map(lambda x, y: x + y, a, b)))  # [11, 22, 33]

# Pythonic 写法:列表推导式通常比 map 更清晰
squares = [x ** 2 for x in nums]
print(squares)  # [1, 4, 9, 16, 25]


# === filter: 过滤 ===
evens = filter(lambda x: x % 2 == 0, nums)
print(list(evens))            # [2, 4]

# Pythonic 写法
evens = [x for x in nums if x % 2 == 0]
print(evens)                  # [2, 4]


# === sorted: 排序(返回新列表,不修改原列表) ===
words = ["banana", "apple", "cherry", "date"]
print(sorted(words))                    # ['apple', 'banana', 'cherry', 'date']
print(sorted(words, key=len))           # ['date', 'apple', 'banana', 'cherry']
print(sorted(words, key=lambda w: w[-1]))  # ['apple', 'banana', 'date', 'cherry']

# 反向排序
print(sorted(words, reverse=True))      # ['date', 'cherry', 'banana', 'apple']

# 多级排序:先按长度,再按字母
data = [("alice", 30), ("bob", 25), ("charlie", 25)]
print(sorted(data, key=lambda x: (x[1], x[0])))
# [('bob', 25), ('charlie', 25), ('alice', 30)]

# 注意:list.sort() 是原地排序,sorted() 返回新列表
words_copy = words[:]
words_copy.sort(key=len)    # 原地修改
print(words_copy)           # ['date', 'apple', 'banana', 'cherry']


# === reduce: 归约 ===
from functools import reduce

nums = [1, 2, 3, 4, 5]
total = reduce(lambda acc, x: acc + x, nums)
print(total)                 # 15

# 带初始值(推荐,避免空序列报错)
total = reduce(lambda acc, x: acc + x, nums, 0)
print(total)                 # 15

empty_result = reduce(lambda acc, x: acc + x, [], 0)
print(empty_result)          # 0 — 有初始值,空序列不报错

# Pythonic 写法:sum() 比 reduce 更清晰
print(sum(nums))             # 15


# === any / all: 短路求值 ===
nums = [1, 2, 3, 4, 5]
print(any(x > 10 for x in nums))   # False
print(any(x > 3 for x in nums))    # True
print(all(x > 0 for x in nums))    # True
print(all(x > 2 for x in nums))    # False


# === zip: 并行迭代 ===
names = ["alice", "bob", "charlie"]
scores = [95, 87, 92]
print(list(zip(names, scores)))
# [('alice', 95), ('bob', 87), ('charlie', 92)]

# 字典推导
print(dict(zip(names, scores)))
# {'alice': 95, 'bob': 87, 'charlie': 92}

# 不等长时,以最短的为准(和 Kotlin zip 一致)
print(list(zip([1, 2, 3], [10, 20])))  # [(1, 10), (2, 20)]


# === enumerate: 带索引迭代 ===
# Java: 没有直接等价物,通常用 IntStream.range
# Kotlin: withIndex()
for i, name in enumerate(names):
    print(f"{i}: {name}")
# 0: alice
# 1: bob
# 2: charlie

# 指定起始索引
for i, name in enumerate(names, start=1):
    print(f"{i}: {name}")
# 1: alice
# 2: bob
# 3: charlie

核心差异

特性Java StreamKotlinPython
惰性求值是(直到终端操作)默认非惰性,asSequence() 惰性map/filter 返回迭代器,惰性
可复用否(只能消费一次)集合可复用,Sequence 只能一次迭代器只能消费一次,列表可复用
多参数 mapmapToObjzip + mapmap 原生支持多序列
并行.parallelStream()协程/Flowconcurrent.futures

常见陷阱

# 陷阱:迭代器只能消费一次
m = map(lambda x: x * 2, [1, 2, 3])
print(list(m))  # [2, 4, 6]
print(list(m))  # [] — 已经耗尽!

# 修复:需要多次使用时,转为列表
result = list(map(lambda x: x * 2, [1, 2, 3]))
print(result)   # [2, 4, 6]
print(result)   # [2, 4, 6]

# 陷阱:sorted() 返回新列表,不修改原列表
words = ["b", "a", "c"]
sorted(words)
print(words)    # ['b', 'a', 'c'] — 没变!
words.sort()
print(words)    # ['a', 'b', 'c'] — 原地排序才变

何时使用

  • map/filter:简单转换时可用,但列表推导式通常更 Pythonic
  • sorted:需要保持原列表不变时
  • any/all:条件检查,比手写循环清晰
  • zip/enumerate:并行迭代和带索引迭代,极其常用

5.3 闭包与 nonlocal

Java/Kotlin 对比

// Java: lambda 只能捕获 effectively final 变量
int factor = 2;
// factor = 3;  // 取消注释就编译错误!
Function<Integer, Integer> multiplier = x -> x * factor;
// factor 必须是 effectively final

// Java lambda 不能修改捕获的变量
// 没有 nonlocal 等价物
// Kotlin: lambda 可以捕获可变变量(var)
var count = 0
val inc: () -> Unit = { count++ }
inc()
inc()
println(count)  // 2 — 可以修改捕获的变量

Python 实现

# === 闭包基础 ===
def make_multiplier(factor):
    """返回一个乘法函数,factor 被闭包捕获"""
    def multiply(x):
        return x * factor  # factor 来自外层作用域
    return multiply

double = make_multiplier(2)
triple = make_multiplier(3)
print(double(5))   # 10
print(triple(5))   # 15

# 查看闭包捕获的变量
print(double.__closure__)      # (<cell at ...: int object at ...>,)
print(double.__closure__[0].cell_contents)  # 2


# === 闭包陷阱:循环变量绑定 ===
def make_funcs():
    funcs = []
    for i in range(3):
        def f():
            return i  # 捕获的是变量 i,不是值
        funcs.append(f)
    return funcs

funcs = make_funcs()
print([f() for f in funcs])  # [2, 2, 2] — 全部是 2!

# 修复1:默认参数捕获值
def make_funcs_fixed():
    funcs = []
    for i in range(3):
        def f(i=i):  # 默认参数在定义时求值
            return i
        funcs.append(f)
    return funcs

print([f() for f in make_funcs_fixed()])  # [0, 1, 2]

# 修复2:用工厂函数创建独立作用域
def make_func(i):
    def f():
        return i
    return f

def make_funcs_fixed2():
    return [make_func(i) for i in range(3)]

print([f() for f in make_funcs_fixed2()])  # [0, 1, 2]


# === nonlocal: 修改闭包捕获的变量 ===
def make_counter():
    count = 0  # 闭包变量

    def increment():
        nonlocal count  # 声明使用外层作用域的 count
        count += 1
        return count

    def get_count():
        return count

    return increment, get_count

inc, get = make_counter()
print(inc())   # 1
print(inc())   # 2
print(inc())   # 3
print(get())   # 3

# 没有 nonlocal 会怎样?
def broken_counter():
    count = 0

    def increment():
        # count += 1  # UnboundLocalError: local variable 'count' referenced before assignment
        # 因为 += 意味着赋值,Python 把 count 当成局部变量
        pass

    return increment


# === nonlocal 实战:带状态的闭包 ===
def make_accumulator(initial=0):
    total = initial

    def add(value):
        nonlocal total
        total += value
        return total

    def reset():
        nonlocal total
        total = initial
        return total

    # 返回一个"对象"——用闭包模拟状态
    return {"add": add, "reset": reset, "get": lambda: total}

acc = make_accumulator(100)
print(acc["add"](10))    # 110
print(acc["add"](20))    # 130
print(acc["reset"]())    # 100
print(acc["add"](5))     # 105


# === global: 修改全局变量(慎用) ===
cache = {}

def memoize_global(key, value):
    global cache  # 声明使用全局变量
    cache[key] = value
    return value

memoize_global("a", 1)
print(cache)  # {'a': 1}

核心差异

特性JavaKotlinPython
捕获 effectively final强制不强制不强制
修改捕获的变量不允许允许需要 nonlocal 声明
闭包变量绑定值捕获引用捕获引用捕获
查看闭包变量不支持不支持__closure__ 属性

常见陷阱

# 陷阱1:赋值操作使变量变局部
x = 10
def foo():
    x = 20      # 创建了局部变量 x,不影响外层
    return x
print(foo())   # 20
print(x)       # 10 — 外层没变

# 陷阱2:先读后写,UnboundLocalError
x = 10
def bar():
    # print(x)  # UnboundLocalError!
    x = 20      # 因为有赋值,Python 把 x 当局部变量
    return x

# 陷阱3:nonlocal 找的是最近外层函数作用域,不是全局
x = 10
def outer():
    x = 20
    def inner():
        nonlocal x  # 指向 outer 的 x,不是全局的 x
        x = 30
    inner()
    print(x)   # 30
outer()
print(x)       # 10 — 全局的 x 没变

何时使用

  • 闭包:需要封装状态但不想创建完整类时(轻量级替代)
  • nonlocal:闭包中需要修改外层变量时
  • global:几乎不需要用,模块级状态用类管理更好

5.4 装饰器: Python 的杀手级特性

Java/Kotlin 对比

// Java: 没有直接等价物
// 最接近的是注解处理器(编译期)和 AOP(运行期),但完全不同的机制
// Java 注解是被动标记,装饰器是主动代码变换

// Spring AOP — 需要代理、切面、复杂的框架支持
@Around("execution(* com.example.service.*.*(..))")
public Object log(ProceedingJoinPoint pjp) throws Throwable {
    System.out.println("Before: " + pjp.getSignature());
    Object result = pjp.proceed();
    System.out.println("After");
    return result;
}
// Kotlin: 同样没有直接等价物
// 可以用高阶函数模拟,但语法不如装饰器优雅
fun <T> withLogging(fn: () -> T): T {
    println("Before")
    val result = fn()
    println("After")
    return result
}

Python 实现

import functools
import time


# === 1. 基础装饰器:无参数 ===
def timer(func):
    """计时装饰器"""
    @functools.wraps(func)  # 保留原函数的元信息
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        elapsed = time.perf_counter() - start
        print(f"{func.__name__} took {elapsed:.4f}s")
        return result
    return wrapper

@timer
def slow_add(a, b):
    time.sleep(0.1)
    return a + b

print(slow_add(1, 2))
# slow_add took 0.1001s
# 3

# @timer 等价于 slow_add = timer(slow_add)
# 装饰器的本质:函数 = 装饰器(原函数)


# === 2. 带参数的装饰器 ===
def retry(max_attempts=3, delay=1.0):
    """重试装饰器工厂"""
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            for attempt in range(1, max_attempts + 1):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    print(f"Attempt {attempt} failed: {e}")
                    if attempt == max_attempts:
                        raise
                    time.sleep(delay)
        return wrapper
    return decorator

@retry(max_attempts=3, delay=0.1)
def unstable_api():
    import random
    if random.random() < 0.7:
        raise ConnectionError("Network error")
    return "success"

print(unstable_api())  # 可能重试几次后成功


# === 3. 类作为装饰器 ===
class CountCalls:
    """统计函数调用次数"""
    def __init__(self, func):
        self.func = func
        self.count = 0
        functools.update_wrapper(self, func)

    def __call__(self, *args, **kwargs):
        self.count += 1
        print(f"Call #{self.count} of {self.func.__name__}")
        return self.func(*args, **kwargs)

@CountCalls
def say_hello(name):
    return f"Hello, {name}"

print(say_hello("Alice"))  # Call #1 of say_hello
print(say_hello("Bob"))    # Call #2 of say_hello
print(say_hello.count)     # 2


# === 4. 装饰器堆叠 ===
def bold(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return f"**{func(*args, **kwargs)}**"
    return wrapper

def italic(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return f"*{func(*args, **kwargs)}*"
    return wrapper

# 执行顺序:从下往上装饰,从上往下执行
@bold
@italic
def greet(name):
    return f"Hello, {name}"

print(greet("Alice"))  # ***Hello, Alice*** — italic 先执行,bold 后执行


# === 5. @functools.wraps 保留元信息 ===
def bad_decorator(func):
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper  # 没有 @wraps!

@bad_decorator
def my_function():
    """This is my function."""
    pass

print(my_function.__name__)   # 'wrapper' — 丢失了!
print(my_function.__doc__)    # None — 丢失了!

# 用 @functools.wraps 修复
def good_decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

@good_decorator
def my_function2():
    """This is my function."""
    pass

print(my_function2.__name__)  # 'my_function2' — 保留了!
print(my_function2.__doc__)   # 'This is my function.' — 保留了!


# === 6. 装饰器工厂(带参数的类装饰器) ===
class Validate:
    """参数验证装饰器"""
    def __init__(self, **rules):
        self.rules = rules  # {参数名: (类型, 是否必填)}

    def __call__(self, func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # 简单验证
            for param_name, (expected_type, required) in self.rules.items():
                value = kwargs.get(param_name)
                if required and value is None:
                    raise ValueError(f"Missing required parameter: {param_name}")
                if value is not None and not isinstance(value, expected_type):
                    raise TypeError(
                        f"Parameter {param_name} must be {expected_type}, got {type(value)}"
                    )
            return func(*args, **kwargs)
        return wrapper

@Validate(name=(str, True), age=(int, False))
def create_user(name, age=None):
    return {"name": name, "age": age}

print(create_user(name="Alice", age=30))  # {'name': 'Alice', 'age': 30}
# create_user(name="Alice", age="30")    # TypeError!
# create_user(age=30)                    # ValueError!


# === 7. 常用内置装饰器 ===

# @staticmethod: 静态方法,不接收 self/cls
class MathUtils:
    @staticmethod
    def add(a, b):
        return a + b

print(MathUtils.add(1, 2))  # 3

# @classmethod: 类方法,接收 cls 而非 self
class Database:
    _instance = None

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance

    @classmethod
    def reset(cls):
        cls._instance = None

db1 = Database.get_instance()
db2 = Database.get_instance()
print(db1 is db2)  # True — 单例

# @dataclass: 自动生成 __init__, __repr__, __eq__
from dataclasses import dataclass, field

@dataclass
class Point:
    x: float
    y: float
    label: str = "origin"

    def distance_to(self, other):
        return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2) ** 0.5

p1 = Point(3, 4)
p2 = Point(0, 0, "center")
print(p1)              # Point(x=3.0, y=4.0, label='origin')
print(p1 == Point(3, 4))  # True — 自动生成 __eq__
print(p1.distance_to(p2))  # 5.0

# @dataclass 高级用法
@dataclass
class Order:
    items: list[str] = field(default_factory=list)
    total: float = 0.0

order = Order(items=["apple", "banana"])
print(order)  # Order(items=['apple', 'banana'], total=0.0)

核心差异

特性Java 注解Python 装饰器
执行时机编译期/运行期(被动)定义时立即执行(主动)
代码变换需要注解处理器/AOP直接修改函数对象
参数化注解属性装饰器工厂
组合多注解可叠加可堆叠,有执行顺序
元信息保留反射 API@functools.wraps

常见陷阱

# 陷阱1:装饰器返回的对象必须可调用
def broken_decorator(func):
    return None  # 不是可调用对象!

@broken_decorator
def foo():
    pass

# foo()  # TypeError: 'NoneType' object is not callable

# 陷阱2:带参数装饰器的三层嵌套容易搞混
# 正确结构:外层接收参数 → 中层接收函数 → 内层是实际包装
def with_param(arg):
    def decorator(func):        # 接收被装饰的函数
        def wrapper(*a, **kw):  # 实际执行
            return func(*a, **kw)
        return wrapper
    return decorator             # 返回装饰器

# 陷阱3:类装饰器不会自动继承被装饰类的行为
# 如果用类装饰器装饰类,需要实现 __init_subclass__ 或手动处理

何时使用

  • 日志、计时、缓存、权限检查:横切关注点
  • 注册机制:@app.route@pytest.fixture
  • 参数验证:输入校验
  • 单例/缓存:设计模式
  • 原则:装饰器应该透明,被装饰函数的行为不应被意外改变

5.5 functools: partial, lru_cache, singledispatch

Java/Kotlin 对比

// Java: 没有直接等价物
// 偏函数:手动创建新方法或用方法引用
// 缓存:Guava Cache, Caffeine 等
// 泛型函数:方法重载(编译期分派)
// Kotlin: 偏函数用 :: 部分应用
fun power(base: Int, exp: Int): Int = Math.pow(base.toDouble(), exp.toDouble()).toInt()
val square = { x: Int -> power(x, 2) }  // 手动偏函数

Python 实现

import functools
from functools import partial, lru_cache, singledispatch, total_ordering, wraps


# === 1. functools.partial: 偏函数 ===
# 固定函数的部分参数,返回新函数

def power(base, exp):
    return base ** exp

square = partial(power, exp=2)
cube = partial(power, exp=3)
print(square(5))   # 25
print(cube(3))     # 27

# 固定位置参数
double = partial(power, 2)
print(double(10))  # 1024 — 2^10

# 实际应用:固定 API 调用的公共参数
import urllib.request

# 假设有一个请求函数
def fetch(url, timeout, headers):
    return f"Fetching {url} with timeout={timeout}"

# 固定 timeout 和 headers
safe_fetch = partial(fetch, timeout=30, headers={"User-Agent": "MyApp"})
print(safe_fetch("https://example.com"))
# Fetching https://example.com with timeout=30

# partial 对象的属性
print(square.func)       # <function power>
print(square.args)       # ()
print(square.keywords)   # {'exp': 2}


# === 2. @lru_cache: 缓存(Python 杀手级特性) ===
# 自动缓存函数返回值,相同参数直接返回缓存结果

@lru_cache(maxsize=128)
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n - 1) + fibonacci(n - 2)

# 没有缓存:O(2^n) 指数级
# 有缓存:O(n) 线性级
print(fibonacci(100))  # 354224848179261915075 — 瞬间完成
print(fibonacci(100))  # 第二次直接从缓存读取

# 查看缓存状态
print(fibonacci.cache_info())
# CacheInfo(hits=100, misses=101, maxsize=128, currsize=101)

# 清除缓存
fibonacci.cache_clear()


# 无限缓存
@lru_cache(maxsize=None)
def expensive_computation(x):
    print(f"Computing for {x}...")
    return x * x

print(expensive_computation(5))  # Computing for 5... → 25
print(expensive_computation(5))  # 25 — 直接返回,没有打印


# 缓存实战:HTTP 请求缓存
@lru_cache(maxsize=1000)
def get_user(user_id):
    """模拟数据库查询"""
    print(f"Querying DB for user {user_id}")
    return {"id": user_id, "name": f"User{user_id}"}

print(get_user(1))  # Querying DB for user 1 → {'id': 1, 'name': 'User1'}
print(get_user(1))  # {'id': 1, 'name': 'User1'} — 缓存命中
print(get_user(2))  # Querying DB for user 2 → {'id': 2, 'name': 'User2'}


# === 3. @functools.singledispatch: 泛型函数 ===
# 根据第一个参数的类型分派到不同的实现

@singledispatch
def process(data):
    raise NotImplementedError(f"Cannot process {type(data)}")

@process.register(int)
def _(data):
    return f"Integer: {data * 2}"

@process.register(str)
def _(data):
    return f"String: {data.upper()}"

@process.register(list)
def _(data):
    return f"List: {len(data)} items"

@process.register(dict)
def _(data):
    return f"Dict: {len(data)} keys"

print(process(42))       # Integer: 84
print(process("hello"))  # String: HELLO
print(process([1, 2]))   # List: 2 items
print(process({"a": 1})) # Dict: 1 keys
# process(3.14)          # NotImplementedError

# 注册类型检查:支持 isinstance 语义
@process.register(float)
@process.register(int)   # 可以注册多个类型
def _(data):
    return f"Number: {data}"


# === 4. @functools.total_ordering ===
# 只需实现 __eq__ 和一个比较方法,自动生成其余

@total_ordering
class Version:
    def __init__(self, major, minor, patch):
        self.major = major
        self.minor = minor
        self.patch = patch

    def __eq__(self, other):
        return (self.major, self.minor, self.patch) == (other.major, other.minor, other.patch)

    def __lt__(self, other):
        return (self.major, self.minor, self.patch) < (other.major, other.minor, other.patch)

    def __repr__(self):
        return f"Version({self.major}.{self.minor}.{self.patch})"

v1 = Version(1, 0, 0)
v2 = Version(2, 0, 0)
v3 = Version(1, 0, 0)

print(v1 < v2)   # True
print(v2 > v1)   # True — total_ordering 自动生成
print(v1 <= v2)  # True — 自动生成
print(v1 >= v3)  # True — 自动生成
print(v1 != v2)  # True — 自动生成


# === 5. @functools.wraps ===
# 已在 5.4 节详细讲解,这里展示高级用法

def debug(func):
    """调试装饰器,打印参数和返回值"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        args_repr = [repr(a) for a in args]
        kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
        signature = ", ".join(args_repr + kwargs_repr)
        print(f"Calling {func.__name__}({signature})")
        result = func(*args, **kwargs)
        print(f"{func.__name__} returned {result!r}")
        return result
    return wrapper

@debug
def add(a, b):
    return a + b

print(add(3, 4))
# Calling add(3, 4)
# add returned 7
# 7


# === 6. functools.reduce ===
# 已在 5.2 节讲解,这里补充高级用法

# 用 reduce 实现 pipeline
def pipeline(data, *funcs):
    return functools.reduce(lambda acc, f: f(acc), funcs, data)

result = pipeline(
    "  Hello World  ",
    str.strip,
    str.lower,
    lambda s: s.replace(" ", "-"),
)
print(result)  # 'hello-world'

核心差异

特性JavaPython
偏函数手动实现functools.partial
缓存Guava/Caffeine(需配置)@lru_cache 一行搞定
泛型函数方法重载(编译期)@singledispatch(运行期)
自动生成比较Comparable 接口@total_ordering

常见陷阱

# 陷阱1:lru_cache 的参数必须是可哈希的
@lru_cache(maxsize=128)
def process(data):
    return data

# process([1, 2, 3])  # TypeError: unhashable type: 'list'
process((1, 2, 3))    # OK — 元组可哈希

# 陷阱2:lru_cache 会"隐藏"参数变化
@lru_cache(maxsize=None)
def query_db(sql):
    # 如果数据库数据变了,缓存不会自动失效
    return f"Result of {sql}"

# 需要手动 cache_clear() 或设置合理的 maxsize

# 陷阱3:singledispatch 只检查第一个参数的类型
@singledispatch
def foo(a, b):
    return "default"

@foo.register(str)
def _(a, b):
    return f"str: {a}, {b}"

print(foo("hello", 42))  # str: hello, 42 — 根据 a 的类型分派
print(foo(42, "hello"))  # default — 第一个参数不是 str

何时使用

  • partial:需要复用函数但固定部分参数时
  • lru_cache:纯函数、递归、重复计算——几乎总是好的选择
  • singledispatch:需要根据类型做不同处理时(替代 if-elif-isinstance 链)
  • total_ordering:自定义类需要完整比较操作时

5.6 itertools: 无限迭代器与组合工具

Java/Kotlin 对比

// Java Stream: 所有操作都是有限的
// 没有 count(), cycle(), repeat() 等无限流
// 组合操作需要手动实现或用 Guava

// Java 生成排列组合
// 需要手写递归或用第三方库
// Kotlin Sequence: 也是有限的
// 没有内置的无限序列生成器(可以用 sequence { } 模拟)
// 组合操作需要手动实现

Python 实现

import itertools
from itertools import (
    count, cycle, repeat,
    chain, islice, takewhile, dropwhile,
    product, permutations, combinations, combinations_with_replacement,
    groupby,
)


# === 1. 无限迭代器 ===

# count(start=0, step=1): 无限计数
for i in islice(count(1, 2), 5):  # 必须用 islice 限制!
    print(i, end=" ")  # 1 3 5 7 9
print()

# cycle: 无限循环
colors = ["red", "green", "blue"]
for color in islice(cycle(colors), 7):
    print(color, end=" ")  # red green blue red green blue red
print()

# repeat: 无限重复
for val in islice(repeat("hello"), 3):
    print(val, end=" ")  # hello hello hello
print()

# repeat 带次数限制
print(list(repeat(0, 5)))  # [0, 0, 0, 0, 0]


# === 2. 链接与切片 ===

# chain: 连接多个可迭代对象
list1 = [1, 2, 3]
list2 = [4, 5, 6]
list3 = [7, 8, 9]
print(list(chain(list1, list2, list3)))  # [1, 2, 3, 4, 5, 6, 7, 8, 9]

# chain.from_iterable: 展平嵌套列表
nested = [[1, 2], [3, 4], [5, 6]]
print(list(chain.from_iterable(nested)))  # [1, 2, 3, 4, 5, 6]

# islice: 切片(适用于任何迭代器,包括无限的)
print(list(islice(count(), 5, 10)))  # [5, 6, 7, 8, 9]
print(list(islice("abcdefg", 2, 6, 2)))  # ['c', 'e']

# takewhile: 条件为真时取值
print(list(takewhile(lambda x: x < 5, [1, 3, 5, 2, 4, 6])))
# [1, 3] — 遇到 5 就停了,不会继续检查后面的

# dropwhile: 条件为真时跳过
print(list(dropwhile(lambda x: x < 5, [1, 3, 5, 2, 4, 6])))
# [5, 2, 4, 6] — 跳过前面的 1,3,遇到 5 开始取


# === 3. 组合工具 ===

# product: 笛卡尔积(嵌套循环的替代)
colors = ["red", "blue"]
sizes = ["S", "M", "L"]
print(list(product(colors, sizes)))
# [('red', 'S'), ('red', 'M'), ('red', 'L'), ('blue', 'S'), ('blue', 'M'), ('blue', 'L')]

# 等价于
# for c in colors:
#     for s in sizes:
#         (c, s)

# 带重复的笛卡尔积
print(list(product([1, 2], repeat=3)))
# [(1,1,1), (1,1,2), (1,2,1), (1,2,2), (2,1,1), (2,1,2), (2,2,1), (2,2,2)]

# permutations: 排列(顺序有关)
print(list(permutations([1, 2, 3], 2)))
# [(1,2), (1,3), (2,1), (2,3), (3,1), (3,2)]

# combinations: 组合(顺序无关)
print(list(combinations([1, 2, 3, 4], 2)))
# [(1,2), (1,3), (1,4), (2,3), (2,4), (3,4)]

# combinations_with_replacement: 可重复组合
print(list(combinations_with_replacement([1, 2, 3], 2)))
# [(1,1), (1,2), (1,3), (2,2), (2,3), (3,3)]


# === 4. groupby: 分组(注意:输入必须已排序) ===
data = [
    ("apple", "fruit"),
    ("banana", "fruit"),
    ("carrot", "vegetable"),
    ("date", "fruit"),
    ("eggplant", "vegetable"),
]

# 按类别分组
for key, group in groupby(data, key=lambda x: x[1]):
    print(f"{key}: {[item[0] for item in group]}")
# fruit: ['apple', 'banana']
# vegetable: ['carrot']
# fruit: ['date']  — 注意!连续相同才分组
# vegetable: ['eggplant']

# 正确用法:先排序
data_sorted = sorted(data, key=lambda x: x[1])
for key, group in groupby(data_sorted, key=lambda x: x[1]):
    print(f"{key}: {[item[0] for item in group]}")
# fruit: ['apple', 'banana', 'date']
# vegetable: ['carrot', 'eggplant']


# === 5. 实战:生成所有可能的密码组合 ===
import string

digits = string.digits  # '0123456789'

# 2位数字密码
passwords_2 = map(
    lambda combo: "".join(combo),
    product(digits, repeat=2)
)
print(list(islice(passwords_2, 5)))  # ['00', '01', '02', '03', '04']


# === 6. 实战:分批处理大数据 ===
def batched(iterable, n):
    """将数据分成每批 n 个(Python 3.12+ 有内置 batched,这里手动实现)"""
    it = iter(iterable)
    while batch := list(islice(it, n)):
        yield batch

data = range(10)
for batch in batched(data, 3):
    print(batch)
# [0, 1, 2]
# [3, 4, 5]
# [6, 7, 8]
# [9]

核心差异

特性Java StreamPython itertools
无限流不支持count, cycle, repeat
笛卡尔积flatMap 嵌套product
排列组合需第三方库permutations, combinations
分组Collectors.groupingBygroupby(需预排序)
惰性

常见陷阱

# 陷阱1:groupby 不会自动排序!
data = [1, 2, 1, 2, 1]
# 错误:直接 groupby
for key, group in groupby(data):
    print(key, list(group))
# 1 [1]
# 2 [2]
# 1 [1]  — 分成了三组!
# 2 [2]

# 正确:先排序
for key, group in groupby(sorted(data)):
    print(key, list(group))
# 1 [1, 1, 1]
# 2 [2, 2]

# 陷阱2:无限迭代器必须配合 islice 使用
# for i in count():  # 死循环!
#     print(i)

# 陷阱3:groupby 返回的 group 是一次性迭代器
groups = {}
for key, group in groupby(sorted(data)):
    groups[key] = list(group)  # 必须立即消费!

何时使用

  • count/cycle/repeat:需要无限序列时(配合 islice
  • product/permutations/combinations:排列组合、暴力搜索
  • chain:合并多个可迭代对象
  • groupby:有序数据的分组聚合

5.7 生成器: yield 与惰性求值

Java/Kotlin 对比

// Java: 没有生成器等价物
// 最接近的是 Stream.generate(),但功能有限
// 自定义惰性序列需要实现 Iterator 接口,样板代码多

List<Integer> naturals = Stream.generate(() -> 1)
    .limit(5)
    .collect(Collectors.toList());
// Kotlin: Sequence 是最接近的概念
val naturals = sequence {
    var n = 0
    while (true) {
        yield(n++)
    }
}
naturals.take(5).toList()  // [0, 1, 2, 3, 4]

Python 实现

# === 1. yield 基础 ===
def countdown(n):
    """倒计时生成器"""
    while n > 0:
        yield n
        n -= 1

# 生成器是惰性的——不会立即执行
gen = countdown(5)
print(type(gen))  # <class 'generator'>

# 每次调用 next() 执行到下一个 yield
print(next(gen))  # 1
print(next(gen))  # 4
print(next(gen))  # 3

# for 循环消费剩余
for value in gen:
    print(value, end=" ")  # 2 1
print()


# === 2. 生成器 vs 列表:内存优势 ===
def fibonacci_generator(limit):
    """惰性生成斐波那契数列"""
    a, b = 0, 1
    while a < limit:
        yield a
        a, b = b, a + b

# 内存占用:生成器 O(1),列表 O(n)
# 生成 100 万个斐波那契数
import sys

gen = fibonacci_generator(10**100)
print(sys.getsizeof(gen))  # ~200 bytes — 固定大小!

# 对比列表
small_list = list(fibonacci_generator(10000))
print(sys.getsizeof(small_list))  # ~85KB — 随数据增长


# === 3. yield from: 委托给子生成器 ===
def upper_letters():
    for c in "ABC":
        yield c

def lower_letters():
    for c in "def":
        yield c

def all_letters():
    yield from upper_letters()  # 委托
    yield from lower_letters()  # 委托
    yield from "123"            # 任何可迭代对象

print(list(all_letters()))  # ['A', 'B', 'C', 'd', 'e', 'f', '1', '2', '3']

# yield from 的实际价值:递归数据结构
def flatten(nested):
    """展平嵌套列表"""
    for item in nested:
        if isinstance(item, (list, tuple)):
            yield from flatten(item)  # 递归委托
        else:
            yield item

data = [1, [2, 3], [4, [5, 6]], 7]
print(list(flatten(data)))  # [1, 2, 3, 4, 5, 6, 7]


# === 4. 生成器表达式 vs 列表推导式 ===
nums = range(10)

# 列表推导式:立即求值,返回列表
squares_list = [x ** 2 for x in nums]
print(type(squares_list))  # <class 'list'>

# 生成器表达式:惰性求值,返回生成器
squares_gen = (x ** 2 for x in nums)
print(type(squares_gen))   # <class 'generator'>

# 语法区别:生成器表达式用圆括号
# 列表推导式用方括号

# 生成器表达式的内存优势
import sys
list_comp = [x ** 2 for x in range(1000000)]
gen_expr = (x ** 2 for x in range(1000000))
print(sys.getsizeof(list_comp))  # ~8.5MB
print(sys.getsizeof(gen_expr))   # ~200 bytes

# 生成器表达式作为函数参数可以省略括号
print(sum(x ** 2 for x in range(5)))  # 30
# 等价于 sum((x ** 2 for x in range(5)))
# 但不需要额外的括号


# === 5. 无限生成器 ===
def natural_numbers():
    """所有自然数(无限)"""
    n = 0
    while True:
        yield n
        n += 1

def primes():
    """素数生成器(无限)"""
    yield 2
    composites = {}
    n = 3
    while True:
        if n not in composites:
            yield n
            composites[n * n] = [n]
        else:
            for p in composites[n]:
                composites.setdefault(p + n, []).append(p)
            del composites[n]
        n += 2

# 取前 10 个素数
from itertools import islice
print(list(islice(primes(), 10)))
# [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]


# === 6. 生成器的高级用法:管道 ===
def read_lines(filename):
    """模拟逐行读取文件"""
    lines = ["line1\n", "line2\n", "line3\n", ""]
    for line in lines:
        yield line

def strip_lines(lines):
    for line in lines:
        yield line.strip()

def filter_empty(lines):
    for line in lines:
        if line:
            yield line

def upper_lines(lines):
    for line in lines:
        yield line.upper()

# 管道:惰性处理,内存 O(1)
pipeline = upper_lines(filter_empty(strip_lines(read_lines("dummy"))))
print(list(pipeline))  # ['LINE1', 'LINE2', 'LINE3']


# === 7. 生成器方法 ===
def gen_demo():
    yield 1
    yield 2
    yield 3

g = gen_demo()
print(next(g))   # 1
print(g.send(None))  # 2 — send(None) 等价于 next(g)
g.close()       # 关闭生成器
# next(g)        # StopIteration

# throw: 向生成器抛出异常
def gen_with_throw():
    try:
        yield 1
        yield 2
        yield 3
    except ValueError:
        yield "caught!"

g = gen_with_throw()
print(next(g))           # 1
print(g.throw(ValueError))  # caught!
print(next(g))           # StopIteration

核心差异

特性JavaKotlin SequencePython Generator
创建方式实现 Iteratorsequence { }yield
惰性求值
委托需手动yieldAll()yield from
生成器表达式(x for x in ...)
双向通信send() / throw()

常见陷阱

# 陷阱1:生成器只能消费一次
gen = (x for x in range(3))
print(list(gen))  # [0, 1, 2]
print(list(gen))  # [] — 已耗尽

# 陷阱2:在生成器函数中 return 的值
def gen_with_return():
    yield 1
    yield 2
    return "done"  # 触发 StopIteration,值成为异常的 value

g = gen_with_return()
print(next(g))  # 1
print(next(g))  # 2
try:
    next(g)
except StopIteration as e:
    print(e.value)  # 'done'

# 陷阱3:不要对生成器调用 len()
gen = (x for x in range(5))
# len(gen)  # TypeError: object of type 'generator' has no len()

何时使用

  • 处理大数据集时(文件、数据库结果、网络流)
  • 需要惰性求值避免不必要的计算时
  • 管道式数据处理时
  • 无限序列时(必须用生成器)
  • 原则:如果数据量大或可能无限,用生成器;如果数据小且需要多次访问,用列表

5.8 async/await 协程基础

Java/Kotlin 对比

// Java 21+: Virtual Threads(虚拟线程)
// 本质还是线程模型,不是协程
try (var executor = Executors.newVirtualThreadPerTaskExecutor()) {
    Future<String> future = executor.submit(() -> {
        Thread.sleep(1000);
        return "result";
    });
    String result = future.get();
}

// Java 的异步是 CompletableFuture
CompletableFuture.supplyAsync(() -> fetchData())
    .thenApply(data -> process(data))
    .thenAccept(result -> System.out.println(result));
// Kotlin: 协程是一等语言特性
suspend fun fetchData(): String {
    delay(1000)  // 挂起函数,不阻塞线程
    return "data"
}

// launch: 启动协程
// async: 启动协程并返回 Deferred
// Flow: 冷流,类似 Python 的异步生成器

Python 实现

import asyncio


# === 1. async def 定义协程 ===
async def hello():
    """协程函数"""
    print("Hello")
    await asyncio.sleep(0.1)  # 挂起,不阻塞事件循环
    print("World")

# 注意:调用协程函数不会执行!
coro = hello()
print(type(coro))  # <class 'coroutine'>

# 必须用 asyncio.run() 或 await 执行
# asyncio.run(hello())  # Hello → World


# === 2. await 等待协程 ===
async def fetch_data(url: str) -> str:
    """模拟异步 HTTP 请求"""
    print(f"Fetching {url}...")
    await asyncio.sleep(1)  # 模拟网络延迟
    return f"Data from {url}"

async def process():
    """串行执行"""
    data1 = await fetch_data("/api/users")
    data2 = await fetch_data("/api/posts")
    print(data1)
    print(data2)

# asyncio.run(process())
# 总耗时 ~2s(串行)


# === 3. 并发执行: asyncio.gather ===
async def process_concurrent():
    """并发执行"""
    results = await asyncio.gather(
        fetch_data("/api/users"),
        fetch_data("/api/posts"),
        fetch_data("/api/comments"),
    )
    for r in results:
        print(r)

# asyncio.run(process_concurrent())
# 总耗时 ~1s(并发)


# === 4. asyncio.create_task: 后台任务 ===
async def background_job(name: str, seconds: float):
    print(f"{name} started")
    await asyncio.sleep(seconds)
    print(f"{name} finished after {seconds}s")
    return f"{name} result"

async def main():
    # 创建后台任务
    task1 = asyncio.create_task(background_job("Job-A", 0.5))
    task2 = asyncio.create_task(background_job("Job-B", 0.3))

    # 做其他事情
    print("Doing other work...")
    await asyncio.sleep(0.1)
    print("Other work done")

    # 等待任务完成
    result1 = await task1
    result2 = await task2
    print(result1, result2)

# asyncio.run(main())


# === 5. async for: 异步迭代 ===
async def async_range(n: int):
    """异步生成器"""
    for i in range(n):
        await asyncio.sleep(0.1)
        yield i

async def consume_async():
    async for value in async_range(3):
        print(value, end=" ")  # 0 1 2
    print()

# asyncio.run(consume_async())


# === 6. async with: 异步上下文管理器 ===
class AsyncResource:
    async def __aenter__(self):
        print("Acquiring resource...")
        await asyncio.sleep(0.1)
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        print("Releasing resource...")
        await asyncio.sleep(0.1)
        return False

async def use_resource():
    async with AsyncResource() as resource:
        print("Using resource")
        await asyncio.sleep(0.1)

# asyncio.run(use_resource())


# === 7. 超时控制 ===
async def with_timeout():
    try:
        result = await asyncio.wait_for(
            fetch_data("/api/slow"),
            timeout=0.5  # 0.5秒超时
        )
        print(result)
    except asyncio.TimeoutError:
        print("Request timed out!")

# asyncio.run(with_timeout())


# === 8. 实战:并发下载 ===
async def download(url: str) -> tuple[str, str]:
    """模拟下载"""
    await asyncio.sleep(0.1 * len(url))  # 模拟延迟
    return url, f"content of {url}"

async def download_all(urls: list[str]):
    """并发下载所有 URL"""
    tasks = [asyncio.create_task(download(url)) for url in urls]
    results = await asyncio.gather(*tasks)
    return dict(results)

async def main_download():
    urls = ["/a", "/bb", "/ccc", "/dddd"]
    start = asyncio.get_event_loop().time()
    result = await download_all(urls)
    elapsed = asyncio.get_event_loop().time() - start
    print(f"Downloaded {len(result)} URLs in {elapsed:.2f}s")
    print(result)

# asyncio.run(main_download())

核心差异

特性Java Virtual ThreadsKotlin CoroutinesPython async/await
模型线程(轻量级)协程(挂起)协程(事件循环)
阻塞操作需要虚拟线程suspend 函数需要 async 版本库
并发原语ExecutorServicelaunch/asynccreate_task/gather
取消interruptcancelcancel
生态JDK 内置kotlinx.coroutinesasyncio

常见陷阱

# 陷阱1:在协程中调用阻塞函数会阻塞整个事件循环
import time

async def bad():
    time.sleep(1)  # 阻塞!整个事件循环都停了
    return "done"

# 修复:用 asyncio.sleep 或 run_in_executor
async def good():
    await asyncio.sleep(1)  # 非阻塞挂起
    return "done"

# 阻塞代码用线程池
async def run_blocking():
    loop = asyncio.get_event_loop()
    result = await loop.run_in_executor(None, lambda: time.sleep(1) or "done")
    return result

# 陷阱2:忘记 await
async def fetch():
    return "data"

async def bad_call():
    result = fetch()  # 没有 await!result 是协程对象,不是数据
    print(result)     # <coroutine object fetch at 0x...>

async def good_call():
    result = await fetch()  # 正确
    print(result)           # 'data'

# 陷阱3:asyncio.run() 不能嵌套
# 如果已经在异步上下文中,不要再次调用 asyncio.run()
# 用 await asyncio.gather() 或 create_task() 代替

何时使用

  • I/O 密集型任务:HTTP 请求、数据库查询、文件操作
  • 需要并发但不需要多线程的复杂度时
  • WebSocket、长轮询等实时通信
  • CPU 密集型任务不要用 asyncio,用 multiprocessing

5.9 contextlib: 上下文管理器

Java/Kotlin 对比

// Java: try-with-resources(需要实现 AutoCloseable)
try (BufferedReader reader = new BufferedReader(new FileReader("file.txt"))) {
    String line;
    while ((line = reader.readLine()) != null) {
        System.out.println(line);
    }
} // reader.close() 自动调用

// 自定义资源
public class MyResource implements AutoCloseable {
    @Override
    public void close() throws Exception {
        System.out.println("Resource closed");
    }
}
// Kotlin: use{} 扩展函数
File("file.txt").bufferedReader().use { reader ->
    reader.lineSequence().forEach { println(it) }
}
// 自动调用 close()

Python 实现

from contextlib import contextmanager, suppress, redirect_stdout, redirect_stderr
import io


# === 1. with 语句基础 ===
# Python 的 with 等价于 Java 的 try-with-resources

# 内置示例
with open("/tmp/demo.txt", "w") as f:
    f.write("hello")

# 等价于
f = open("/tmp/demo.txt", "w")
try:
    f.write("hello")
finally:
    f.close()


# === 2. 自定义上下文管理器(类方式) ===
class DatabaseConnection:
    def __init__(self, url: str):
        self.url = url
        self.connected = False

    def __enter__(self):
        """进入 with 块时调用"""
        print(f"Connecting to {self.url}")
        self.connected = True
        return self  # 返回值绑定到 as 后的变量

    def __exit__(self, exc_type, exc_val, exc_tb):
        """退出 with 块时调用(包括异常)"""
        print(f"Disconnecting from {self.url}")
        self.connected = False
        # 返回 True 会抑制异常,返回 False(或 None)异常会继续传播
        if exc_type is not None:
            print(f"Exception occurred: {exc_val}")
        return False  # 不抑制异常

    def query(self, sql: str):
        if not self.connected:
            raise RuntimeError("Not connected")
        return f"Result of: {sql}"

with DatabaseConnection("postgresql://localhost/mydb") as db:
    print(db.query("SELECT * FROM users"))
    # Connecting to postgresql://localhost/mydb
    # Result of: SELECT * FROM users
    # Disconnecting from postgresql://localhost/mydb


# === 3. @contextmanager: 用函数创建上下文管理器 ===
@contextmanager
def timer(name: str):
    """计时上下文管理器"""
    import time
    start = time.perf_counter()
    try:
        yield  # yield 之前的代码 = __enter__,之后的代码 = __exit__
    finally:
        elapsed = time.perf_counter() - start
        print(f"[{name}] {elapsed:.4f}s")

with timer("data processing"):
    import time
    time.sleep(0.1)
# [data processing] 0.1001s


# === 4. @contextmanager 带返回值 ===
@contextmanager
def temporary_file(content: str):
    """创建临时文件的上下文管理器"""
    import tempfile
    import os

    fd, path = tempfile.mkstemp()
    try:
        os.write(fd, content.encode())
        os.close(fd)
        yield path  # 返回文件路径
    finally:
        os.unlink(path)  # 确保删除

with temporary_file("hello world") as path:
    print(f"File at: {path}")
    with open(path) as f:
        print(f.read())
# File at: /tmp/tmpXXXXXX
# hello world


# === 5. @contextmanager 处理异常 ===
@contextmanager
def assert_no_exception():
    """如果 with 块中抛出异常,重新抛出包装后的异常"""
    try:
        yield
    except Exception as e:
        raise RuntimeError(f"Caught in context: {e}") from e

try:
    with assert_no_exception():
        raise ValueError("original error")
except RuntimeError as e:
    print(f"Caught: {e}")
    print(f"Caused by: {e.__cause__}")
# Caught: Caught in context: original error
# Caused by: original error


# === 6. contextlib.suppress: 抑制异常 ===
# 等价于 Java 的空 catch 块,但更简洁

# Java 风格
try:
    os.remove("/tmp/nonexistent")
except FileNotFoundError:
    pass

# Python 风格
with suppress(FileNotFoundError):
    os.remove("/tmp/nonexistent")

# 抑制多种异常
with suppress(FileNotFoundError, PermissionError):
    os.remove("/tmp/nonexistent")


# === 7. contextlib.redirect_stdout / redirect_stderr ===
# 重定向标准输出(测试时非常有用)

f = io.StringIO()
with redirect_stdout(f):
    print("This goes to the buffer")
    print("Not to the console")

output = f.getvalue()
print(f"Captured: {output!r}")
# Captured: 'This goes to the buffer\nNot to the console\n'


# === 8. 嵌套上下文管理器 ===
@contextmanager
def acquire_lock(name: str):
    print(f"Lock {name} acquired")
    try:
        yield
    finally:
        print(f"Lock {name} released")

# Python 3.10+ 的写法
with (
    acquire_lock("A"),
    acquire_lock("B"),
):
    print("Both locks held")
# Lock A acquired
# Lock B acquired
# Both locks held
# Lock B released
# Lock A released

# ExitStack: 动态管理多个上下文
from contextlib import ExitStack

with ExitStack() as stack:
    stack.enter_context(acquire_lock("X"))
    stack.enter_context(acquire_lock("Y"))
    print("X and Y held")
# Lock X acquired
# Lock Y acquired
# X and Y held
# Lock Y released
# Lock X released


# === 9. 实战:数据库事务上下文 ===
@contextmanager
def transaction(connection):
    """模拟数据库事务"""
    print("BEGIN TRANSACTION")
    try:
        yield connection
        print("COMMIT")
    except Exception:
        print("ROLLBACK")
        raise

class MockConnection:
    def execute(self, sql):
        print(f"  EXEC: {sql}")

conn = MockConnection()
try:
    with transaction(conn):
        conn.execute("INSERT INTO users VALUES (1, 'Alice')")
        conn.execute("INSERT INTO orders VALUES (1, 100)")
        raise ValueError("Oops!")  # 触发回滚
except ValueError:
    pass
# BEGIN TRANSACTION
#   EXEC: INSERT INTO users VALUES (1, 'Alice')
#   EXEC: INSERT INTO orders VALUES (1, 100)
# ROLLBACK

核心差异

特性Java try-with-resourcesKotlin use{}Python with
接口要求AutoCloseableCloseable__enter__/__exit__
异常处理catch 块try/catch__exit__ 返回值
创建方式类 或 @contextmanager
动态管理不支持不支持ExitStack
抑制异常空 catch空 catchsuppress()

常见陷阱

# 陷阱1:@contextmanager 中 yield 只能出现一次
@contextmanager
def bad():
    yield 1
    yield 2  # RuntimeError: generator didn't stop

# 陷阱2:__exit__ 的返回值
class SilentError:
    def __exit__(self, *args):
        return True  # 抑制所有异常!

with SilentError():
    raise ValueError("This is silently ignored")  # 不会抛出!
print("No error raised")  # 会执行

# 陷阱3:with 块中的变量不会泄漏(Python 3 的改进)
with open("/tmp/demo.txt", "w") as f:
    content = f.write("test")
# f 在这里已经关闭,但变量 f 仍然存在(只是文件已关闭)
# 不要在 with 块外使用 f

何时使用

  • 资源管理:文件、数据库连接、网络连接、锁
  • 临时状态变更:环境变量、工作目录、日志级别
  • 测试:redirect_stdout、临时文件、mock
  • 事务管理:数据库事务、原子操作
  • 原则:任何需要"获取-使用-释放"模式的场景都用上下文管理器

本章总结

概念JVM 等价物Python 优势关键陷阱
一等函数Kotlin 函数类型函数可携带属性lambda 延迟绑定
高阶函数Stream APIzip/enumerate 原生支持迭代器只能消费一次
闭包Kotlin lambdanonlocal 显式声明循环变量绑定
装饰器无直接等价物元编程利器三层嵌套结构
functoolsGuava/Caffeine@lru_cache 一行缓存缓存参数须可哈希
itertools有限无限迭代器+组合工具groupby 需预排序
生成器Kotlin Sequenceyield/yield from只能消费一次
async/awaitVirtual Threads/Kotlin协程内置事件循环阻塞调用会卡死
上下文管理器try-with-resources@contextmanager__exit__ 返回值

核心原则:Python 的函数式编程不是要你写纯函数式代码,而是给你更强大的工具来表达意图。装饰器和生成器是 Python 最独特的两个特性,掌握它们是从"会写 Python"到"写好 Python"的关键分水岭。