Python装饰器:优雅的函数增强术

33 阅读9分钟

想象一下,你有一部功能完整的手机,现在想要给它添加保护壳、贴膜和挂绳。你不需要拆开手机重新制造,只需要在外层添加这些配件。Python装饰器正是这样的"函数配件",它能在不修改原函数代码的情况下,为函数添加新功能,让代码既保持简洁又功能强大。

装饰器的三种境界

  1. 基础装饰器:函数的包装器
  2. 带参数装饰器:灵活的配置器
  3. 类装饰器:面向对象的装饰

实战代码:装饰器的艺术

基础装饰器

import time
import functools
from typing import Callable, Any

def timer(func: Callable) -> Callable:
    """计时装饰器:测量函数执行时间"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs) -> Any:
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"函数 {func.__name__} 执行耗时: {end_time - start_time:.4f}秒")
        return result
    return wrapper

def logger(func: Callable) -> Callable:
    """日志装饰器:记录函数调用信息"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs) -> Any:
        print(f"开始调用: {func.__name__}")
        print(f"参数: args={args}, kwargs={kwargs}")
        
        result = func(*args, **kwargs)
        
        print(f"函数 {func.__name__} 执行完成")
        print(f"返回值: {result}")
        return result
    return wrapper

def retry(max_attempts: int = 3, delay: float = 1.0):
    """重试装饰器:在失败时自动重试"""
    def decorator(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args, **kwargs) -> Any:
            last_exception = None
            
            for attempt in range(max_attempts):
                try:
                    result = func(*args, **kwargs)
                    
                    if attempt > 0:
                        print(f"第{attempt + 1}次尝试成功!")
                    
                    return result
                    
                except Exception as e:
                    last_exception = e
                    print(f"第{attempt + 1}次尝试失败: {e}")
                    
                    if attempt < max_attempts - 1:
                        print(f"等待{delay}秒后重试...")
                        time.sleep(delay)
            
            print(f"所有{max_attempts}次尝试均失败")
            raise last_exception
        return wrapper
    return decorator

# 使用基础装饰器
@timer
@logger
def calculate_sum(n: int) -> int:
    """计算1到n的和"""
    return sum(range(1, n + 1))

@retry(max_attempts=3, delay=0.5)
def unreliable_operation():
    """不可靠的操作,可能随机失败"""
    import random
    if random.random() < 0.7:
        raise ValueError("随机失败!")
    return "操作成功"

def demo_basic_decorators():
    """演示基础装饰器的使用"""
    print("=== 基础装饰器演示 ===")
    
    # 测试计时和日志装饰器
    result = calculate_sum(1000)
    print(f"计算结果: {result}")
    
    print("\n" + "="*40 + "\n")
    
    # 测试重试装饰器
    try:
        outcome = unreliable_operation()
        print(f"最终结果: {outcome}")
    except Exception as e:
        print(f"最终失败: {e}")

# 运行演示
# demo_basic_decorators()

高级装饰器应用

import functools
from typing import Callable, Any, Dict
import time

class Cache:
    """缓存装饰器类"""
    
    def __init__(self, max_size: int = 100, ttl: int = 300):
        self.max_size = max_size
        self.ttl = ttl  # 生存时间(秒)
        self._cache: Dict = {}
        self._access_times: Dict = {}
    
    def __call__(self, func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args, **kwargs) -> Any:
            # 生成缓存键
            cache_key = self._make_key(func.__name__, args, kwargs)
            
            # 检查缓存是否存在且未过期
            if self._is_valid(cache_key):
                print(f"缓存命中: {func.__name__}{args}")
                self._access_times[cache_key] = time.time()
                return self._cache[cache_key]
            
            # 清理过期缓存
            self._cleanup()
            
            # 执行函数并缓存结果
            result = func(*args, **kwargs)
            self._cache[cache_key] = result
            self._access_times[cache_key] = time.time()
            
            print(f"缓存新增: {func.__name__}{args} -> {result}")
            return result
        
        return wrapper
    
    def _make_key(self, func_name: str, args: tuple, kwargs: dict) -> str:
        """生成缓存键"""
        key_parts = [func_name, str(args), str(sorted(kwargs.items()))]
        return hash(tuple(key_parts))
    
    def _is_valid(self, cache_key: str) -> bool:
        """检查缓存是否有效"""
        if cache_key not in self._cache:
            return False
        
        current_time = time.time()
        cache_time = self._access_times.get(cache_key, 0)
        
        # 检查是否过期
        if current_time - cache_time > self.ttl:
            del self._cache[cache_key]
            del self._access_times[cache_key]
            return False
        
        return True
    
    def _cleanup(self):
        """清理缓存"""
        current_time = time.time()
        
        # 移除过期缓存
        expired_keys = [
            key for key, access_time in self._access_times.items()
            if current_time - access_time > self.ttl
        ]
        
        for key in expired_keys:
            del self._cache[key]
            del self._access_times[key]
        
        # 如果缓存仍然太大,移除最久未使用的
        if len(self._cache) > self.max_size:
            sorted_keys = sorted(
                self._access_times.items(),
                key=lambda x: x[1]
            )[:len(self._cache) - self.max_size]
            
            for key, _ in sorted_keys:
                del self._cache[key]
                del self._access_times[key]

