本文介绍了python with语句常见的使用场景。通过with语句进行的错误处理,contextlib中有趣的上下文管理器,以及一些关于with语句的代码设计。需要读者有一定的python水平。
With 上下文管理器常用场景
with 上下文管理器在python中最常用的使用场景是在读写文件的时候:
with open('data.log', 'r', encoding='utf8') as f:
result = f.read()
这样读写文件的好处就是在读写文件出现异常的时候,仍然能够确保文件能够正确关闭。其等价形式为:
try:
f = open('data.log', 'r', encoding='utf8')
result = f.read()
finally:
f.close()
比如在 data.log 为二进制文件时,utf-8编码无法正确解码其内容,会抛出 UnicodeDecodeError,如果不使用 try, except, finally 或者 with,这个异常会无法关闭文件导致资源泄露。
多个上下文管理器
多个上下文管理器可以放到同一行, 比较常见的场景就是复制文件:
with open('data.png', 'rb') as src, open('/pictures/data.png', 'wb') as des:
des.write(src.read())
进入代码块的时候会依次运行上下文管理器的"__enter__" 方法, 退出会依次运行上下文管理器的"__exit__" 方法。
自定义 with 上下文管理器
with语句可以在进入代码块和出代码块的时候执行代码,这样的特性比较适合用来写计时器。
在定义类的时候,定义 "__enter__" 和 "__exit__" 方法,这个类的实例就是一个上下文管理器,在进入代码块的时候 "__enter__" 方法会执行,退出代码块的时候 "__exit__" 方法就会执行。with 语句中,as 后面的变量就是 "__enter__" 方法的返回值。
from time import perf_counter, sleep
class Timer:
def __enter__(self):
self.t0 = perf_counter()
# 返回值会赋值给 with 语句中 as 后面跟着的变量
return self
def __exit__(self, exc_type, exc_value, traceback):
self.time_total = perf_counter() - self.t0
with Timer() as timer:
sleep(1)
print(timer.time_total)
输出结果 1.0050958750071004
上面的代码,我们定义了一个简易的计时器,在进入代码块的时候,运行了 self.t0 = perf_counter() 记录的初始时间,在退出代码块的时候运行了 self.time_total = perf_counter() - self.t0, 计算了从进入代码块到退出代码块花费的总时间。
在 "with Timer() as timer:" 语句中,"__enter__" 方法返回的值,赋值给了 timer 变量,所以可以直接在后面的代码中 "print(timer.time_total)"。
"__enter__" 和 "__exit__" 也是有“异步”版本的,异步的上下文管理器,使用"__aenter__" 和 "__aexit__" 方法。
import asyncio
class AsyncWaitOneSecond:
async def __aenter__(self):
# 进入代码
print("正在睡觉,请勿打扰")
await asyncio.sleep(1)
print("睡醒了")
async def __aexit__(self, exc_type, exc_value, traceback):
# 退出时运行
print("代码运行结束")
async def main():
async with AsyncWaitOneSecond():
print("正在运行代码。。。")
asyncio.run(main())
以上代码在控制台的输出如下:
正在睡觉,请勿打扰
睡醒了
正在运行代码。。。
代码运行结束
上下文管理器也可以通过标准库 contextlib 中的 contextmanager 和 asynccontextmanager 实现。 比如上面的计时器例子:
from contextlib import contextmanager
from time import perf_counter, sleep
@contextmanager
def timeit():
result = {}
t0 = perf_counter()
# yield 出去,进入with后面的代码块
yield result
result["time_total"] = perf_counter() - t0
with timeit() as result:
sleep(1)
print(result)
输出结果
{'time_total': 1.0051242079935037}
上面的代码,可能需要一些生成器的知识才能看得比较明白,如果大家感兴趣,后面我会再写一篇详细解释生成器的文章。
在 yield 之前的代码其实等价于在之前 class 中 "__enter__" 的代码,yield 后面跟着的变量,其实相当于 "__enter__" 方法的返回值。 在yield之后的代码,相当于class中 "__exit__" 方法里的代码。
异步版本的上下文管理器也是类似的:
import asyncio
from contextlib import asynccontextmanager
@asynccontextmanager
async def wait_one_second():
print("正在睡觉,请勿打扰")
await asyncio.sleep(1)
print("睡醒了")
yield
print("代码运行结束")
async def main():
async with wait_one_second():
print("正在运行代码。。。")
asyncio.run(main())
以上代码在控制台的输出如下:
正在睡觉,请勿打扰
睡醒了
正在运行代码。。。
代码运行结束
使用 With 上下文管理器 处理错误
不知道读者有没有发现,在上面代码"__exit__" 方法,中的参数有 "exc_type", "exc_value", "traceback",这三个参数,在python中,通常是用来处理异常的,分别为:异常的类,异常的实例 和 traceback实例。是的,"__exit__" 方法可以用来处理在with代码块中的异常。
在 "__exit__" 方法中如果返回值可以转换为True(即bool(x) 为 True),则该异常就会停止传播。如果返回值为可以转换为False,则该异常就会继续传播。没有用return返回实际上也会返回None,也可以转换为False。
class ThrowException:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
return False
with ThrowException():
raise ValueError
在上面的代码中,我在with代码块中抛出了一个ValueError,由于__exit__方法中返回值为False,这个异常将会继续传播,导致程序崩溃。
控制台输出如下:
Traceback (most recent call last):
File "/Users/a1/projects/test/test.py", line 10, in <module>
raise ValueError
ValueError
如果在"__exit__"方法中返回True,则不会抛出异常导致程序崩溃。
class SuppressException:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
return True
with SuppressException():
print("不会抛出错误")
raise ValueError
print("但是这行也不会执行")
控制台输出如下:
不会抛出错误
在exit中处理异常后,虽然不会抛出异常,但是异常后面的代码也不会执行了。
当然,你也可以指定上下文处理器对哪些异常不会抛出:
class SuppressSpecificException:
def __init__(self, exc_class):
self.exc_class = exc_class
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
if isinstance(exc_value, self.exc_class):
return True
with SuppressSpecificException(ValueError):
print("不会抛出Value Error 异常")
raise ValueError
with SuppressSpecificException(ValueError):
print("只处理了ValueError, 没有处理RuntimeError, 会抛出异常")
raise RuntimeError
控制台输出如下:
不会抛出Value Error 异常
只处理了ValueError, 没有处理RuntimeError, 会抛出异常
Traceback (most recent call last):
File "/Users/a1/projects/test/test.py", line 19, in <module>
raise RuntimeError
RuntimeError
contextlib 中部分有趣的上下文管理器
在python标准库 contextlib 中,有许多有意思的上下文管理器。本文由于篇幅有限,仅介绍:
- ContextDecorator
- redirect_stdout
contextlib 更多内容,请参考python文档:docs.python.org/zh-cn/3/lib…
ContextDecorator
这个类可以让你的上下文管理器迅速变为一个装饰器,相当于让被装饰器函数,在上下文管理器的with模块中运行。 示例代码:
from contextlib import ContextDecorator
from time import perf_counter, sleep
class Timer(ContextDecorator):
def __enter__(self):
self.t0 = perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.time_total = perf_counter() - self.t0
timer = Timer()
@timer
def do_something_heavy():
sleep(1)
do_something_heavy()
print(timer.time_total)
控制台输出结果
1.000443875003839
在上述代码中,Timer这个类继承了ContextDecorator,使得timer有了作为装饰器的能力。这个代码等价于:
with Timer() as timer:
do_something_heavy()
print(timer.time_total)
本质上Timer是继承了ContextDecorator的__call__方法。如果Timer不继承这个ContextDecorator, 添加合适的__call__方法也可以达到同样的效果。__call__方法是类的实例被小括号调用时,所运行的方法。
from functools import wraps
from time import perf_counter, sleep
class Timer:
def __enter__(self):
self.t0 = perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.time_total = perf_counter() - self.t0
def __call__(self, func):
@wraps(func)
def deco(*args, **kwargs):
with self:
return func(*args, **kwargs)
return deco
timer = Timer()
@timer
def do_something_heavy():
sleep(1)
do_something_heavy()
print(timer.time_total)
在控制台输出如下:
1.0050645000010263
上述代码中的Timer类,没有继承ContextDecorator,仍然实现了相同的功能。
ContextDecorator 也有异步版本,AsyncContextDecorator,本身的原理是一样的,这里不做过多介绍了。
redirect_stdout
redirect_stdout 是一个非常有趣的上下文管理器,它通过修改内置的一些方法,可以将sys.stdout输出到别的地方。print函数默认也是用的sys.stdout,也就是说,哪怕只是使用print也可以做一个类似日志。比如,将print中的内容发送到数据库里存起来。
只要向redirect_stdout传入一个有write和flush对象的方法,在print的时候就会调用这两个方法。
示例代码:
from contextlib import redirect_stdout
class Logger:
def __init__(self) -> None:
self._text = []
def write(self, text: str):
self._text.append(text)
def flush(self):
with open("test.log", "a", encoding="utf8") as f:
f.write("".join(self._text))
self._text.clear()
logger = Logger()
with redirect_stdout(logger):
print("控制台不会输出内容")
print("这里的内容将会写入到文件", flush=True)
print("这里的内容将会被输出到控制台")
控制台输出:
这里的内容将会被输出到控制台
此时,在文件夹下会有一个test.log记录print里的内容。 test.log 内容如下
控制台不会输出内容
这里的内容将会写入到文件
注意:如果不将print中flush参数设置为True,只有write方法会被调用,flush方法不会被调用
redirect_stdout通过修改sys.stdout实现的。我们也可以通过修改sys.stdout实现相同的功能。
import sys
class Redirect:
def __init__(self, output) -> None:
self.output = output
def __enter__(self):
self.org_stdout = sys.stdout
sys.stdout = self.output
def __exit__(self, *args):
sys.stdout = self.org_stdout
class Logger:
def __init__(self) -> None:
self._text = []
def write(self, text: str):
self._text.append(text)
def flush(self):
with open("test.log", "a", encoding="utf8") as f:
f.write("".join(self._text))
self._text.clear()
logger = Logger()
with Redirect(logger):
print("控制台不会输出内容")
print("这里的内容将会写入到文件", flush=True)
print("这里的内容将会被输出到控制台")
上述代码和redirect_stdout实现了相同的功能。
上下文管理器代码设计
渲染html文本
from typing import List
class Tag:
stack: List["Tag"] = []
def __init__(self, tag_name: str, **attributes):
self.tag_name = tag_name
self.attributes = attributes
self._children = []
if len(Tag.stack) > 0:
Tag.stack[-1]._children.append(self)
def __enter__(self):
Tag.stack.append(self)
return self
def __exit__(self, *args):
Tag.stack.pop()
def render(self, indent: int = 4, level: int = 1) -> str:
res = f"<{self.tag_name}"
for k,v in self.attributes.items():
if k in ('_class', 'class_', 'cls', 'classname', 'className'):
k = "class"
elif k == "text":
continue
k = k.replace("_", "-")
res += f' {k}="{v}"'
has_child = len(self._children) > 0
res += '>\n' if has_child else '>'
text = self.attributes.get("text")
if text:
if has_child:
res += " " * indent * level
res += text
if has_child:
res += '\n'
for child in self._children:
res += " " * indent * level
res += child.render(indent=indent, level=level+1)
if has_child:
res += " " * indent * (level - 1)
res += f'</{self.tag_name}>\n'
return res
with Tag("div", cls="text-white text-lg") as root:
Tag("div", cls="btn", disable="true")
Tag("div", cls="btn warning")
with Tag("div", cls="flex"):
with Tag("p", text="abc"):
Tag("span", text="not not not")
print(root.render())
控制台输出
<div class="text-white text-large">
<div class="btn" disable="true"></div>
<div class="btn warning"></div>
<div class="flex">
<p>
abc
<span>not not not</span>
</p>
</div>
</div>
上面的代码实现了一个简单的html渲染,最核心的设计在于,在初始化元素时,会将自己加入到stack末尾中Tag的children:
if len(Tag.stack) > 0:
Tag.stack[-1]._children.append(self)
在进入with代码块的时候,会将自己加入到stack中:
Tag.stack.append(self)
在退出with代码块的时候,将自己从stack中弹出:
Tag.stack.pop()
不过上面代码在多线程的情况下仍然是有问题的,多个线程会共用同一个stack会导致stack混乱。 我们可以进一步改良:
from typing import List, Dict
from threading import local
from concurrent.futures import ThreadPoolExecutor
class Tag:
stack_local = local()
def __init__(self, tag_name: str, **attributes):
self.tag_name = tag_name
self.attributes = attributes
self._children = []
if not hasattr(Tag.stack_local, "stack"):
Tag.stack_local.stack = []
if len(Tag.stack_local.stack) > 0:
Tag.stack_local.stack[-1]._children.append(self)
def __enter__(self):
Tag.stack_local.stack.append(self)
return self
def __exit__(self, *args):
Tag.stack_local.stack.pop()
def render(self, indent: int = 4, level: int = 1) -> str:
...
def render_html():
with Tag("div", cls="text-white text-lg") as root:
Tag("div", cls="btn", disable="true")
Tag("div", cls="btn warning")
with Tag("div", cls="flex"):
with Tag("p", text="abc"):
Tag("span", text="not not not")
return root.render()
with ThreadPoolExecutor(10) as pool:
futures = [pool.submit(render_html) for _ in range(10)]
for f in futures:
print(f.result())
这次,我们将stack放到local里,这是用来存储每个线程本地的变量。这样html文本就能正常渲染了。
简单读写锁
from threading import Condition
from typing import Any, NoReturn
from time import sleep
from concurrent.futures import ThreadPoolExecutor
class Frozen:
"""
用来保护value参数不被修改
a.b 和 a['b'] 都可以运行
以下都会报错:
a.b = 1
a.['b'] = 1
del a.b
del a['b']
"""
def __init__(self, value) -> None:
object.__setattr__(self, "__value", value)
def __getattribute__(self, name: str) -> Any:
return getattr(object.__getattribute__(self, "__value"), name)
def __setattr__(self, _, __) -> NoReturn:
raise ValueError(
"%s cannot be assigned any value!"
% object.__getattribute__(self, "__value")
)
def __delattr__(self, _: str) -> NoReturn:
raise ValueError(
"%s cannot be deleted any value!" % object.__getattribute__(self, "__value")
)
def __getitem__(self, name: str) -> Any:
return object.__getattribute__(self, "__value")[name]
__setitem__ = __setattr__
__delitem__ = __delattr__
def __str__(self):
return "<Frozen: %s>" % object.__getattribute__(self, "__value")
class ReadWriteLock:
"""
用来保护数据线程安全。被保护的数据可以由多个数据进行读取,或者只能有一个数据进行写入。
参数:
value:
被 ReadWriteLock 保护的值
使用方式:
>>> data = {"a": 1}
>>> rwLock = ReadWriteLock(data)
>>> with rwLock.read() as data_released: # 读取数据
... print(data_released)
>>> with rwLock.write() as data_released: # 写入数据
... data_released["b"] = 2
"""
def __init__(self, value: Any) -> None:
self._cond = Condition()
self._write = 0
self._read = 0
self._value = value
def read(self) -> "ReadStatus":
return ReadStatus(self)
def write(self) -> "WriteStatus":
return WriteStatus(self)
class ReadStatus:
def __init__(self, parent: ReadWriteLock) -> None:
self.parent = parent
def __enter__(self):
with self.parent._cond:
while self.parent._write > 0:
self.parent._cond.wait()
self.parent._read += 1
return Frozen(self.parent._value)
def __exit__(self, *args):
with self.parent._cond:
self.parent._read -= 1
self.parent._cond.notify_all()
class WriteStatus:
def __init__(self, parent: ReadWriteLock) -> None:
self.parent = parent
def __enter__(self):
with self.parent._cond:
while self.parent._write > 0 or self.parent._read > 0:
self.parent._cond.wait()
self.parent._write += 1
return self.parent._value
def __exit__(self, *args):
with self.parent._cond:
self.parent._write -= 1
self.parent._cond.notify_all()
def reading(data: ReadWriteLock):
with data.read() as data_released:
print(data_released)
sleep(1)
def writing(data: ReadWriteLock, i):
with data.write() as data_released:
data_released["b"] = i
sleep(1)
print(i)
print(data_released)
print(i)
data = {"a": 123}
rwlock = ReadWriteLock(data)
with ThreadPoolExecutor() as pool:
futures = [pool.submit(reading, rwlock) for _ in range(10)]
for f in futures:
f.result()
futures = [pool.submit(writing, rwlock, i) for i in range(10)]
for f in futures:
f.result()
这一段代码需要读者对 threading.Condition 比较熟悉。
读写锁,确保数据可以同时被多个线程读取,或者只被一个线程写入。 代码设计要点:
- 被保护数据与读写锁绑定
- 通过read 和 write 返回不同的状态对象
- ReadStatus 和 WriteStatus 通过 with 返回数据对象
- ReadStatus 通过 with 返回的数据对象进行了保护,使其不能被修改