想象一下,你有一部功能完整的手机,现在想要给它添加保护壳、贴膜和挂绳。你不需要拆开手机重新制造,只需要在外层添加这些配件。Python装饰器正是这样的"函数配件",它能在不修改原函数代码的情况下,为函数添加新功能,让代码既保持简洁又功能强大。
装饰器的三种境界
- 基础装饰器:函数的包装器
- 带参数装饰器:灵活的配置器
- 类装饰器:面向对象的装饰
实战代码:装饰器的艺术
基础装饰器
import time
import functools
from typing import Callable, Any
def timer(func: Callable) -> Callable:
"""计时装饰器:测量函数执行时间"""
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"函数 {func.__name__} 执行耗时: {end_time - start_time:.4f}秒")
return result
return wrapper
def logger(func: Callable) -> Callable:
"""日志装饰器:记录函数调用信息"""
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
print(f"开始调用: {func.__name__}")
print(f"参数: args={args}, kwargs={kwargs}")
result = func(*args, **kwargs)
print(f"函数 {func.__name__} 执行完成")
print(f"返回值: {result}")
return result
return wrapper
def retry(max_attempts: int = 3, delay: float = 1.0):
"""重试装饰器:在失败时自动重试"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
last_exception = None
for attempt in range(max_attempts):
try:
result = func(*args, **kwargs)
if attempt > 0:
print(f"第{attempt + 1}次尝试成功!")
return result
except Exception as e:
last_exception = e
print(f"第{attempt + 1}次尝试失败: {e}")
if attempt < max_attempts - 1:
print(f"等待{delay}秒后重试...")
time.sleep(delay)
print(f"所有{max_attempts}次尝试均失败")
raise last_exception
return wrapper
return decorator
# 使用基础装饰器
@timer
@logger
def calculate_sum(n: int) -> int:
"""计算1到n的和"""
return sum(range(1, n + 1))
@retry(max_attempts=3, delay=0.5)
def unreliable_operation():
"""不可靠的操作,可能随机失败"""
import random
if random.random() < 0.7:
raise ValueError("随机失败!")
return "操作成功"
def demo_basic_decorators():
"""演示基础装饰器的使用"""
print("=== 基础装饰器演示 ===")
# 测试计时和日志装饰器
result = calculate_sum(1000)
print(f"计算结果: {result}")
print("\n" + "="*40 + "\n")
# 测试重试装饰器
try:
outcome = unreliable_operation()
print(f"最终结果: {outcome}")
except Exception as e:
print(f"最终失败: {e}")
# 运行演示
# demo_basic_decorators()
高级装饰器应用
import functools
from typing import Callable, Any, Dict
import time
class Cache:
"""缓存装饰器类"""
def __init__(self, max_size: int = 100, ttl: int = 300):
self.max_size = max_size
self.ttl = ttl # 生存时间(秒)
self._cache: Dict = {}
self._access_times: Dict = {}
def __call__(self, func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
# 生成缓存键
cache_key = self._make_key(func.__name__, args, kwargs)
# 检查缓存是否存在且未过期
if self._is_valid(cache_key):
print(f"缓存命中: {func.__name__}{args}")
self._access_times[cache_key] = time.time()
return self._cache[cache_key]
# 清理过期缓存
self._cleanup()
# 执行函数并缓存结果
result = func(*args, **kwargs)
self._cache[cache_key] = result
self._access_times[cache_key] = time.time()
print(f"缓存新增: {func.__name__}{args} -> {result}")
return result
return wrapper
def _make_key(self, func_name: str, args: tuple, kwargs: dict) -> str:
"""生成缓存键"""
key_parts = [func_name, str(args), str(sorted(kwargs.items()))]
return hash(tuple(key_parts))
def _is_valid(self, cache_key: str) -> bool:
"""检查缓存是否有效"""
if cache_key not in self._cache:
return False
current_time = time.time()
cache_time = self._access_times.get(cache_key, 0)
# 检查是否过期
if current_time - cache_time > self.ttl:
del self._cache[cache_key]
del self._access_times[cache_key]
return False
return True
def _cleanup(self):
"""清理缓存"""
current_time = time.time()
# 移除过期缓存
expired_keys = [
key for key, access_time in self._access_times.items()
if current_time - access_time > self.ttl
]
for key in expired_keys:
del self._cache[key]
del self._access_times[key]
# 如果缓存仍然太大,移除最久未使用的
if len(self._cache) > self.max_size:
sorted_keys = sorted(
self._access_times.items(),
key=lambda x: x[1]
)[:len(self._cache) - self.max_size]
for key, _ in sorted_keys:
del self._cache[key]
del self._access_times[key]
def validate_input(*validators: Callable):
"""输入验证装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
# 验证位置参数
for i, (arg, validator) in enumerate(zip(args, validators)):
if not validator(arg):
raise ValueError(f"参数{i}验证失败: {arg}")
# 这里可以扩展关键字参数的验证
print(f"输入验证通过: {func.__name__}{args}")
return func(*args, **kwargs)
return wrapper
return decorator
def rate_limit(max_calls: int, period: float):
"""限流装饰器:限制函数调用频率"""
def decorator(func: Callable) -> Callable:
calls = []
lock = threading.Lock()
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
nonlocal calls
current_time = time.time()
with lock:
# 移除过期的调用记录
calls = [call_time for call_time in calls
if current_time - call_time < period]
# 检查是否超过限制
if len(calls) >= max_calls:
wait_time = period - (current_time - calls[0])
raise Exception(f"频率限制,请等待{wait_time:.2f}秒")
# 记录本次调用
calls.append(current_time)
return func(*args, **kwargs)
return wrapper
return decorator
# 验证器函数
def is_positive(x):
return x > 0
def is_even(x):
return x % 2 == 0
def is_string(x):
return isinstance(x, str)
# 使用高级装饰器
cache = Cache(max_size=10, ttl=10) # 10秒TTL,最多缓存10个结果
@cache
@validate_input(is_positive, is_even)
def expensive_operation(n: int) -> int:
"""模拟耗时操作"""
print(f"执行耗时计算: {n}")
time.sleep(1) # 模拟计算耗时
return n * n
@rate_limit(max_calls=3, period=5.0)
def api_call(endpoint: str):
"""模拟API调用"""
print(f"调用API: {endpoint}")
return f"响应来自 {endpoint}"
def demo_advanced_decorators():
"""演示高级装饰器的使用"""
print("=== 高级装饰器演示 ===")
# 测试缓存和验证装饰器
print("第一次调用(计算并缓存):")
result1 = expensive_operation(4)
print(f"结果: {result1}")
print("\n第二次调用(使用缓存):")
result2 = expensive_operation(4)
print(f"结果: {result2}")
print("\n不同参数的调用:")
expensive_operation(6)
expensive_operation(8)
print("\n" + "="*40 + "\n")
# 测试限流装饰器
print("限流测试:")
for i in range(5):
try:
result = api_call(f"/api/data/{i}")
print(f"调用 {i+1}: {result}")
except Exception as e:
print(f"调用 {i+1} 失败: {e}")
time.sleep(1)
# 运行演示需要导入threading
import threading
# demo_advanced_decorators()
类装饰器和属性装饰器
import functools
from datetime import datetime
from typing import Any, Callable
class Singleton:
"""单例装饰器类"""
def __init__(self, cls):
self.cls = cls
self.instance = None
def __call__(self, *args, **kwargs):
if self.instance is None:
self.instance = self.cls(*args, **kwargs)
return self.instance
class Deprecated:
"""标记过时装饰器"""
def __init__(self, message: str = ""):
self.message = message
def __call__(self, func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
print(f"警告: 函数 {func.__name__} 已过时!")
if self.message:
print(f"说明: {self.message}")
print(f"建议使用替代方案")
return func(*args, **kwargs)
return wrapper
def property_cache(func: Callable) -> property:
"""属性缓存装饰器"""
cache_name = f"_{func.__name__}_cached"
cache_value_name = f"_{func.__name__}_value"
@property
@functools.wraps(func)
def wrapper(self):
if not hasattr(self, cache_name) or not getattr(self, cache_name):
# 计算并缓存值
value = func(self)
setattr(self, cache_value_name, value)
setattr(self, cache_name, True)
return value
return getattr(self, cache_value_name)
@wrapper.setter
def wrapper(self, value):
setattr(self, cache_value_name, value)
setattr(self, cache_name, True)
@wrapper.deleter
def wrapper(self):
if hasattr(self, cache_name):
setattr(self, cache_name, False)
return wrapper
class Observable:
"""可观察属性装饰器"""
def __init__(self, on_change: Callable = None):
self.on_change = on_change
def __call__(self, func: Callable) -> property:
prop_name = func.__name__
private_name = f"_{prop_name}"
@property
@functools.wraps(func)
def wrapper(self):
return getattr(self, private_name, None)
@wrapper.setter
def wrapper(self, value):
old_value = getattr(self, private_name, None)
setattr(self, private_name, value)
# 触发变更回调
if self.on_change and old_value != value:
self.on_change(self, prop_name, old_value, value)
return wrapper
# 使用类装饰器
@Singleton
class DatabaseConnection:
"""数据库连接类(单例)"""
def __init__(self):
print("创建数据库连接")
self.connected_at = datetime.now()
def query(self, sql: str):
print(f"执行查询: {sql}")
return f"结果: {sql}"
@Deprecated("请使用 new_calculation_method 代替")
def old_calculation(x: int) -> int:
"""过时的计算方法"""
return x * 2
class DataProcessor:
"""数据处理类"""
def __init__(self, data: list):
self._data = data
self._change_count = 0
def _on_data_change(self, obj, prop_name, old_value, new_value):
"""属性变更回调"""
self._change_count += 1
print(f"属性 {prop_name} 从 {old_value} 变更为 {new_value}")
print(f"总计变更次数: {self._change_count}")
@property_cache
def processed_data(self):
"""处理后的数据(计算属性,带缓存)"""
print("计算 processed_data...")
return [x * 2 for x in self._data]
@Observable(on_change=_on_data_change)
def threshold(self):
"""可观察的阈值属性"""
return getattr(self, '_threshold', 0)
def demo_class_decorators():
"""演示类装饰器和属性装饰器"""
print("=== 类装饰器和属性装饰器演示 ===")
# 测试单例装饰器
print("单例模式测试:")
db1 = DatabaseConnection()
db2 = DatabaseConnection()
print(f"是同一个实例: {db1 is db2}")
print("\n" + "="*40 + "\n")
# 测试过时装饰器
print("过时函数测试:")
result = old_calculation(5)
print(f"结果: {result}")
print("\n" + "="*40 + "\n")
# 测试属性装饰器
print("属性装饰器测试:")
processor = DataProcessor([1, 2, 3, 4, 5])
print("第一次访问 processed_data:")
print(processor.processed_data)
print("第二次访问 processed_data(应该使用缓存):")
print(processor.processed_data)
print("清空缓存后访问:")
del processor.processed_data
print(processor.processed_data)
print("\n可观察属性测试:")
processor.threshold = 10
processor.threshold = 20
processor.threshold = 20 # 相同的值不会触发回调
# 运行演示
# demo_class_decorators()
装饰器在Web框架中的应用
from typing import Callable, Dict, Any
import json
class WebFramework:
"""简单的Web框架模拟"""
def __init__(self):
self.routes: Dict[str, Callable] = {}
self.middlewares: List[Callable] = []
def route(self, path: str, methods: list = None):
"""路由装饰器"""
if methods is None:
methods = ['GET']
def decorator(func: Callable) -> Callable:
def handler(request: Dict) -> Dict:
# 执行中间件
for middleware in self.middlewares:
request = middleware(request)
# 执行路由处理函数
response = func(request)
return response
# 注册路由
for method in methods:
route_key = f"{method}:{path}"
self.routes[route_key] = handler
return handler
return decorator
def middleware(self, func: Callable) -> Callable:
"""中间件装饰器"""
self.middlewares.append(func)
return func
def require_auth(self, roles: list = None):
"""认证装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(request: Dict) -> Dict:
user = request.get('user')
if not user:
return {
'status': 401,
'body': '未认证'
}
if roles and user.get('role') not in roles:
return {
'status': 403,
'body': '权限不足'
}
return func(request)
return wrapper
return decorator
def validate_schema(self, schema: Dict[str, Any]):
"""请求验证装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(request: Dict) -> Dict:
body = request.get('body', {})
for field, field_type in schema.items():
if field not in body:
return {
'status': 400,
'body': f'缺少字段: {field}'
}
if not isinstance(body[field], field_type):
return {
'status': 400,
'body': f'字段 {field} 类型错误'
}
return func(request)
return wrapper
return decorator
# 创建Web应用实例
app = WebFramework()
# 定义中间件
@app.middleware
def log_middleware(request: Dict) -> Dict:
"""日志中间件"""
print(f"请求: {request['method']} {request['path']}")
return request
@app.middleware
def parse_json_middleware(request: Dict) -> Dict:
"""JSON解析中间件"""
if 'body' in request and isinstance(request['body'], str):
try:
request['body'] = json.loads(request['body'])
except json.JSONDecodeError:
pass
return request
# 定义路由
@app.route('/api/users', methods=['GET'])
@app.require_auth(roles=['admin', 'user'])
def get_users(request: Dict) -> Dict:
"""获取用户列表"""
return {
'status': 200,
'body': ['user1', 'user2', 'user3']
}
@app.route('/api/users', methods=['POST'])
@app.require_auth(roles=['admin'])
@validate_schema({
'name': str,
'email': str,
'age': int
})
def create_user(request: Dict) -> Dict:
"""创建用户"""
body = request['body']
return {
'status': 201,
'body': f"创建用户: {body['name']}"
}
@app.route('/api/public', methods=['GET'])
def public_info(request: Dict) -> Dict:
"""公开信息(无需认证)"""
return {
'status': 200,
'body': '这是公开信息'
}
def demo_web_framework():
"""演示Web框架中的装饰器使用"""
print("=== Web框架装饰器演示 ===")
# 模拟请求
requests = [
{
'method': 'GET',
'path': '/api/public',
'user': None
},
{
'method': 'GET',
'path': '/api/users',
'user': {'id': 1, 'role': 'user'}
},
{
'method': 'POST',
'path': '/api/users',
'user': {'id': 1, 'role': 'admin'},
'body': '{"name": "张三", "email": "zhang@example.com", "age": 25}'
},
{
'method': 'POST',
'path': '/api/users',
'user': {'id': 1, 'role': 'user'},
'body': '{"name": "李四", "email": "li@example.com"}'
}
]
# 处理请求
for i, request in enumerate(requests):
print(f"\n请求 {i+1}:")
route_key = f"{request['method']}:{request['path']}"
if route_key in app.routes:
handler = app.routes[route_key]
response = handler(request)
print(f"响应: 状态码 {response['status']}, 内容: {response['body']}")
else:
print(f"404: 路由 {route_key} 不存在")
# 运行演示
# demo_web_framework()
装饰器开发原则
- 保持透明性
- 单一职责原则
- 性能考虑
- 错误处理
- 装饰器的核心价值在于:通过包装和组合,无侵入地增强函数功能