def validate_input(*validators: Callable):
    """输入验证装饰器"""
    def decorator(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args, **kwargs) -> Any:
            # 验证位置参数
            for i, (arg, validator) in enumerate(zip(args, validators)):
                if not validator(arg):
                    raise ValueError(f"参数{i}验证失败: {arg}")
            
            # 这里可以扩展关键字参数的验证
            print(f"输入验证通过: {func.__name__}{args}")
            return func(*args, **kwargs)
        return wrapper
    return decorator

def rate_limit(max_calls: int, period: float):
    """限流装饰器:限制函数调用频率"""
    def decorator(func: Callable) -> Callable:
        calls = []
        lock = threading.Lock()
        
        @functools.wraps(func)
        def wrapper(*args, **kwargs) -> Any:
            nonlocal calls
            current_time = time.time()
            
            with lock:
                # 移除过期的调用记录
                calls = [call_time for call_time in calls 
                        if current_time - call_time < period]
                
                # 检查是否超过限制
                if len(calls) >= max_calls:
                    wait_time = period - (current_time - calls[0])
                    raise Exception(f"频率限制,请等待{wait_time:.2f}秒")
                
                # 记录本次调用
                calls.append(current_time)
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

# 验证器函数
def is_positive(x):
    return x > 0

def is_even(x):
    return x % 2 == 0

def is_string(x):
    return isinstance(x, str)

# 使用高级装饰器
cache = Cache(max_size=10, ttl=10)  # 10秒TTL,最多缓存10个结果

@cache
@validate_input(is_positive, is_even)
def expensive_operation(n: int) -> int:
    """模拟耗时操作"""
    print(f"执行耗时计算: {n}")
    time.sleep(1)  # 模拟计算耗时
    return n * n

@rate_limit(max_calls=3, period=5.0)
def api_call(endpoint: str):
    """模拟API调用"""
    print(f"调用API: {endpoint}")
    return f"响应来自 {endpoint}"

def demo_advanced_decorators():
    """演示高级装饰器的使用"""
    print("=== 高级装饰器演示 ===")
    
    # 测试缓存和验证装饰器
    print("第一次调用(计算并缓存):")
    result1 = expensive_operation(4)
    print(f"结果: {result1}")
    
    print("\n第二次调用(使用缓存):")
    result2 = expensive_operation(4)
    print(f"结果: {result2}")
    
    print("\n不同参数的调用:")
    expensive_operation(6)
    expensive_operation(8)
    
    print("\n" + "="*40 + "\n")
    
    # 测试限流装饰器
    print("限流测试:")
    for i in range(5):
        try:
            result = api_call(f"/api/data/{i}")
            print(f"调用 {i+1}: {result}")
        except Exception as e:
            print(f"调用 {i+1} 失败: {e}")
        time.sleep(1)

# 运行演示需要导入threading
import threading
# demo_advanced_decorators()

类装饰器和属性装饰器

import functools
from datetime import datetime
from typing import Any, Callable

class Singleton:
    """单例装饰器类"""
    
    def __init__(self, cls):
        self.cls = cls
        self.instance = None
    
    def __call__(self, *args, **kwargs):
        if self.instance is None:
            self.instance = self.cls(*args, **kwargs)
        return self.instance

class Deprecated:
    """标记过时装饰器"""
    
    def __init__(self, message: str = ""):
        self.message = message
    
    def __call__(self, func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args, **kwargs) -> Any:
            print(f"警告: 函数 {func.__name__} 已过时!")
            if self.message:
                print(f"说明: {self.message}")
            print(f"建议使用替代方案")
            return func(*args, **kwargs)
        return wrapper

