第5章: 函数式编程 — Python的函数式风格
Java/Kotlin 开发者习惯了 Stream API 和 lambda 表达式,但 Python 的函数式编程走的是完全不同的路。Python 没有函数类型签名、没有受检异常、lambda 只能写单行——但 Python 有一等函数、装饰器、生成器、functools、itertools 这些 JVM 世界不存在的利器。装饰器是 Python 最强大的元编程手段,生成器是惰性求值的底层原语,itertools 提供了 Java Stream 望尘莫及的组合能力。 本章从一等函数出发,逐步拆解 Python 函数式编程的全部核心。
5.1 一等函数: 函数是对象
Java/Kotlin 对比
import java.util.function.*;
Function<Integer, Integer> doubleIt = x -> x * 2;
Consumer<String> printer = System.out::println;
val doubleIt: (Int) -> Int = { x -> x * 2 }
val greet: (String) -> Unit = { println("Hi, $it") }
val ops = mapOf(
"double" to { x: Int -> x * 2 },
"square" to { x: Int -> x * x }
)
Python 实现
def greet(name: str) -> str:
return f"Hello, {name}"
print(type(greet))
print(greet.__name__)
print(id(greet))
say_hi = greet
print(say_hi("Alice"))
ops = {
"double": lambda x: x * 2,
"square": lambda x: x ** 2,
"negate": lambda x: -x,
}
print(ops["double"](5))
print(ops["square"](3))
pipeline = [
lambda x: x.strip(),
lambda x: x.lower(),
lambda x: x.replace(" ", "-"),
]
text = " Hello World "
result = text
for fn in pipeline:
result = fn(result)
print(result)
def apply(fn, value):
return fn(value)
print(apply(lambda x: x * 2, 5))
print(apply(str.upper, "hello"))
def power(exp):
"""返回一个求幂函数"""
return lambda base: base ** exp
square = power(2)
cube = power(3)
print(square(5))
print(cube(3))
def important_func():
"""一个重要的函数"""
pass
important_func.version = "1.0"
important_func.author = "Alice"
important_func.tags = ["math", "utils"]
print(important_func.version)
print(important_func.tags)
def foo():
pass
bar = foo
print(foo is bar)
print(foo == bar)
def baz():
pass
print(foo == baz)
核心差异
| 特性 | Java | Kotlin | Python |
|---|
| 函数是一等对象 | 否(通过函数式接口模拟) | 是 | 是 |
| 函数可携带属性 | 否 | 否 | 是 |
| 函数类型 | Function<T,R> 等接口 | (T) -> R | Callable 协议 |
| lambda 限制 | 必须匹配函数式接口 | 无 | 只能单行表达式 |
常见陷阱
funcs = [lambda: i for i in range(5)]
print([f() for f in funcs])
funcs = [lambda i=i: i for i in range(5)]
print([f() for f in funcs])
何时使用
- 需要把行为作为数据传递时(策略模式、回调)
- 构建函数管道/处理链时
- 需要给函数附加元数据时(Python 独有优势)
5.2 高阶函数: map, filter, sorted
Java/Kotlin 对比
import java.util.*;
import java.util.stream.*;
List<Integer> nums = List.of(1, 2, 3, 4, 5);
List<Integer> result = nums.stream()
.filter(n -> n % 2 == 0)
.map(n -> n * n)
.collect(Collectors.toList());
List<String> names = List.of("Charlie", "Alice", "Bob");
List<String> sorted = names.stream()
.sorted(Comparator.comparingInt(String::length))
.collect(Collectors.toList());
Stream<Integer> s = nums.stream();
s.filter(n -> n > 2);
s.map(n -> n * 2);
val nums = listOf(1, 2, 3, 4, 5)
val result = nums.filter { it % 2 == 0 }.map { it * it }
val result2 = nums.asSequence()
.filter { it % 2 == 0 }
.map { it * it }
.toList()
Python 实现
nums = [1, 2, 3, 4, 5]
squares = map(lambda x: x ** 2, nums)
print(type(squares))
print(list(squares))
names = ["alice", "bob", "charlie"]
upper = map(str.upper, names)
print(list(upper))
a = [1, 2, 3]
b = [10, 20, 30]
print(list(map(lambda x, y: x + y, a, b)))
squares = [x ** 2 for x in nums]
print(squares)
evens = filter(lambda x: x % 2 == 0, nums)
print(list(evens))
evens = [x for x in nums if x % 2 == 0]
print(evens)
words = ["banana", "apple", "cherry", "date"]
print(sorted(words))
print(sorted(words, key=len))
print(sorted(words, key=lambda w: w[-1]))
print(sorted(words, reverse=True))
data = [("alice", 30), ("bob", 25), ("charlie", 25)]
print(sorted(data, key=lambda x: (x[1], x[0])))
words_copy = words[:]
words_copy.sort(key=len)
print(words_copy)
from functools import reduce
nums = [1, 2, 3, 4, 5]
total = reduce(lambda acc, x: acc + x, nums)
print(total)
total = reduce(lambda acc, x: acc + x, nums, 0)
print(total)
empty_result = reduce(lambda acc, x: acc + x, [], 0)
print(empty_result)
print(sum(nums))
nums = [1, 2, 3, 4, 5]
print(any(x > 10 for x in nums))
print(any(x > 3 for x in nums))
print(all(x > 0 for x in nums))
print(all(x > 2 for x in nums))
names = ["alice", "bob", "charlie"]
scores = [95, 87, 92]
print(list(zip(names, scores)))
print(dict(zip(names, scores)))
print(list(zip([1, 2, 3], [10, 20])))
for i, name in enumerate(names):
print(f"{i}: {name}")
for i, name in enumerate(names, start=1):
print(f"{i}: {name}")
核心差异
| 特性 | Java Stream | Kotlin | Python |
|---|
| 惰性求值 | 是(直到终端操作) | 默认非惰性,asSequence() 惰性 | map/filter 返回迭代器,惰性 |
| 可复用 | 否(只能消费一次) | 集合可复用,Sequence 只能一次 | 迭代器只能消费一次,列表可复用 |
| 多参数 map | mapToObj 等 | zip + map | map 原生支持多序列 |
| 并行 | .parallelStream() | 协程/Flow | concurrent.futures |
常见陷阱
m = map(lambda x: x * 2, [1, 2, 3])
print(list(m))
print(list(m))
result = list(map(lambda x: x * 2, [1, 2, 3]))
print(result)
print(result)
words = ["b", "a", "c"]
sorted(words)
print(words)
words.sort()
print(words)
何时使用
map/filter:简单转换时可用,但列表推导式通常更 Pythonic
sorted:需要保持原列表不变时
any/all:条件检查,比手写循环清晰
zip/enumerate:并行迭代和带索引迭代,极其常用
5.3 闭包与 nonlocal
Java/Kotlin 对比
int factor = 2;
Function<Integer, Integer> multiplier = x -> x * factor;
var count = 0
val inc: () -> Unit = { count++ }
inc()
inc()
println(count)
Python 实现
def make_multiplier(factor):
"""返回一个乘法函数,factor 被闭包捕获"""
def multiply(x):
return x * factor
return multiply
double = make_multiplier(2)
triple = make_multiplier(3)
print(double(5))
print(triple(5))
print(double.__closure__)
print(double.__closure__[0].cell_contents)
def make_funcs():
funcs = []
for i in range(3):
def f():
return i
funcs.append(f)
return funcs
funcs = make_funcs()
print([f() for f in funcs])
def make_funcs_fixed():
funcs = []
for i in range(3):
def f(i=i):
return i
funcs.append(f)
return funcs
print([f() for f in make_funcs_fixed()])
def make_func(i):
def f():
return i
return f
def make_funcs_fixed2():
return [make_func(i) for i in range(3)]
print([f() for f in make_funcs_fixed2()])
def make_counter():
count = 0
def increment():
nonlocal count
count += 1
return count
def get_count():
return count
return increment, get_count
inc, get = make_counter()
print(inc())
print(inc())
print(inc())
print(get())
def broken_counter():
count = 0
def increment():
pass
return increment
def make_accumulator(initial=0):
total = initial
def add(value):
nonlocal total
total += value
return total
def reset():
nonlocal total
total = initial
return total
return {"add": add, "reset": reset, "get": lambda: total}
acc = make_accumulator(100)
print(acc["add"](10))
print(acc["add"](20))
print(acc["reset"]())
print(acc["add"](5))
cache = {}
def memoize_global(key, value):
global cache
cache[key] = value
return value
memoize_global("a", 1)
print(cache)
核心差异
| 特性 | Java | Kotlin | Python |
|---|
| 捕获 effectively final | 强制 | 不强制 | 不强制 |
| 修改捕获的变量 | 不允许 | 允许 | 需要 nonlocal 声明 |
| 闭包变量绑定 | 值捕获 | 引用捕获 | 引用捕获 |
| 查看闭包变量 | 不支持 | 不支持 | __closure__ 属性 |
常见陷阱
x = 10
def foo():
x = 20
return x
print(foo())
print(x)
x = 10
def bar():
x = 20
return x
x = 10
def outer():
x = 20
def inner():
nonlocal x
x = 30
inner()
print(x)
outer()
print(x)
何时使用
- 闭包:需要封装状态但不想创建完整类时(轻量级替代)
nonlocal:闭包中需要修改外层变量时
global:几乎不需要用,模块级状态用类管理更好
5.4 装饰器: Python 的杀手级特性
Java/Kotlin 对比
@Around("execution(* com.example.service.*.*(..))")
public Object log(ProceedingJoinPoint pjp) throws Throwable {
System.out.println("Before: " + pjp.getSignature());
Object result = pjp.proceed();
System.out.println("After");
return result;
}
fun <T> withLogging(fn: () -> T): T {
println("Before")
val result = fn()
println("After")
return result
}
Python 实现
import functools
import time
def timer(func):
"""计时装饰器"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
print(f"{func.__name__} took {elapsed:.4f}s")
return result
return wrapper
@timer
def slow_add(a, b):
time.sleep(0.1)
return a + b
print(slow_add(1, 2))
def retry(max_attempts=3, delay=1.0):
"""重试装饰器工厂"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"Attempt {attempt} failed: {e}")
if attempt == max_attempts:
raise
time.sleep(delay)
return wrapper
return decorator
@retry(max_attempts=3, delay=0.1)
def unstable_api():
import random
if random.random() < 0.7:
raise ConnectionError("Network error")
return "success"
print(unstable_api())
class CountCalls:
"""统计函数调用次数"""
def __init__(self, func):
self.func = func
self.count = 0
functools.update_wrapper(self, func)
def __call__(self, *args, **kwargs):
self.count += 1
print(f"Call #{self.count} of {self.func.__name__}")
return self.func(*args, **kwargs)
@CountCalls
def say_hello(name):
return f"Hello, {name}"
print(say_hello("Alice"))
print(say_hello("Bob"))
print(say_hello.count)
def bold(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return f"**{func(*args, **kwargs)}**"
return wrapper
def italic(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return f"*{func(*args, **kwargs)}*"
return wrapper
@bold
@italic
def greet(name):
return f"Hello, {name}"
print(greet("Alice"))
def bad_decorator(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@bad_decorator
def my_function():
"""This is my function."""
pass
print(my_function.__name__)
print(my_function.__doc__)
def good_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@good_decorator
def my_function2():
"""This is my function."""
pass
print(my_function2.__name__)
print(my_function2.__doc__)
class Validate:
"""参数验证装饰器"""
def __init__(self, **rules):
self.rules = rules
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for param_name, (expected_type, required) in self.rules.items():
value = kwargs.get(param_name)
if required and value is None:
raise ValueError(f"Missing required parameter: {param_name}")
if value is not None and not isinstance(value, expected_type):
raise TypeError(
f"Parameter {param_name} must be {expected_type}, got {type(value)}"
)
return func(*args, **kwargs)
return wrapper
@Validate(name=(str, True), age=(int, False))
def create_user(name, age=None):
return {"name": name, "age": age}
print(create_user(name="Alice", age=30))
class MathUtils:
@staticmethod
def add(a, b):
return a + b
print(MathUtils.add(1, 2))
class Database:
_instance = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset(cls):
cls._instance = None
db1 = Database.get_instance()
db2 = Database.get_instance()
print(db1 is db2)
from dataclasses import dataclass, field
@dataclass
class Point:
x: float
y: float
label: str = "origin"
def distance_to(self, other):
return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2) ** 0.5
p1 = Point(3, 4)
p2 = Point(0, 0, "center")
print(p1)
print(p1 == Point(3, 4))
print(p1.distance_to(p2))
@dataclass
class Order:
items: list[str] = field(default_factory=list)
total: float = 0.0
order = Order(items=["apple", "banana"])
print(order)
核心差异
| 特性 | Java 注解 | Python 装饰器 |
|---|
| 执行时机 | 编译期/运行期(被动) | 定义时立即执行(主动) |
| 代码变换 | 需要注解处理器/AOP | 直接修改函数对象 |
| 参数化 | 注解属性 | 装饰器工厂 |
| 组合 | 多注解可叠加 | 可堆叠,有执行顺序 |
| 元信息保留 | 反射 API | @functools.wraps |
常见陷阱
def broken_decorator(func):
return None
@broken_decorator
def foo():
pass
def with_param(arg):
def decorator(func):
def wrapper(*a, **kw):
return func(*a, **kw)
return wrapper
return decorator
何时使用
- 日志、计时、缓存、权限检查:横切关注点
- 注册机制:
@app.route、@pytest.fixture
- 参数验证:输入校验
- 单例/缓存:设计模式
- 原则:装饰器应该透明,被装饰函数的行为不应被意外改变
5.5 functools: partial, lru_cache, singledispatch
Java/Kotlin 对比
fun power(base: Int, exp: Int): Int = Math.pow(base.toDouble(), exp.toDouble()).toInt()
val square = { x: Int -> power(x, 2) }
Python 实现
import functools
from functools import partial, lru_cache, singledispatch, total_ordering, wraps
def power(base, exp):
return base ** exp
square = partial(power, exp=2)
cube = partial(power, exp=3)
print(square(5))
print(cube(3))
double = partial(power, 2)
print(double(10))
import urllib.request
def fetch(url, timeout, headers):
return f"Fetching {url} with timeout={timeout}"
safe_fetch = partial(fetch, timeout=30, headers={"User-Agent": "MyApp"})
print(safe_fetch("https://example.com"))
print(square.func)
print(square.args)
print(square.keywords)
@lru_cache(maxsize=128)
def fibonacci(n):
if n < 2:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
print(fibonacci(100))
print(fibonacci(100))
print(fibonacci.cache_info())
fibonacci.cache_clear()
@lru_cache(maxsize=None)
def expensive_computation(x):
print(f"Computing for {x}...")
return x * x
print(expensive_computation(5))
print(expensive_computation(5))
@lru_cache(maxsize=1000)
def get_user(user_id):
"""模拟数据库查询"""
print(f"Querying DB for user {user_id}")
return {"id": user_id, "name": f"User{user_id}"}
print(get_user(1))
print(get_user(1))
print(get_user(2))
@singledispatch
def process(data):
raise NotImplementedError(f"Cannot process {type(data)}")
@process.register(int)
def _(data):
return f"Integer: {data * 2}"
@process.register(str)
def _(data):
return f"String: {data.upper()}"
@process.register(list)
def _(data):
return f"List: {len(data)} items"
@process.register(dict)
def _(data):
return f"Dict: {len(data)} keys"
print(process(42))
print(process("hello"))
print(process([1, 2]))
print(process({"a": 1}))
@process.register(float)
@process.register(int)
def _(data):
return f"Number: {data}"
@total_ordering
class Version:
def __init__(self, major, minor, patch):
self.major = major
self.minor = minor
self.patch = patch
def __eq__(self, other):
return (self.major, self.minor, self.patch) == (other.major, other.minor, other.patch)
def __lt__(self, other):
return (self.major, self.minor, self.patch) < (other.major, other.minor, other.patch)
def __repr__(self):
return f"Version({self.major}.{self.minor}.{self.patch})"
v1 = Version(1, 0, 0)
v2 = Version(2, 0, 0)
v3 = Version(1, 0, 0)
print(v1 < v2)
print(v2 > v1)
print(v1 <= v2)
print(v1 >= v3)
print(v1 != v2)
def debug(func):
"""调试装饰器,打印参数和返回值"""
@wraps(func)
def wrapper(*args, **kwargs):
args_repr = [repr(a) for a in args]
kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
signature = ", ".join(args_repr + kwargs_repr)
print(f"Calling {func.__name__}({signature})")
result = func(*args, **kwargs)
print(f"{func.__name__} returned {result!r}")
return result
return wrapper
@debug
def add(a, b):
return a + b
print(add(3, 4))
def pipeline(data, *funcs):
return functools.reduce(lambda acc, f: f(acc), funcs, data)
result = pipeline(
" Hello World ",
str.strip,
str.lower,
lambda s: s.replace(" ", "-"),
)
print(result)
核心差异
| 特性 | Java | Python |
|---|
| 偏函数 | 手动实现 | functools.partial |
| 缓存 | Guava/Caffeine(需配置) | @lru_cache 一行搞定 |
| 泛型函数 | 方法重载(编译期) | @singledispatch(运行期) |
| 自动生成比较 | Comparable 接口 | @total_ordering |
常见陷阱
@lru_cache(maxsize=128)
def process(data):
return data
process((1, 2, 3))
@lru_cache(maxsize=None)
def query_db(sql):
return f"Result of {sql}"
@singledispatch
def foo(a, b):
return "default"
@foo.register(str)
def _(a, b):
return f"str: {a}, {b}"
print(foo("hello", 42))
print(foo(42, "hello"))
何时使用
partial:需要复用函数但固定部分参数时
lru_cache:纯函数、递归、重复计算——几乎总是好的选择
singledispatch:需要根据类型做不同处理时(替代 if-elif-isinstance 链)
total_ordering:自定义类需要完整比较操作时
5.6 itertools: 无限迭代器与组合工具
Java/Kotlin 对比
Python 实现
import itertools
from itertools import (
count, cycle, repeat,
chain, islice, takewhile, dropwhile,
product, permutations, combinations, combinations_with_replacement,
groupby,
)
for i in islice(count(1, 2), 5):
print(i, end=" ")
print()
colors = ["red", "green", "blue"]
for color in islice(cycle(colors), 7):
print(color, end=" ")
print()
for val in islice(repeat("hello"), 3):
print(val, end=" ")
print()
print(list(repeat(0, 5)))
list1 = [1, 2, 3]
list2 = [4, 5, 6]
list3 = [7, 8, 9]
print(list(chain(list1, list2, list3)))
nested = [[1, 2], [3, 4], [5, 6]]
print(list(chain.from_iterable(nested)))
print(list(islice(count(), 5, 10)))
print(list(islice("abcdefg", 2, 6, 2)))
print(list(takewhile(lambda x: x < 5, [1, 3, 5, 2, 4, 6])))
print(list(dropwhile(lambda x: x < 5, [1, 3, 5, 2, 4, 6])))
colors = ["red", "blue"]
sizes = ["S", "M", "L"]
print(list(product(colors, sizes)))
print(list(product([1, 2], repeat=3)))
print(list(permutations([1, 2, 3], 2)))
print(list(combinations([1, 2, 3, 4], 2)))
print(list(combinations_with_replacement([1, 2, 3], 2)))
data = [
("apple", "fruit"),
("banana", "fruit"),
("carrot", "vegetable"),
("date", "fruit"),
("eggplant", "vegetable"),
]
for key, group in groupby(data, key=lambda x: x[1]):
print(f"{key}: {[item[0] for item in group]}")
data_sorted = sorted(data, key=lambda x: x[1])
for key, group in groupby(data_sorted, key=lambda x: x[1]):
print(f"{key}: {[item[0] for item in group]}")
import string
digits = string.digits
passwords_2 = map(
lambda combo: "".join(combo),
product(digits, repeat=2)
)
print(list(islice(passwords_2, 5)))
def batched(iterable, n):
"""将数据分成每批 n 个(Python 3.12+ 有内置 batched,这里手动实现)"""
it = iter(iterable)
while batch := list(islice(it, n)):
yield batch
data = range(10)
for batch in batched(data, 3):
print(batch)
核心差异
| 特性 | Java Stream | Python itertools |
|---|
| 无限流 | 不支持 | count, cycle, repeat |
| 笛卡尔积 | flatMap 嵌套 | product |
| 排列组合 | 需第三方库 | permutations, combinations |
| 分组 | Collectors.groupingBy | groupby(需预排序) |
| 惰性 | 是 | 是 |
常见陷阱
data = [1, 2, 1, 2, 1]
for key, group in groupby(data):
print(key, list(group))
for key, group in groupby(sorted(data)):
print(key, list(group))
groups = {}
for key, group in groupby(sorted(data)):
groups[key] = list(group)
何时使用
count/cycle/repeat:需要无限序列时(配合 islice)
product/permutations/combinations:排列组合、暴力搜索
chain:合并多个可迭代对象
groupby:有序数据的分组聚合
5.7 生成器: yield 与惰性求值
Java/Kotlin 对比
List<Integer> naturals = Stream.generate(() -> 1)
.limit(5)
.collect(Collectors.toList());
val naturals = sequence {
var n = 0
while (true) {
yield(n++)
}
}
naturals.take(5).toList()
Python 实现
def countdown(n):
"""倒计时生成器"""
while n > 0:
yield n
n -= 1
gen = countdown(5)
print(type(gen))
print(next(gen))
print(next(gen))
print(next(gen))
for value in gen:
print(value, end=" ")
print()
def fibonacci_generator(limit):
"""惰性生成斐波那契数列"""
a, b = 0, 1
while a < limit:
yield a
a, b = b, a + b
import sys
gen = fibonacci_generator(10**100)
print(sys.getsizeof(gen))
small_list = list(fibonacci_generator(10000))
print(sys.getsizeof(small_list))
def upper_letters():
for c in "ABC":
yield c
def lower_letters():
for c in "def":
yield c
def all_letters():
yield from upper_letters()
yield from lower_letters()
yield from "123"
print(list(all_letters()))
def flatten(nested):
"""展平嵌套列表"""
for item in nested:
if isinstance(item, (list, tuple)):
yield from flatten(item)
else:
yield item
data = [1, [2, 3], [4, [5, 6]], 7]
print(list(flatten(data)))
nums = range(10)
squares_list = [x ** 2 for x in nums]
print(type(squares_list))
squares_gen = (x ** 2 for x in nums)
print(type(squares_gen))
import sys
list_comp = [x ** 2 for x in range(1000000)]
gen_expr = (x ** 2 for x in range(1000000))
print(sys.getsizeof(list_comp))
print(sys.getsizeof(gen_expr))
print(sum(x ** 2 for x in range(5)))
def natural_numbers():
"""所有自然数(无限)"""
n = 0
while True:
yield n
n += 1
def primes():
"""素数生成器(无限)"""
yield 2
composites = {}
n = 3
while True:
if n not in composites:
yield n
composites[n * n] = [n]
else:
for p in composites[n]:
composites.setdefault(p + n, []).append(p)
del composites[n]
n += 2
from itertools import islice
print(list(islice(primes(), 10)))
def read_lines(filename):
"""模拟逐行读取文件"""
lines = ["line1\n", "line2\n", "line3\n", ""]
for line in lines:
yield line
def strip_lines(lines):
for line in lines:
yield line.strip()
def filter_empty(lines):
for line in lines:
if line:
yield line
def upper_lines(lines):
for line in lines:
yield line.upper()
pipeline = upper_lines(filter_empty(strip_lines(read_lines("dummy"))))
print(list(pipeline))
def gen_demo():
yield 1
yield 2
yield 3
g = gen_demo()
print(next(g))
print(g.send(None))
g.close()
def gen_with_throw():
try:
yield 1
yield 2
yield 3
except ValueError:
yield "caught!"
g = gen_with_throw()
print(next(g))
print(g.throw(ValueError))
print(next(g))
核心差异
| 特性 | Java | Kotlin Sequence | Python Generator |
|---|
| 创建方式 | 实现 Iterator | sequence { } | yield |
| 惰性求值 | 是 | 是 | 是 |
| 委托 | 需手动 | yieldAll() | yield from |
| 生成器表达式 | 无 | 无 | (x for x in ...) |
| 双向通信 | 无 | 无 | send() / throw() |
常见陷阱
gen = (x for x in range(3))
print(list(gen))
print(list(gen))
def gen_with_return():
yield 1
yield 2
return "done"
g = gen_with_return()
print(next(g))
print(next(g))
try:
next(g)
except StopIteration as e:
print(e.value)
gen = (x for x in range(5))
何时使用
- 处理大数据集时(文件、数据库结果、网络流)
- 需要惰性求值避免不必要的计算时
- 管道式数据处理时
- 无限序列时(必须用生成器)
- 原则:如果数据量大或可能无限,用生成器;如果数据小且需要多次访问,用列表
5.8 async/await 协程基础
Java/Kotlin 对比
try (var executor = Executors.newVirtualThreadPerTaskExecutor()) {
Future<String> future = executor.submit(() -> {
Thread.sleep(1000);
return "result";
});
String result = future.get();
}
CompletableFuture.supplyAsync(() -> fetchData())
.thenApply(data -> process(data))
.thenAccept(result -> System.out.println(result));
suspend fun fetchData(): String {
delay(1000)
return "data"
}
Python 实现
import asyncio
async def hello():
"""协程函数"""
print("Hello")
await asyncio.sleep(0.1)
print("World")
coro = hello()
print(type(coro))
async def fetch_data(url: str) -> str:
"""模拟异步 HTTP 请求"""
print(f"Fetching {url}...")
await asyncio.sleep(1)
return f"Data from {url}"
async def process():
"""串行执行"""
data1 = await fetch_data("/api/users")
data2 = await fetch_data("/api/posts")
print(data1)
print(data2)
async def process_concurrent():
"""并发执行"""
results = await asyncio.gather(
fetch_data("/api/users"),
fetch_data("/api/posts"),
fetch_data("/api/comments"),
)
for r in results:
print(r)
async def background_job(name: str, seconds: float):
print(f"{name} started")
await asyncio.sleep(seconds)
print(f"{name} finished after {seconds}s")
return f"{name} result"
async def main():
task1 = asyncio.create_task(background_job("Job-A", 0.5))
task2 = asyncio.create_task(background_job("Job-B", 0.3))
print("Doing other work...")
await asyncio.sleep(0.1)
print("Other work done")
result1 = await task1
result2 = await task2
print(result1, result2)
async def async_range(n: int):
"""异步生成器"""
for i in range(n):
await asyncio.sleep(0.1)
yield i
async def consume_async():
async for value in async_range(3):
print(value, end=" ")
print()
class AsyncResource:
async def __aenter__(self):
print("Acquiring resource...")
await asyncio.sleep(0.1)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
print("Releasing resource...")
await asyncio.sleep(0.1)
return False
async def use_resource():
async with AsyncResource() as resource:
print("Using resource")
await asyncio.sleep(0.1)
async def with_timeout():
try:
result = await asyncio.wait_for(
fetch_data("/api/slow"),
timeout=0.5
)
print(result)
except asyncio.TimeoutError:
print("Request timed out!")
async def download(url: str) -> tuple[str, str]:
"""模拟下载"""
await asyncio.sleep(0.1 * len(url))
return url, f"content of {url}"
async def download_all(urls: list[str]):
"""并发下载所有 URL"""
tasks = [asyncio.create_task(download(url)) for url in urls]
results = await asyncio.gather(*tasks)
return dict(results)
async def main_download():
urls = ["/a", "/bb", "/ccc", "/dddd"]
start = asyncio.get_event_loop().time()
result = await download_all(urls)
elapsed = asyncio.get_event_loop().time() - start
print(f"Downloaded {len(result)} URLs in {elapsed:.2f}s")
print(result)
核心差异
| 特性 | Java Virtual Threads | Kotlin Coroutines | Python async/await |
|---|
| 模型 | 线程(轻量级) | 协程(挂起) | 协程(事件循环) |
| 阻塞操作 | 需要虚拟线程 | suspend 函数 | 需要 async 版本库 |
| 并发原语 | ExecutorService | launch/async | create_task/gather |
| 取消 | interrupt | cancel | cancel |
| 生态 | JDK 内置 | kotlinx.coroutines | asyncio |
常见陷阱
import time
async def bad():
time.sleep(1)
return "done"
async def good():
await asyncio.sleep(1)
return "done"
async def run_blocking():
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, lambda: time.sleep(1) or "done")
return result
async def fetch():
return "data"
async def bad_call():
result = fetch()
print(result)
async def good_call():
result = await fetch()
print(result)
何时使用
- I/O 密集型任务:HTTP 请求、数据库查询、文件操作
- 需要并发但不需要多线程的复杂度时
- WebSocket、长轮询等实时通信
- CPU 密集型任务不要用 asyncio,用 multiprocessing
5.9 contextlib: 上下文管理器
Java/Kotlin 对比
try (BufferedReader reader = new BufferedReader(new FileReader("file.txt"))) {
String line;
while ((line = reader.readLine()) != null) {
System.out.println(line);
}
}
public class MyResource implements AutoCloseable {
@Override
public void close() throws Exception {
System.out.println("Resource closed");
}
}
File("file.txt").bufferedReader().use { reader ->
reader.lineSequence().forEach { println(it) }
}
Python 实现
from contextlib import contextmanager, suppress, redirect_stdout, redirect_stderr
import io
with open("/tmp/demo.txt", "w") as f:
f.write("hello")
f = open("/tmp/demo.txt", "w")
try:
f.write("hello")
finally:
f.close()
class DatabaseConnection:
def __init__(self, url: str):
self.url = url
self.connected = False
def __enter__(self):
"""进入 with 块时调用"""
print(f"Connecting to {self.url}")
self.connected = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""退出 with 块时调用(包括异常)"""
print(f"Disconnecting from {self.url}")
self.connected = False
if exc_type is not None:
print(f"Exception occurred: {exc_val}")
return False
def query(self, sql: str):
if not self.connected:
raise RuntimeError("Not connected")
return f"Result of: {sql}"
with DatabaseConnection("postgresql://localhost/mydb") as db:
print(db.query("SELECT * FROM users"))
@contextmanager
def timer(name: str):
"""计时上下文管理器"""
import time
start = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start
print(f"[{name}] {elapsed:.4f}s")
with timer("data processing"):
import time
time.sleep(0.1)
@contextmanager
def temporary_file(content: str):
"""创建临时文件的上下文管理器"""
import tempfile
import os
fd, path = tempfile.mkstemp()
try:
os.write(fd, content.encode())
os.close(fd)
yield path
finally:
os.unlink(path)
with temporary_file("hello world") as path:
print(f"File at: {path}")
with open(path) as f:
print(f.read())
@contextmanager
def assert_no_exception():
"""如果 with 块中抛出异常,重新抛出包装后的异常"""
try:
yield
except Exception as e:
raise RuntimeError(f"Caught in context: {e}") from e
try:
with assert_no_exception():
raise ValueError("original error")
except RuntimeError as e:
print(f"Caught: {e}")
print(f"Caused by: {e.__cause__}")
try:
os.remove("/tmp/nonexistent")
except FileNotFoundError:
pass
with suppress(FileNotFoundError):
os.remove("/tmp/nonexistent")
with suppress(FileNotFoundError, PermissionError):
os.remove("/tmp/nonexistent")
f = io.StringIO()
with redirect_stdout(f):
print("This goes to the buffer")
print("Not to the console")
output = f.getvalue()
print(f"Captured: {output!r}")
@contextmanager
def acquire_lock(name: str):
print(f"Lock {name} acquired")
try:
yield
finally:
print(f"Lock {name} released")
with (
acquire_lock("A"),
acquire_lock("B"),
):
print("Both locks held")
from contextlib import ExitStack
with ExitStack() as stack:
stack.enter_context(acquire_lock("X"))
stack.enter_context(acquire_lock("Y"))
print("X and Y held")
@contextmanager
def transaction(connection):
"""模拟数据库事务"""
print("BEGIN TRANSACTION")
try:
yield connection
print("COMMIT")
except Exception:
print("ROLLBACK")
raise
class MockConnection:
def execute(self, sql):
print(f" EXEC: {sql}")
conn = MockConnection()
try:
with transaction(conn):
conn.execute("INSERT INTO users VALUES (1, 'Alice')")
conn.execute("INSERT INTO orders VALUES (1, 100)")
raise ValueError("Oops!")
except ValueError:
pass
核心差异
| 特性 | Java try-with-resources | Kotlin use{} | Python with |
|---|
| 接口要求 | AutoCloseable | Closeable | __enter__/__exit__ |
| 异常处理 | catch 块 | try/catch | __exit__ 返回值 |
| 创建方式 | 类 | 类 | 类 或 @contextmanager |
| 动态管理 | 不支持 | 不支持 | ExitStack |
| 抑制异常 | 空 catch | 空 catch | suppress() |
常见陷阱
@contextmanager
def bad():
yield 1
yield 2
class SilentError:
def __exit__(self, *args):
return True
with SilentError():
raise ValueError("This is silently ignored")
print("No error raised")
with open("/tmp/demo.txt", "w") as f:
content = f.write("test")
何时使用
- 资源管理:文件、数据库连接、网络连接、锁
- 临时状态变更:环境变量、工作目录、日志级别
- 测试:
redirect_stdout、临时文件、mock
- 事务管理:数据库事务、原子操作
- 原则:任何需要"获取-使用-释放"模式的场景都用上下文管理器
本章总结
| 概念 | JVM 等价物 | Python 优势 | 关键陷阱 |
|---|
| 一等函数 | Kotlin 函数类型 | 函数可携带属性 | lambda 延迟绑定 |
| 高阶函数 | Stream API | zip/enumerate 原生支持 | 迭代器只能消费一次 |
| 闭包 | Kotlin lambda | nonlocal 显式声明 | 循环变量绑定 |
| 装饰器 | 无直接等价物 | 元编程利器 | 三层嵌套结构 |
| functools | Guava/Caffeine | @lru_cache 一行缓存 | 缓存参数须可哈希 |
| itertools | 有限 | 无限迭代器+组合工具 | groupby 需预排序 |
| 生成器 | Kotlin Sequence | yield/yield from | 只能消费一次 |
| async/await | Virtual Threads/Kotlin协程 | 内置事件循环 | 阻塞调用会卡死 |
| 上下文管理器 | try-with-resources | @contextmanager | __exit__ 返回值 |
核心原则:Python 的函数式编程不是要你写纯函数式代码,而是给你更强大的工具来表达意图。装饰器和生成器是 Python 最独特的两个特性,掌握它们是从"会写 Python"到"写好 Python"的关键分水岭。