Python上下文管理器:资源管理的优雅之道

27 阅读11分钟

想象一下,你每次使用完一个工具后,它都能自动清理并放回原处;打开一扇门后,它能确保门被正确关闭。在编程世界中,这种"用了就自动清理"的机制正是上下文管理器的用武之地。它让资源管理变得自动化、安全且优雅,彻底告别资源泄漏的烦恼。

上下文管理器的三大应用场景

  • 文件操作自动化
  • 数据库连接管理
  • 锁和同步控制

实战代码:构建智能资源管理系统

基础上下文管理器

import sqlite3
import threading
from contextlib import contextmanager
from typing import Any, Iterator

class FileManager:
    """文件管理的上下文管理器"""
    
    def __init__(self, filename: str, mode: str = 'r', encoding: str = 'utf-8'):
        self.filename = filename
        self.mode = mode
        self.encoding = encoding
        self.file = None
    
    def __enter__(self) -> Any:
        """进入上下文时调用"""
        print(f"打开文件: {self.filename}")
        self.file = open(self.filename, self.mode, encoding=self.encoding)
        return self.file
    
    def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
        """退出上下文时调用"""
        if self.file:
            print(f"关闭文件: {self.filename}")
            self.file.close()
        
        # 如果发生异常,返回True表示已处理,False会重新抛出异常
        return False

class DatabaseConnection:
    """数据库连接的上下文管理器"""
    
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.connection = None
    
    def __enter__(self) -> sqlite3.Connection:
        """建立数据库连接"""
        print(f"连接到数据库: {self.db_path}")
        self.connection = sqlite3.connect(self.db_path)
        self.connection.row_factory = sqlite3.Row  # 返回字典样式的行
        return self.connection
    
    def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
        """关闭数据库连接,处理事务"""
        if self.connection:
            if exc_type is None:
                # 没有异常,提交事务
                print("提交事务并关闭连接")
                self.connection.commit()
            else:
                # 发生异常,回滚事务
                print(f"发生异常 {exc_type.__name__},回滚事务")
                self.connection.rollback()
            
            self.connection.close()
            print("数据库连接已关闭")
        
        # 返回False,让异常继续传播
        return False

class ThreadLock:
    """线程锁的上下文管理器"""
    
    def __init__(self, lock: threading.Lock):
        self.lock = lock
        self.acquired = False
    
    def __enter__(self) -> 'ThreadLock':
        """获取锁"""
        print("等待获取锁...")
        self.lock.acquire()
        self.acquired = True
        print("锁已获取")
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
        """释放锁"""
        if self.acquired:
            self.lock.release()
            print("锁已释放")
        return False

def demo_basic_context_managers():
    """演示基础上下文管理器的使用"""
    print("=== 基础上下文管理器演示 ===")
    
    # 文件管理器演示
    print("\n1. 文件管理器:")
    try:
        with FileManager('example.txt', 'w') as file:
            file.write("Hello, Context Manager!\n")
            file.write("这行文字会被自动保存并关闭文件。\n")
            # 这里如果发生异常,文件也会被正确关闭
    except Exception as e:
        print(f"文件操作错误: {e}")
    
    # 读取刚才写入的文件
    with FileManager('example.txt', 'r') as file:
        content = file.read()
        print(f"文件内容: {content}")
    
    # 数据库连接管理器演示
    print("\n2. 数据库连接管理器:")
    try:
        with DatabaseConnection(':memory:') as conn:
            # 创建表
            conn.execute('''
                CREATE TABLE users (
                    id INTEGER PRIMARY KEY,
                    name TEXT NOT NULL,
                    email TEXT UNIQUE NOT NULL
                )
            ''')
            
            # 插入数据
            conn.execute(
                "INSERT INTO users (name, email) VALUES (?, ?)",
                ("张三", "zhang@example.com")
            )
            conn.execute(
                "INSERT INTO users (name, email) VALUES (?, ?)", 
                ("李四", "li@example.com")
            )
            
            # 查询数据
            cursor = conn.execute("SELECT * FROM users")
            users = cursor.fetchall()
            
            print("数据库用户:")
            for user in users:
                print(f"  ID: {user['id']}, 姓名: {user['name']}, 邮箱: {user['email']}")
            
            # 模拟一个错误(事务会自动回滚)
            # conn.execute("INSERT INTO nonexistent_table VALUES (1)")
            
    except Exception as e:
        print(f"数据库操作错误: {e}")
    
    # 线程锁管理器演示
    print("\n3. 线程锁管理器:")
    lock = threading.Lock()
    
    def worker(thread_id):
        with ThreadLock(lock):
            print(f"线程 {thread_id} 正在处理关键任务...")
            import time
            time.sleep(1)  # 模拟工作
            print(f"线程 {thread_id} 完成工作")
    
    # 创建多个线程测试锁
    threads = []
    for i in range(3):
        thread = threading.Thread(target=worker, args=(i,))
        threads.append(thread)
        thread.start()
    
    for thread in threads:
        thread.join()