def property_cache(func: Callable) -> property:
    """属性缓存装饰器"""
    cache_name = f"_{func.__name__}_cached"
    cache_value_name = f"_{func.__name__}_value"
    
    @property
    @functools.wraps(func)
    def wrapper(self):
        if not hasattr(self, cache_name) or not getattr(self, cache_name):
            # 计算并缓存值
            value = func(self)
            setattr(self, cache_value_name, value)
            setattr(self, cache_name, True)
            return value
        return getattr(self, cache_value_name)
    
    @wrapper.setter
    def wrapper(self, value):
        setattr(self, cache_value_name, value)
        setattr(self, cache_name, True)
    
    @wrapper.deleter
    def wrapper(self):
        if hasattr(self, cache_name):
            setattr(self, cache_name, False)
    
    return wrapper

class Observable:
    """可观察属性装饰器"""
    
    def __init__(self, on_change: Callable = None):
        self.on_change = on_change
    
    def __call__(self, func: Callable) -> property:
        prop_name = func.__name__
        private_name = f"_{prop_name}"
        
        @property
        @functools.wraps(func)
        def wrapper(self):
            return getattr(self, private_name, None)
        
        @wrapper.setter
        def wrapper(self, value):
            old_value = getattr(self, private_name, None)
            setattr(self, private_name, value)
            
            # 触发变更回调
            if self.on_change and old_value != value:
                self.on_change(self, prop_name, old_value, value)
        
        return wrapper

# 使用类装饰器
@Singleton
class DatabaseConnection:
    """数据库连接类(单例)"""
    
    def __init__(self):
        print("创建数据库连接")
        self.connected_at = datetime.now()
    
    def query(self, sql: str):
        print(f"执行查询: {sql}")
        return f"结果: {sql}"

@Deprecated("请使用 new_calculation_method 代替")
def old_calculation(x: int) -> int:
    """过时的计算方法"""
    return x * 2

class DataProcessor:
    """数据处理类"""
    
    def __init__(self, data: list):
        self._data = data
        self._change_count = 0
    
    def _on_data_change(self, obj, prop_name, old_value, new_value):
        """属性变更回调"""
        self._change_count += 1
        print(f"属性 {prop_name}{old_value} 变更为 {new_value}")
        print(f"总计变更次数: {self._change_count}")
    
    @property_cache
    def processed_data(self):
        """处理后的数据(计算属性,带缓存)"""
        print("计算 processed_data...")
        return [x * 2 for x in self._data]
    
    @Observable(on_change=_on_data_change)
    def threshold(self):
        """可观察的阈值属性"""
        return getattr(self, '_threshold', 0)

def demo_class_decorators():
    """演示类装饰器和属性装饰器"""
    print("=== 类装饰器和属性装饰器演示 ===")
    
    # 测试单例装饰器
    print("单例模式测试:")
    db1 = DatabaseConnection()
    db2 = DatabaseConnection()
    print(f"是同一个实例: {db1 is db2}")
    
    print("\n" + "="*40 + "\n")
    
    # 测试过时装饰器
    print("过时函数测试:")
    result = old_calculation(5)
    print(f"结果: {result}")
    
    print("\n" + "="*40 + "\n")
    
    # 测试属性装饰器
    print("属性装饰器测试:")
    processor = DataProcessor([1, 2, 3, 4, 5])
    
    print("第一次访问 processed_data:")
    print(processor.processed_data)
    
    print("第二次访问 processed_data(应该使用缓存):")
    print(processor.processed_data)
    
    print("清空缓存后访问:")
    del processor.processed_data
    print(processor.processed_data)
    
    print("\n可观察属性测试:")
    processor.threshold = 10
    processor.threshold = 20
    processor.threshold = 20  # 相同的值不会触发回调

# 运行演示
# demo_class_decorators()

装饰器在Web框架中的应用

from typing import Callable, Dict, Any
import json

