想象一下,你每次使用完一个工具后,它都能自动清理并放回原处;打开一扇门后,它能确保门被正确关闭。在编程世界中,这种"用了就自动清理"的机制正是上下文管理器的用武之地。它让资源管理变得自动化、安全且优雅,彻底告别资源泄漏的烦恼。
上下文管理器的三大应用场景
- 文件操作自动化
- 数据库连接管理
- 锁和同步控制
实战代码:构建智能资源管理系统
基础上下文管理器
import sqlite3
import threading
from contextlib import contextmanager
from typing import Any, Iterator
class FileManager:
"""文件管理的上下文管理器"""
def __init__(self, filename: str, mode: str = 'r', encoding: str = 'utf-8'):
self.filename = filename
self.mode = mode
self.encoding = encoding
self.file = None
def __enter__(self) -> Any:
"""进入上下文时调用"""
print(f"打开文件: {self.filename}")
self.file = open(self.filename, self.mode, encoding=self.encoding)
return self.file
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
"""退出上下文时调用"""
if self.file:
print(f"关闭文件: {self.filename}")
self.file.close()
# 如果发生异常,返回True表示已处理,False会重新抛出异常
return False
class DatabaseConnection:
"""数据库连接的上下文管理器"""
def __init__(self, db_path: str):
self.db_path = db_path
self.connection = None
def __enter__(self) -> sqlite3.Connection:
"""建立数据库连接"""
print(f"连接到数据库: {self.db_path}")
self.connection = sqlite3.connect(self.db_path)
self.connection.row_factory = sqlite3.Row # 返回字典样式的行
return self.connection
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
"""关闭数据库连接,处理事务"""
if self.connection:
if exc_type is None:
# 没有异常,提交事务
print("提交事务并关闭连接")
self.connection.commit()
else:
# 发生异常,回滚事务
print(f"发生异常 {exc_type.__name__},回滚事务")
self.connection.rollback()
self.connection.close()
print("数据库连接已关闭")
# 返回False,让异常继续传播
return False
class ThreadLock:
"""线程锁的上下文管理器"""
def __init__(self, lock: threading.Lock):
self.lock = lock
self.acquired = False
def __enter__(self) -> 'ThreadLock':
"""获取锁"""
print("等待获取锁...")
self.lock.acquire()
self.acquired = True
print("锁已获取")
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
"""释放锁"""
if self.acquired:
self.lock.release()
print("锁已释放")
return False
def demo_basic_context_managers():
"""演示基础上下文管理器的使用"""
print("=== 基础上下文管理器演示 ===")
# 文件管理器演示
print("\n1. 文件管理器:")
try:
with FileManager('example.txt', 'w') as file:
file.write("Hello, Context Manager!\n")
file.write("这行文字会被自动保存并关闭文件。\n")
# 这里如果发生异常,文件也会被正确关闭
except Exception as e:
print(f"文件操作错误: {e}")
# 读取刚才写入的文件
with FileManager('example.txt', 'r') as file:
content = file.read()
print(f"文件内容: {content}")
# 数据库连接管理器演示
print("\n2. 数据库连接管理器:")
try:
with DatabaseConnection(':memory:') as conn:
# 创建表
conn.execute('''
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT UNIQUE NOT NULL
)
''')
# 插入数据
conn.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
("张三", "zhang@example.com")
)
conn.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
("李四", "li@example.com")
)
# 查询数据
cursor = conn.execute("SELECT * FROM users")
users = cursor.fetchall()
print("数据库用户:")
for user in users:
print(f" ID: {user['id']}, 姓名: {user['name']}, 邮箱: {user['email']}")
# 模拟一个错误(事务会自动回滚)
# conn.execute("INSERT INTO nonexistent_table VALUES (1)")
except Exception as e:
print(f"数据库操作错误: {e}")
# 线程锁管理器演示
print("\n3. 线程锁管理器:")
lock = threading.Lock()
def worker(thread_id):
with ThreadLock(lock):
print(f"线程 {thread_id} 正在处理关键任务...")
import time
time.sleep(1) # 模拟工作
print(f"线程 {thread_id} 完成工作")
# 创建多个线程测试锁
threads = []
for i in range(3):
thread = threading.Thread(target=worker, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# 运行演示
# demo_basic_context_managers()
使用contextlib的高级技巧
import time
import tempfile
import os
from contextlib import contextmanager, redirect_stdout, redirect_stderr
from io import StringIO
from typing import Any, Iterator
@contextmanager
def timer(operation_name: str) -> Iterator[None]:
"""计时上下文管理器"""
start_time = time.time()
print(f"开始 {operation_name}...")
try:
yield
finally:
end_time = time.time()
duration = end_time - start_time
print(f"{operation_name} 完成,耗时: {duration:.4f}秒")
@contextmanager
def temporary_directory() -> Iterator[str]:
"""临时目录上下文管理器"""
temp_dir = tempfile.mkdtemp()
print(f"创建临时目录: {temp_dir}")
try:
yield temp_dir
finally:
# 清理临时目录
import shutil
print(f"清理临时目录: {temp_dir}")
shutil.rmtree(temp_dir)
@contextmanager
def change_directory(new_path: str) -> Iterator[None]:
"""切换工作目录的上下文管理器"""
original_path = os.getcwd()
print(f"从 {original_path} 切换到 {new_path}")
try:
os.chdir(new_path)
yield
finally:
os.chdir(original_path)
print(f"切换回原目录: {original_path}")
@contextmanager
def suppress_exceptions(*exception_types) -> Iterator[None]:
"""抑制指定类型异常的上下文管理器"""
try:
yield
except exception_types as e:
print(f"抑制异常: {type(e).__name__}: {e}")
def demo_contextlib_managers():
"""演示contextlib创建的管理器"""
print("=== contextlib高级技巧演示 ===")
# 计时管理器
print("\n1. 计时管理器:")
with timer("数据计算操作"):
# 模拟耗时操作
result = sum(i * i for i in range(1000000))
print(f"计算结果: {result}")
# 临时目录管理器
print("\n2. 临时目录管理器:")
with temporary_directory() as temp_dir:
# 在临时目录中创建文件
temp_file = os.path.join(temp_dir, "test.txt")
with open(temp_file, 'w') as f:
f.write("临时文件内容")
print(f"在 {temp_dir} 中创建了文件")
print(f"临时目录存在: {os.path.exists(temp_dir)}")
# 临时目录已被自动清理
print(f"临时目录已清理: {not os.path.exists(temp_dir)}")
# 切换目录管理器
print("\n3. 切换目录管理器:")
original_dir = os.getcwd()
with change_directory("/tmp" if os.name != 'nt' else "C:\\Windows\\Temp"):
print(f"当前目录: {os.getcwd()}")
# 在这里进行目录相关操作
print(f"回到原目录: {os.getcwd() == original_dir}")
# 异常抑制管理器
print("\n4. 异常抑制管理器:")
# 正常情况会抛出异常
try:
with suppress_exceptions(ValueError, ZeroDivisionError):
result = 1 / 0 # 这会抛出ZeroDivisionError,但被抑制
print("这行不会执行")
except:
print("异常未被抑制")
# 抑制特定异常
with suppress_exceptions(ValueError):
raise ValueError("这个异常会被抑制")
print("程序继续执行,异常已被抑制")
# 输出重定向
print("\n5. 输出重定向:")
# 捕获标准输出
output = StringIO()
with redirect_stdout(output):
print("这行输出被重定向")
print("不会被显示在控制台")
captured_output = output.getvalue()
print(f"捕获的输出: {captured_output}")
# 运行演示
# demo_contextlib_managers()
数据库事务管理器
import sqlite3
from contextlib import contextmanager
from typing import Iterator, Tuple, Any
class TransactionManager:
"""高级数据库事务管理器"""
def __init__(self, db_path: str):
self.db_path = db_path
self.connection = None
@contextmanager
def transaction(self, isolation_level: str = None) -> Iterator[sqlite3.Connection]:
"""数据库事务上下文管理器"""
try:
self.connection = sqlite3.connect(self.db_path)
if isolation_level:
self.connection.isolation_level = isolation_level
self.connection.row_factory = sqlite3.Row
print("开始数据库事务")
yield self.connection
# 如果没有异常,提交事务
self.connection.commit()
print("事务提交成功")
except Exception as e:
if self.connection:
self.connection.rollback()
print(f"事务回滚,原因: {e}")
raise e
finally:
if self.connection:
self.connection.close()
print("数据库连接关闭")
class BatchOperationManager:
"""批量操作管理器"""
def __init__(self, commit_every: int = 1000):
self.commit_every = commit_every
self.operation_count = 0
@contextmanager
def batch_operation(self, connection: sqlite3.Connection) -> Iterator['BatchOperationManager']:
"""批量操作上下文管理器"""
try:
# 开始批量操作
self.operation_count = 0
yield self
# 提交剩余的操作
if self.operation_count > 0:
connection.commit()
print(f"批量操作完成,共处理 {self.operation_count} 条记录")
except Exception as e:
connection.rollback()
print(f"批量操作失败,已回滚: {e}")
raise
def execute_operation(self, cursor: sqlite3.Cursor, sql: str, params: Tuple = ()) -> None:
"""执行单个操作,自动处理批量提交"""
cursor.execute(sql, params)
self.operation_count += 1
# 达到提交阈值时自动提交
if self.operation_count % self.commit_every == 0:
cursor.connection.commit()
print(f"已提交 {self.operation_count} 条记录")
def demo_database_managers():
"""演示数据库事务管理器"""
print("=== 数据库事务管理器演示 ===")
# 创建内存数据库
db_path = ":memory:"
# 初始化数据库
with sqlite3.connect(db_path) as conn:
conn.execute('''
CREATE TABLE products (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
price REAL NOT NULL,
category TEXT NOT NULL
)
''')
conn.commit()
# 使用事务管理器
transaction_mgr = TransactionManager(db_path)
print("\n1. 正常事务操作:")
try:
with transaction_mgr.transaction() as conn:
# 插入一些测试数据
conn.execute(
"INSERT INTO products (name, price, category) VALUES (?, ?, ?)",
("笔记本电脑", 5999.0, "电子产品")
)
conn.execute(
"INSERT INTO products (name, price, category) VALUES (?, ?, ?)",
("智能手机", 3999.0, "电子产品")
)
# 查询验证
cursor = conn.execute("SELECT COUNT(*) as count FROM products")
count = cursor.fetchone()['count']
print(f"成功插入 {count} 条记录")
except Exception as e:
print(f"事务失败: {e}")
print("\n2. 带异常的事务(自动回滚):")
try:
with transaction_mgr.transaction() as conn:
conn.execute(
"INSERT INTO products (name, price, category) VALUES (?, ?, ?)",
("平板电脑", 2999.0, "电子产品")
)
# 故意制造一个错误
raise ValueError("模拟的业务逻辑错误")
except ValueError as e:
print(f"捕获到预期异常: {e}")
# 验证数据是否回滚
with transaction_mgr.transaction() as conn:
cursor = conn.execute("SELECT COUNT(*) as count FROM products")
count = cursor.fetchone()['count']
print(f"回滚后记录数: {count} (平板电脑插入被回滚)")
print("\n3. 批量操作管理器:")
batch_mgr = BatchOperationManager(commit_every=50)
with transaction_mgr.transaction() as conn:
with batch_mgr.batch_operation(conn) as batch:
cursor = conn.cursor()
# 模拟批量插入100条记录
for i in range(100):
batch.execute_operation(
cursor,
"INSERT INTO products (name, price, category) VALUES (?, ?, ?)",
(f"产品{i}", i * 100.0, "测试类别")
)
# 验证批量插入结果
with transaction_mgr.transaction() as conn:
cursor = conn.execute("SELECT COUNT(*) as count FROM products")
total_count = cursor.fetchone()['count']
print(f"批量操作后总记录数: {total_count}")
# 运行演示
# demo_database_managers()
网络连接管理器
import socket
import requests
from contextlib import contextmanager
from typing import Iterator, Optional
from urllib.parse import urlparse
class SocketConnection:
"""Socket连接上下文管理器"""
def __init__(self, host: str, port: int, timeout: float = 10.0):
self.host = host
self.port = port
self.timeout = timeout
self.socket = None
def __enter__(self) -> socket.socket:
"""建立Socket连接"""
print(f"连接到 {self.host}:{self.port}")
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.settimeout(self.timeout)
self.socket.connect((self.host, self.port))
return self.socket
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
"""关闭Socket连接"""
if self.socket:
print(f"关闭到 {self.host}:{self.port} 的连接")
self.socket.close()
return False
class HTTPConnection:
"""HTTP连接上下文管理器"""
def __init__(self, base_url: str, timeout: int = 30):
self.base_url = base_url
self.timeout = timeout
self.session = None
def __enter__(self) -> 'HTTPConnection':
"""创建HTTP会话"""
print(f"创建HTTP会话: {self.base_url}")
self.session = requests.Session()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
"""关闭HTTP会话"""
if self.session:
print("关闭HTTP会话")
self.session.close()
return False
def get(self, endpoint: str, **kwargs) -> requests.Response:
"""发送GET请求"""
url = f"{self.base_url}{endpoint}"
print(f"GET请求: {url}")
return self.session.get(url, timeout=self.timeout, **kwargs)
def post(self, endpoint: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
"""发送POST请求"""
url = f"{self.base_url}{endpoint}"
print(f"POST请求: {url}")
return self.session.post(url, json=data, timeout=self.timeout, **kwargs)
@contextmanager
def http_request(url: str, method: str = 'GET', **kwargs) -> Iterator[requests.Response]:
"""HTTP请求上下文管理器"""
session = requests.Session()
try:
print(f"发送{method}请求到: {url}")
response = session.request(method, url, **kwargs)
response.raise_for_status() # 如果状态码不是200,抛出异常
yield response
except requests.RequestException as e:
print(f"HTTP请求失败: {e}")
raise
finally:
session.close()
print("HTTP会话已关闭")
def demo_network_managers():
"""演示网络连接管理器"""
print("=== 网络连接管理器演示 ===")
print("\n1. HTTP连接管理器:")
# 使用HTTP连接管理器
with HTTPConnection("https://httpbin.org") as http:
# 发送多个请求,共享同一个会话
response1 = http.get("/get")
print(f"GET请求状态: {response1.status_code}")
response2 = http.post("/post", data={"message": "Hello World"})
print(f"POST请求状态: {response2.status_code}")
print("\n2. HTTP请求上下文管理器:")
# 使用HTTP请求上下文管理器
try:
with http_request("https://httpbin.org/json", 'GET') as response:
data = response.json()
print(f"获取到JSON数据: {data['slideshow']['title']}")
except Exception as e:
print(f"请求失败: {e}")
print("\n3. Socket连接管理器(模拟):")
# 注意:这里使用一个不存在的地址来演示,实际使用时需要有效的地址
try:
# 这个会失败,用于演示错误处理
with SocketConnection("invalid_host", 80, timeout=1.0) as sock:
sock.send(b"GET / HTTP/1.1\r\n\r\n")
response = sock.recv(1024)
print(f"收到响应: {response}")
except Exception as e:
print(f"Socket连接失败(预期中): {e}")
# 演示本地Socket服务器
print("\n4. 本地Socket服务器演示:")
def start_test_server():
"""启动一个简单的测试服务器"""
import threading
def server_thread():
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(('localhost', 9999))
server_socket.listen(1)
conn, addr = server_socket.accept()
data = conn.recv(1024)
conn.send(b"Hello from test server")
conn.close()
server_socket.close()
thread = threading.Thread(target=server_thread, daemon=True)
thread.start()
return thread
# 启动测试服务器
server_thread = start_test_server()
import time
time.sleep(0.1) # 给服务器启动时间
# 连接到测试服务器
try:
with SocketConnection('localhost', 9999, timeout=5.0) as client_socket:
client_socket.send(b"Test message")
response = client_socket.recv(1024)
print(f"从测试服务器收到: {response.decode()}")
except Exception as e:
print(f"测试服务器连接失败: {e}")
# 运行演示
# demo_network_managers()
资源池管理器
import threading
from contextlib import contextmanager
from typing import List, Iterator, Any
from queue import Queue, Empty
class ConnectionPool:
"""数据库连接池管理器"""
def __init__(self, create_connection, max_connections: int = 5):
self.create_connection = create_connection
self.max_connections = max_connections
self._pool = Queue(max_connections)
self._lock = threading.Lock()
self._created_count = 0
# 预先创建一些连接
for _ in range(min(2, max_connections)):
self._create_and_add_connection()
def _create_and_add_connection(self):
"""创建新连接并添加到池中"""
with self._lock:
if self._created_count < self.max_connections:
conn = self.create_connection()
self._pool.put(conn)
self._created_count += 1
@contextmanager
def get_connection(self, timeout: float = 5.0) -> Iterator[Any]:
"""从池中获取连接"""
conn = None
try:
# 尝试获取连接
try:
conn = self._pool.get(timeout=timeout)
except Empty:
# 池为空,创建新连接
self._create_and_add_connection()
conn = self._pool.get(timeout=timeout)
print(f"获取连接,池中剩余: {self._pool.qsize()}")
yield conn
finally:
# 将连接返回池中
if conn:
self._pool.put(conn)
print(f"归还连接,池中现有: {self._pool.qsize()}")
class ThreadPoolManager:
"""线程池上下文管理器"""
def __init__(self, max_workers: int = None):
if max_workers is None:
max_workers = min(32, (os.cpu_count() or 1) + 4)
self.max_workers = max_workers
self.executor = None
def __enter__(self) -> 'ThreadPoolManager':
"""创建线程池"""
from concurrent.futures import ThreadPoolExecutor
print(f"创建线程池,最大工作线程: {self.max_workers}")
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
"""关闭线程池"""
if self.executor:
print("关闭线程池...")
self.executor.shutdown(wait=True)
print("线程池已关闭")
return False
def submit(self, func, *args, **kwargs):
"""提交任务到线程池"""
return self.executor.submit(func, *args, **kwargs)
def demo_resource_pools():
"""演示资源池管理器"""
print("=== 资源池管理器演示 ===")
# 模拟数据库连接创建函数
def create_mock_connection():
"""创建模拟数据库连接"""
conn_id = f"conn_{threading.get_ident()}_{time.time()}"
print(f"创建新连接: {conn_id}")
return {"id": conn_id, "created_at": time.time()}
print("\n1. 连接池管理器:")
# 创建连接池
pool = ConnectionPool(create_mock_connection, max_connections=3)
def worker(worker_id):
"""工作线程函数"""
with pool.get_connection() as conn:
print(f"工作线程 {worker_id} 使用连接: {conn['id']}")
time.sleep(1) # 模拟工作
return f"工作线程 {worker_id} 完成"
# 使用线程池执行多个任务
with ThreadPoolManager(max_workers=5) as thread_pool:
futures = []
# 提交多个任务,测试连接池
for i in range(8):
future = thread_pool.submit(worker, i)
futures.append(future)
# 等待所有任务完成
for future in futures:
try:
result = future.result(timeout=10)
print(f"任务结果: {result}")
except Exception as e:
print(f"任务失败: {e}")
print("\n2. 线程池管理器独立使用:")
def cpu_intensive_task(n):
"""CPU密集型任务"""
print(f"开始计算任务 {n}")
result = sum(i * i for i in range(n))
print(f"任务 {n} 完成,结果: {result}")
return result
# 使用线程池执行计算任务
with ThreadPoolManager(max_workers=2) as pool:
futures = [
pool.submit(cpu_intensive_task, 1000000),
pool.submit(cpu_intensive_task, 2000000),
pool.submit(cpu_intensive_task, 1500000),
]
# 获取结果
for i, future in enumerate(futures):
try:
result = future.result()
print(f"任务 {i} 最终结果: {result}")
except Exception as e:
print(f"任务 {i} 出错: {e}")
# 运行演示
# demo_resource_pools()
上下文管理器开发原则
- 资源安全原则
- 异常传播原则
- 性能优化原则
- 用户体验原则
- 上下文管理器的核心价值在于:确保资源被自动且正确地管理,让代码更安全、更简洁