# 运行演示
# demo_basic_context_managers()

使用contextlib的高级技巧

import time
import tempfile
import os
from contextlib import contextmanager, redirect_stdout, redirect_stderr
from io import StringIO
from typing import Any, Iterator

@contextmanager
def timer(operation_name: str) -> Iterator[None]:
    """计时上下文管理器"""
    start_time = time.time()
    print(f"开始 {operation_name}...")
    
    try:
        yield
    finally:
        end_time = time.time()
        duration = end_time - start_time
        print(f"{operation_name} 完成,耗时: {duration:.4f}秒")

@contextmanager
def temporary_directory() -> Iterator[str]:
    """临时目录上下文管理器"""
    temp_dir = tempfile.mkdtemp()
    print(f"创建临时目录: {temp_dir}")
    
    try:
        yield temp_dir
    finally:
        # 清理临时目录
        import shutil
        print(f"清理临时目录: {temp_dir}")
        shutil.rmtree(temp_dir)

@contextmanager
def change_directory(new_path: str) -> Iterator[None]:
    """切换工作目录的上下文管理器"""
    original_path = os.getcwd()
    print(f"从 {original_path} 切换到 {new_path}")
    
    try:
        os.chdir(new_path)
        yield
    finally:
        os.chdir(original_path)
        print(f"切换回原目录: {original_path}")

@contextmanager
def suppress_exceptions(*exception_types) -> Iterator[None]:
    """抑制指定类型异常的上下文管理器"""
    try:
        yield
    except exception_types as e:
        print(f"抑制异常: {type(e).__name__}: {e}")

def demo_contextlib_managers():
    """演示contextlib创建的管理器"""
    print("=== contextlib高级技巧演示 ===")
    
    # 计时管理器
    print("\n1. 计时管理器:")
    with timer("数据计算操作"):
        # 模拟耗时操作
        result = sum(i * i for i in range(1000000))
        print(f"计算结果: {result}")
    
    # 临时目录管理器
    print("\n2. 临时目录管理器:")
    with temporary_directory() as temp_dir:
        # 在临时目录中创建文件
        temp_file = os.path.join(temp_dir, "test.txt")
        with open(temp_file, 'w') as f:
            f.write("临时文件内容")
        
        print(f"在 {temp_dir} 中创建了文件")
        print(f"临时目录存在: {os.path.exists(temp_dir)}")
    
    # 临时目录已被自动清理
    print(f"临时目录已清理: {not os.path.exists(temp_dir)}")
    
    # 切换目录管理器
    print("\n3. 切换目录管理器:")
    original_dir = os.getcwd()
    
    with change_directory("/tmp" if os.name != 'nt' else "C:\\Windows\\Temp"):
        print(f"当前目录: {os.getcwd()}")
        # 在这里进行目录相关操作
    
    print(f"回到原目录: {os.getcwd() == original_dir}")
    
    # 异常抑制管理器
    print("\n4. 异常抑制管理器:")
    
    # 正常情况会抛出异常
    try:
        with suppress_exceptions(ValueError, ZeroDivisionError):
            result = 1 / 0  # 这会抛出ZeroDivisionError,但被抑制
            print("这行不会执行")
    except:
        print("异常未被抑制")
    
    # 抑制特定异常
    with suppress_exceptions(ValueError):
        raise ValueError("这个异常会被抑制")
    
    print("程序继续执行,异常已被抑制")
    
    # 输出重定向
    print("\n5. 输出重定向:")
    
    # 捕获标准输出
    output = StringIO()
    with redirect_stdout(output):
        print("这行输出被重定向")
        print("不会被显示在控制台")
    
    captured_output = output.getvalue()
    print(f"捕获的输出: {captured_output}")

# 运行演示
# demo_contextlib_managers()

数据库事务管理器

import sqlite3
from contextlib import contextmanager
from typing import Iterator, Tuple, Any