class WebFramework:
    """简单的Web框架模拟"""
    
    def __init__(self):
        self.routes: Dict[str, Callable] = {}
        self.middlewares: List[Callable] = []
    
    def route(self, path: str, methods: list = None):
        """路由装饰器"""
        if methods is None:
            methods = ['GET']
        
        def decorator(func: Callable) -> Callable:
            def handler(request: Dict) -> Dict:
                # 执行中间件
                for middleware in self.middlewares:
                    request = middleware(request)
                
                # 执行路由处理函数
                response = func(request)
                return response
            
            # 注册路由
            for method in methods:
                route_key = f"{method}:{path}"
                self.routes[route_key] = handler
            
            return handler
        return decorator
    
    def middleware(self, func: Callable) -> Callable:
        """中间件装饰器"""
        self.middlewares.append(func)
        return func
    
    def require_auth(self, roles: list = None):
        """认证装饰器"""
        def decorator(func: Callable) -> Callable:
            @functools.wraps(func)
            def wrapper(request: Dict) -> Dict:
                user = request.get('user')
                
                if not user:
                    return {
                        'status': 401,
                        'body': '未认证'
                    }
                
                if roles and user.get('role') not in roles:
                    return {
                        'status': 403,
                        'body': '权限不足'
                    }
                
                return func(request)
            return wrapper
        return decorator
    
    def validate_schema(self, schema: Dict[str, Any]):
        """请求验证装饰器"""
        def decorator(func: Callable) -> Callable:
            @functools.wraps(func)
            def wrapper(request: Dict) -> Dict:
                body = request.get('body', {})
                
                for field, field_type in schema.items():
                    if field not in body:
                        return {
                            'status': 400,
                            'body': f'缺少字段: {field}'
                        }
                    
                    if not isinstance(body[field], field_type):
                        return {
                            'status': 400,
                            'body': f'字段 {field} 类型错误'
                        }
                
                return func(request)
            return wrapper
        return decorator

# 创建Web应用实例
app = WebFramework()

# 定义中间件
@app.middleware
def log_middleware(request: Dict) -> Dict:
    """日志中间件"""
    print(f"请求: {request['method']} {request['path']}")
    return request

@app.middleware
def parse_json_middleware(request: Dict) -> Dict:
    """JSON解析中间件"""
    if 'body' in request and isinstance(request['body'], str):
        try:
            request['body'] = json.loads(request['body'])
        except json.JSONDecodeError:
            pass
    return request

# 定义路由
@app.route('/api/users', methods=['GET'])
@app.require_auth(roles=['admin', 'user'])
def get_users(request: Dict) -> Dict:
    """获取用户列表"""
    return {
        'status': 200,
        'body': ['user1', 'user2', 'user3']
    }

@app.route('/api/users', methods=['POST'])
@app.require_auth(roles=['admin'])
@validate_schema({
    'name': str,
    'email': str,
    'age': int
})
def create_user(request: Dict) -> Dict:
    """创建用户"""
    body = request['body']
    return {
        'status': 201,
        'body': f"创建用户: {body['name']}"
    }

@app.route('/api/public', methods=['GET'])
def public_info(request: Dict) -> Dict:
    """公开信息(无需认证)"""
    return {
        'status': 200,
        'body': '这是公开信息'
    }

def demo_web_framework():
    """演示Web框架中的装饰器使用"""
    print("=== Web框架装饰器演示 ===")
    
    # 模拟请求
    requests = [
        {
            'method': 'GET',
            'path': '/api/public',
            'user': None
        },
        {
            'method': 'GET', 
            'path': '/api/users',
            'user': {'id': 1, 'role': 'user'}
        },
        {
            'method': 'POST',
            'path': '/api/users',
            'user': {'id': 1, 'role': 'admin'},
            'body': '{"name": "张三", "email": "zhang@example.com", "age": 25}'
        },
        {
            'method': 'POST',
            'path': '/api/users', 
            'user': {'id': 1, 'role': 'user'},
            'body': '{"name": "李四", "email": "li@example.com"}'
        }
    ]
    
    # 处理请求
    for i, request in enumerate(requests):
        print(f"\n请求 {i+1}:")
        route_key = f"{request['method']}:{request['path']}"
        
        if route_key in app.routes:
            handler = app.routes[route_key]
            response = handler(request)
            print(f"响应: 状态码 {response['status']}, 内容: {response['body']}")
        else:
            print(f"404: 路由 {route_key} 不存在")

# 运行演示
# demo_web_framework()

装饰器开发原则

  1. 保持透明性
  2. 单一职责原则
  3. 性能考虑
  4. 错误处理
  • 装饰器的核心价值在于:通过包装和组合,无侵入地增强函数功能