class TransactionManager:
    """高级数据库事务管理器"""
    
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.connection = None
    
    @contextmanager
    def transaction(self, isolation_level: str = None) -> Iterator[sqlite3.Connection]:
        """数据库事务上下文管理器"""
        try:
            self.connection = sqlite3.connect(self.db_path)
            
            if isolation_level:
                self.connection.isolation_level = isolation_level
            
            self.connection.row_factory = sqlite3.Row
            print("开始数据库事务")
            
            yield self.connection
            
            # 如果没有异常,提交事务
            self.connection.commit()
            print("事务提交成功")
            
        except Exception as e:
            if self.connection:
                self.connection.rollback()
                print(f"事务回滚,原因: {e}")
            raise e
        
        finally:
            if self.connection:
                self.connection.close()
                print("数据库连接关闭")

class BatchOperationManager:
    """批量操作管理器"""
    
    def __init__(self, commit_every: int = 1000):
        self.commit_every = commit_every
        self.operation_count = 0
    
    @contextmanager
    def batch_operation(self, connection: sqlite3.Connection) -> Iterator['BatchOperationManager']:
        """批量操作上下文管理器"""
        try:
            # 开始批量操作
            self.operation_count = 0
            yield self
            
            # 提交剩余的操作
            if self.operation_count > 0:
                connection.commit()
                print(f"批量操作完成,共处理 {self.operation_count} 条记录")
                
        except Exception as e:
            connection.rollback()
            print(f"批量操作失败,已回滚: {e}")
            raise
    
    def execute_operation(self, cursor: sqlite3.Cursor, sql: str, params: Tuple = ()) -> None:
        """执行单个操作,自动处理批量提交"""
        cursor.execute(sql, params)
        self.operation_count += 1
        
        # 达到提交阈值时自动提交
        if self.operation_count % self.commit_every == 0:
            cursor.connection.commit()
            print(f"已提交 {self.operation_count} 条记录")

def demo_database_managers():
    """演示数据库事务管理器"""
    print("=== 数据库事务管理器演示 ===")
    
    # 创建内存数据库
    db_path = ":memory:"
    
    # 初始化数据库
    with sqlite3.connect(db_path) as conn:
        conn.execute('''
            CREATE TABLE products (
                id INTEGER PRIMARY KEY,
                name TEXT NOT NULL,
                price REAL NOT NULL,
                category TEXT NOT NULL
            )
        ''')
        conn.commit()
    
    # 使用事务管理器
    transaction_mgr = TransactionManager(db_path)
    
    print("\n1. 正常事务操作:")
    try:
        with transaction_mgr.transaction() as conn:
            # 插入一些测试数据
            conn.execute(
                "INSERT INTO products (name, price, category) VALUES (?, ?, ?)",
                ("笔记本电脑", 5999.0, "电子产品")
            )
            conn.execute(
                "INSERT INTO products (name, price, category) VALUES (?, ?, ?)",
                ("智能手机", 3999.0, "电子产品")
            )
            
            # 查询验证
            cursor = conn.execute("SELECT COUNT(*) as count FROM products")
            count = cursor.fetchone()['count']
            print(f"成功插入 {count} 条记录")
    
    except Exception as e:
        print(f"事务失败: {e}")
    
    print("\n2. 带异常的事务(自动回滚):")
    try:
        with transaction_mgr.transaction() as conn:
            conn.execute(
                "INSERT INTO products (name, price, category) VALUES (?, ?, ?)",
                ("平板电脑", 2999.0, "电子产品")
            )
            
            # 故意制造一个错误
            raise ValueError("模拟的业务逻辑错误")
            
    except ValueError as e:
        print(f"捕获到预期异常: {e}")
    
    # 验证数据是否回滚
    with transaction_mgr.transaction() as conn:
        cursor = conn.execute("SELECT COUNT(*) as count FROM products")
        count = cursor.fetchone()['count']
        print(f"回滚后记录数: {count} (平板电脑插入被回滚)")
    
    print("\n3. 批量操作管理器:")
    batch_mgr = BatchOperationManager(commit_every=50)
    
    with transaction_mgr.transaction() as conn:
        with batch_mgr.batch_operation(conn) as batch:
            cursor = conn.cursor()
            
            # 模拟批量插入100条记录
            for i in range(100):
                batch.execute_operation(
                    cursor,
                    "INSERT INTO products (name, price, category) VALUES (?, ?, ?)",
                    (f"产品{i}", i * 100.0, "测试类别")
                )
    
    # 验证批量插入结果
    with transaction_mgr.transaction() as conn:
        cursor = conn.execute("SELECT COUNT(*) as count FROM products")
        total_count = cursor.fetchone()['count']
        print(f"批量操作后总记录数: {total_count}")

# 运行演示
# demo_database_managers()

网络连接管理器

import socket
import requests
from contextlib import contextmanager
from typing import Iterator, Optional
from urllib.parse import urlparse

class SocketConnection:
    """Socket连接上下文管理器"""
    
    def __init__(self, host: str, port: int, timeout: float = 10.0):
        self.host = host
        self.port = port
        self.timeout = timeout
        self.socket = None
    
    def __enter__(self) -> socket.socket:
        """建立Socket连接"""
        print(f"连接到 {self.host}:{self.port}")
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.settimeout(self.timeout)
        self.socket.connect((self.host, self.port))
        return self.socket
    
    def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
        """关闭Socket连接"""
        if self.socket:
            print(f"关闭到 {self.host}:{self.port} 的连接")
            self.socket.close()
        return False

class HTTPConnection:
    """HTTP连接上下文管理器"""
    
    def __init__(self, base_url: str, timeout: int = 30):
        self.base_url = base_url
        self.timeout = timeout
        self.session = None
    
    def __enter__(self) -> 'HTTPConnection':
        """创建HTTP会话"""
        print(f"创建HTTP会话: {self.base_url}")
        self.session = requests.Session()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
        """关闭HTTP会话"""
        if self.session:
            print("关闭HTTP会话")
            self.session.close()
        return False
    
    def get(self, endpoint: str, **kwargs) -> requests.Response:
        """发送GET请求"""
        url = f"{self.base_url}{endpoint}"
        print(f"GET请求: {url}")
        return self.session.get(url, timeout=self.timeout, **kwargs)
    
    def post(self, endpoint: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
        """发送POST请求"""
        url = f"{self.base_url}{endpoint}"
        print(f"POST请求: {url}")
        return self.session.post(url, json=data, timeout=self.timeout, **kwargs)

@contextmanager
def http_request(url: str, method: str = 'GET', **kwargs) -> Iterator[requests.Response]:
    """HTTP请求上下文管理器"""
    session = requests.Session()
    
    try:
        print(f"发送{method}请求到: {url}")
        response = session.request(method, url, **kwargs)
        response.raise_for_status()  # 如果状态码不是200,抛出异常
        
        yield response
        
    except requests.RequestException as e:
        print(f"HTTP请求失败: {e}")
        raise
    
    finally:
        session.close()
        print("HTTP会话已关闭")

def demo_network_managers():
    """演示网络连接管理器"""
    print("=== 网络连接管理器演示 ===")
    
    print("\n1. HTTP连接管理器:")
    
    # 使用HTTP连接管理器
    with HTTPConnection("https://httpbin.org") as http:
        # 发送多个请求,共享同一个会话
        response1 = http.get("/get")
        print(f"GET请求状态: {response1.status_code}")
        
        response2 = http.post("/post", data={"message": "Hello World"})
        print(f"POST请求状态: {response2.status_code}")
    
    print("\n2. HTTP请求上下文管理器:")
    
    # 使用HTTP请求上下文管理器
    try:
        with http_request("https://httpbin.org/json", 'GET') as response:
            data = response.json()
            print(f"获取到JSON数据: {data['slideshow']['title']}")
    
    except Exception as e:
        print(f"请求失败: {e}")
    
    print("\n3. Socket连接管理器(模拟):")
    
    # 注意:这里使用一个不存在的地址来演示,实际使用时需要有效的地址
    try:
        # 这个会失败,用于演示错误处理
        with SocketConnection("invalid_host", 80, timeout=1.0) as sock:
            sock.send(b"GET / HTTP/1.1\r\n\r\n")
            response = sock.recv(1024)
            print(f"收到响应: {response}")
    
    except Exception as e:
        print(f"Socket连接失败(预期中): {e}")
    
    # 演示本地Socket服务器
    print("\n4. 本地Socket服务器演示:")
    
    def start_test_server():
        """启动一个简单的测试服务器"""
        import threading
        
        def server_thread():
            server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            server_socket.bind(('localhost', 9999))
            server_socket.listen(1)
            
            conn, addr = server_socket.accept()
            data = conn.recv(1024)
            conn.send(b"Hello from test server")
            conn.close()
            server_socket.close()
        
        thread = threading.Thread(target=server_thread, daemon=True)
        thread.start()
        return thread
    
    # 启动测试服务器
    server_thread = start_test_server()
    import time
    time.sleep(0.1)  # 给服务器启动时间
    
    # 连接到测试服务器
    try:
        with SocketConnection('localhost', 9999, timeout=5.0) as client_socket:
            client_socket.send(b"Test message")
            response = client_socket.recv(1024)
            print(f"从测试服务器收到: {response.decode()}")
    
    except Exception as e:
        print(f"测试服务器连接失败: {e}")

# 运行演示
# demo_network_managers()

资源池管理器

import threading
from contextlib import contextmanager
from typing import List, Iterator, Any
from queue import Queue, Empty

class ConnectionPool:
    """数据库连接池管理器"""
    
    def __init__(self, create_connection, max_connections: int = 5):
        self.create_connection = create_connection
        self.max_connections = max_connections
        self._pool = Queue(max_connections)
        self._lock = threading.Lock()
        self._created_count = 0
        
        # 预先创建一些连接
        for _ in range(min(2, max_connections)):
            self._create_and_add_connection()
    
    def _create_and_add_connection(self):
        """创建新连接并添加到池中"""
        with self._lock:
            if self._created_count < self.max_connections:
                conn = self.create_connection()
                self._pool.put(conn)
                self._created_count += 1
    
    @contextmanager
    def get_connection(self, timeout: float = 5.0) -> Iterator[Any]:
        """从池中获取连接"""
        conn = None
        
        try:
            # 尝试获取连接
            try:
                conn = self._pool.get(timeout=timeout)
            except Empty:
                # 池为空,创建新连接
                self._create_and_add_connection()
                conn = self._pool.get(timeout=timeout)
            
            print(f"获取连接,池中剩余: {self._pool.qsize()}")
            yield conn
            
        finally:
            # 将连接返回池中
            if conn:
                self._pool.put(conn)
                print(f"归还连接,池中现有: {self._pool.qsize()}")

class ThreadPoolManager:
    """线程池上下文管理器"""
    
    def __init__(self, max_workers: int = None):
        if max_workers is None:
            max_workers = min(32, (os.cpu_count() or 1) + 4)
        
        self.max_workers = max_workers
        self.executor = None
    
    def __enter__(self) -> 'ThreadPoolManager':
        """创建线程池"""
        from concurrent.futures import ThreadPoolExecutor
        
        print(f"创建线程池,最大工作线程: {self.max_workers}")
        self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
        """关闭线程池"""
        if self.executor:
            print("关闭线程池...")
            self.executor.shutdown(wait=True)
            print("线程池已关闭")
        return False
    
    def submit(self, func, *args, **kwargs):
        """提交任务到线程池"""
        return self.executor.submit(func, *args, **kwargs)

def demo_resource_pools():
    """演示资源池管理器"""
    print("=== 资源池管理器演示 ===")
    
    # 模拟数据库连接创建函数
    def create_mock_connection():
        """创建模拟数据库连接"""
        conn_id = f"conn_{threading.get_ident()}_{time.time()}"
        print(f"创建新连接: {conn_id}")
        return {"id": conn_id, "created_at": time.time()}
    
    print("\n1. 连接池管理器:")
    
    # 创建连接池
    pool = ConnectionPool(create_mock_connection, max_connections=3)
    
    def worker(worker_id):
        """工作线程函数"""
        with pool.get_connection() as conn:
            print(f"工作线程 {worker_id} 使用连接: {conn['id']}")
            time.sleep(1)  # 模拟工作
            return f"工作线程 {worker_id} 完成"
    
    # 使用线程池执行多个任务
    with ThreadPoolManager(max_workers=5) as thread_pool:
        futures = []
        
        # 提交多个任务,测试连接池
        for i in range(8):
            future = thread_pool.submit(worker, i)
            futures.append(future)
        
        # 等待所有任务完成
        for future in futures:
            try:
                result = future.result(timeout=10)
                print(f"任务结果: {result}")
            except Exception as e:
                print(f"任务失败: {e}")
    
    print("\n2. 线程池管理器独立使用:")
    
    def cpu_intensive_task(n):
        """CPU密集型任务"""
        print(f"开始计算任务 {n}")
        result = sum(i * i for i in range(n))
        print(f"任务 {n} 完成,结果: {result}")
        return result
    
    # 使用线程池执行计算任务
    with ThreadPoolManager(max_workers=2) as pool:
        futures = [
            pool.submit(cpu_intensive_task, 1000000),
            pool.submit(cpu_intensive_task, 2000000),
            pool.submit(cpu_intensive_task, 1500000),
        ]
        
        # 获取结果
        for i, future in enumerate(futures):
            try:
                result = future.result()
                print(f"任务 {i} 最终结果: {result}")
            except Exception as e:
                print(f"任务 {i} 出错: {e}")

# 运行演示
# demo_resource_pools()

上下文管理器开发原则

  1. 资源安全原则
  2. 异常传播原则
  3. 性能优化原则
  4. 用户体验原则
  • 上下文管理器的核心价值在于:确保资源被自动且正确地管理,让代码更安全、更简洁