[Python黑魔法] 进阶 with 上下文管理器,with语句你不知道的秘密

189 阅读11分钟

本文介绍了python with语句常见的使用场景。通过with语句进行的错误处理,contextlib中有趣的上下文管理器,以及一些关于with语句的代码设计。需要读者有一定的python水平。

With 上下文管理器常用场景

with 上下文管理器在python中最常用的使用场景是在读写文件的时候:

with open('data.log', 'r', encoding='utf8') as f:
    result = f.read()

这样读写文件的好处就是在读写文件出现异常的时候,仍然能够确保文件能够正确关闭。其等价形式为:

try:
    f = open('data.log', 'r', encoding='utf8')
    result = f.read()
finally:
    f.close()

比如在 data.log 为二进制文件时,utf-8编码无法正确解码其内容,会抛出 UnicodeDecodeError,如果不使用 try, except, finally 或者 with,这个异常会无法关闭文件导致资源泄露。

多个上下文管理器

多个上下文管理器可以放到同一行, 比较常见的场景就是复制文件:

with open('data.png', 'rb') as src, open('/pictures/data.png', 'wb') as des:
    des.write(src.read())

进入代码块的时候会依次运行上下文管理器的"__enter__" 方法, 退出会依次运行上下文管理器的"__exit__" 方法。

自定义 with 上下文管理器

with语句可以在进入代码块和出代码块的时候执行代码,这样的特性比较适合用来写计时器。

在定义类的时候,定义 "__enter__" 和 "__exit__" 方法,这个类的实例就是一个上下文管理器,在进入代码块的时候 "__enter__" 方法会执行,退出代码块的时候 "__exit__" 方法就会执行。with 语句中,as 后面的变量就是 "__enter__" 方法的返回值。

from time import perf_counter, sleep

class Timer:
    def __enter__(self):
        self.t0 = perf_counter() 
        # 返回值会赋值给 with 语句中 as 后面跟着的变量
        return self 
        
    def __exit__(self, exc_type, exc_value, traceback):
        self.time_total = perf_counter() - self.t0 

with Timer() as timer:
    sleep(1)
print(timer.time_total)

输出结果 1.0050958750071004

上面的代码,我们定义了一个简易的计时器,在进入代码块的时候,运行了 self.t0 = perf_counter() 记录的初始时间,在退出代码块的时候运行了 self.time_total = perf_counter() - self.t0, 计算了从进入代码块到退出代码块花费的总时间。

在 "with Timer() as timer:" 语句中,"__enter__" 方法返回的值,赋值给了 timer 变量,所以可以直接在后面的代码中 "print(timer.time_total)"。

"__enter__" 和 "__exit__" 也是有“异步”版本的,异步的上下文管理器,使用"__aenter__" 和 "__aexit__" 方法。

import asyncio 

class AsyncWaitOneSecond:
    
    async def __aenter__(self):
        # 进入代码
        print("正在睡觉,请勿打扰")
        await asyncio.sleep(1)
        print("睡醒了")
    
    async def __aexit__(self, exc_type, exc_value, traceback):
        # 退出时运行
        print("代码运行结束")
        
async def main():
    async with AsyncWaitOneSecond():
        print("正在运行代码。。。")
    
asyncio.run(main())

以上代码在控制台的输出如下:

正在睡觉,请勿打扰
睡醒了
正在运行代码。。。
代码运行结束

上下文管理器也可以通过标准库 contextlib 中的 contextmanager 和 asynccontextmanager 实现。 比如上面的计时器例子:

from contextlib import contextmanager
from time import perf_counter, sleep 

@contextmanager
def timeit():
    result = {}
    t0 = perf_counter()
    
    # yield 出去,进入with后面的代码块
    yield result
    
    result["time_total"] = perf_counter() - t0

with timeit() as result:
    sleep(1)

print(result)

输出结果

{'time_total': 1.0051242079935037}

上面的代码,可能需要一些生成器的知识才能看得比较明白,如果大家感兴趣,后面我会再写一篇详细解释生成器的文章。

在 yield 之前的代码其实等价于在之前 class 中 "__enter__" 的代码,yield 后面跟着的变量,其实相当于 "__enter__" 方法的返回值。 在yield之后的代码,相当于class中 "__exit__" 方法里的代码。

异步版本的上下文管理器也是类似的:

import asyncio 
from contextlib import asynccontextmanager

@asynccontextmanager
async def wait_one_second():
    print("正在睡觉,请勿打扰")
    await asyncio.sleep(1)
    print("睡醒了")
    yield 
    print("代码运行结束")
    
async def main():
    async with wait_one_second():
        print("正在运行代码。。。")
        
asyncio.run(main())

以上代码在控制台的输出如下:

正在睡觉,请勿打扰
睡醒了
正在运行代码。。。
代码运行结束

使用 With 上下文管理器 处理错误

不知道读者有没有发现,在上面代码"__exit__" 方法,中的参数有 "exc_type", "exc_value", "traceback",这三个参数,在python中,通常是用来处理异常的,分别为:异常的类,异常的实例 和 traceback实例。是的,"__exit__" 方法可以用来处理在with代码块中的异常。

在 "__exit__" 方法中如果返回值可以转换为True(即bool(x) 为 True),则该异常就会停止传播。如果返回值为可以转换为False,则该异常就会继续传播。没有用return返回实际上也会返回None,也可以转换为False。

class ThrowException:
    def __enter__(self):
        pass 
    
    def __exit__(self, exc_type, exc_value, traceback):
        return False

with ThrowException():
    raise ValueError

在上面的代码中,我在with代码块中抛出了一个ValueError,由于__exit__方法中返回值为False,这个异常将会继续传播,导致程序崩溃。

控制台输出如下:

Traceback (most recent call last):
  File "/Users/a1/projects/test/test.py", line 10, in <module>
    raise ValueError
ValueError

如果在"__exit__"方法中返回True,则不会抛出异常导致程序崩溃。

class SuppressException:
    def __enter__(self):
        pass 
    def __exit__(self, exc_type, exc_value, traceback):
        return True

with SuppressException():
    print("不会抛出错误")
    raise ValueError
    print("但是这行也不会执行")

控制台输出如下:

不会抛出错误

在exit中处理异常后,虽然不会抛出异常,但是异常后面的代码也不会执行了。

当然,你也可以指定上下文处理器对哪些异常不会抛出:

class SuppressSpecificException:
    def __init__(self, exc_class):
        self.exc_class = exc_class
        
    def __enter__(self):
        pass
    
    def __exit__(self, exc_type, exc_value, traceback):
        if isinstance(exc_value, self.exc_class):
            return True

with SuppressSpecificException(ValueError):
    print("不会抛出Value Error 异常")
    raise ValueError

with SuppressSpecificException(ValueError):
    print("只处理了ValueError, 没有处理RuntimeError, 会抛出异常")
    raise RuntimeError

控制台输出如下:

不会抛出Value Error 异常
只处理了ValueError, 没有处理RuntimeError, 会抛出异常
Traceback (most recent call last):
  File "/Users/a1/projects/test/test.py", line 19, in <module>
    raise RuntimeError
RuntimeError

contextlib 中部分有趣的上下文管理器

在python标准库 contextlib 中,有许多有意思的上下文管理器。本文由于篇幅有限,仅介绍:

  1. ContextDecorator
  2. redirect_stdout

contextlib 更多内容,请参考python文档:docs.python.org/zh-cn/3/lib…

ContextDecorator

这个类可以让你的上下文管理器迅速变为一个装饰器,相当于让被装饰器函数,在上下文管理器的with模块中运行。 示例代码:

from contextlib import ContextDecorator
from time import perf_counter, sleep

class Timer(ContextDecorator):
    
    def __enter__(self):
        self.t0 = perf_counter()
        return self
       
    def __exit__(self, exc_type, exc_value, traceback):
        self.time_total = perf_counter() - self.t0

timer = Timer()

@timer
def do_something_heavy():
    sleep(1)

do_something_heavy()
print(timer.time_total)

控制台输出结果

1.000443875003839

在上述代码中,Timer这个类继承了ContextDecorator,使得timer有了作为装饰器的能力。这个代码等价于:

with Timer() as timer:
    do_something_heavy()
print(timer.time_total)

本质上Timer是继承了ContextDecorator的__call__方法。如果Timer不继承这个ContextDecorator, 添加合适的__call__方法也可以达到同样的效果。__call__方法是类的实例被小括号调用时,所运行的方法。

from functools import wraps
from time import perf_counter, sleep

class Timer:

    def __enter__(self):
        self.t0 = perf_counter()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.time_total = perf_counter() - self.t0

    def __call__(self, func):

        @wraps(func)
        def deco(*args, **kwargs):
            with self:
                return func(*args, **kwargs)

        return deco


timer = Timer()

@timer
def do_something_heavy():
    sleep(1)

do_something_heavy()
print(timer.time_total)

在控制台输出如下:

1.0050645000010263

上述代码中的Timer类,没有继承ContextDecorator,仍然实现了相同的功能。

ContextDecorator 也有异步版本,AsyncContextDecorator,本身的原理是一样的,这里不做过多介绍了。

redirect_stdout

redirect_stdout 是一个非常有趣的上下文管理器,它通过修改内置的一些方法,可以将sys.stdout输出到别的地方。print函数默认也是用的sys.stdout,也就是说,哪怕只是使用print也可以做一个类似日志。比如,将print中的内容发送到数据库里存起来。

只要向redirect_stdout传入一个有write和flush对象的方法,在print的时候就会调用这两个方法。

示例代码:

from contextlib import redirect_stdout


class Logger:

    def __init__(self) -> None:
        self._text = []

    def write(self, text: str):
        self._text.append(text)

    def flush(self):
        with open("test.log", "a", encoding="utf8") as f:
            f.write("".join(self._text))
        self._text.clear()


logger = Logger()
with redirect_stdout(logger):
    print("控制台不会输出内容")
    print("这里的内容将会写入到文件", flush=True)
print("这里的内容将会被输出到控制台")

控制台输出:

这里的内容将会被输出到控制台

此时,在文件夹下会有一个test.log记录print里的内容。 test.log 内容如下

控制台不会输出内容
这里的内容将会写入到文件

注意:如果不将print中flush参数设置为True,只有write方法会被调用,flush方法不会被调用

redirect_stdout通过修改sys.stdout实现的。我们也可以通过修改sys.stdout实现相同的功能。

import sys


class Redirect:

    def __init__(self, output) -> None:
        self.output = output

    def __enter__(self):
        self.org_stdout = sys.stdout
        sys.stdout = self.output

    def __exit__(self, *args):
        sys.stdout = self.org_stdout


class Logger:

    def __init__(self) -> None:
        self._text = []

    def write(self, text: str):
        self._text.append(text)

    def flush(self):
        with open("test.log", "a", encoding="utf8") as f:
            f.write("".join(self._text))
        self._text.clear()


logger = Logger()
with Redirect(logger):
    print("控制台不会输出内容")
    print("这里的内容将会写入到文件", flush=True)
print("这里的内容将会被输出到控制台")

上述代码和redirect_stdout实现了相同的功能。

上下文管理器代码设计

渲染html文本

from typing import List


class Tag:
    stack: List["Tag"] = []
    
    def __init__(self, tag_name: str, **attributes):
        self.tag_name = tag_name 
        self.attributes = attributes
        self._children = []
        
        if len(Tag.stack) > 0:
            Tag.stack[-1]._children.append(self)
    
    def __enter__(self):
        Tag.stack.append(self)
        return self 
    
    def __exit__(self, *args):
        Tag.stack.pop()
    
    def render(self, indent: int = 4, level: int = 1) -> str:
        res = f"<{self.tag_name}"
        for k,v in self.attributes.items():
            if k in ('_class', 'class_', 'cls', 'classname', 'className'):
                k = "class"
            elif k == "text":
                continue 
            k = k.replace("_", "-")
            res += f' {k}="{v}"'
        
        has_child = len(self._children) > 0 
        res += '>\n' if has_child else '>'
        text = self.attributes.get("text")
        
        if text:
            if has_child:
                res += " " * indent * level 
            res += text
            if has_child:
                res += '\n'
        
        for child in self._children:
            res += " " * indent * level 
            res += child.render(indent=indent, level=level+1)
        if has_child:
            res += " " * indent * (level - 1)
        res += f'</{self.tag_name}>\n'
        return res
        
with Tag("div", cls="text-white text-lg") as root:
    Tag("div", cls="btn", disable="true")
    Tag("div", cls="btn warning")
    with Tag("div", cls="flex"):
        with Tag("p", text="abc"):
            Tag("span", text="not not not")
        
print(root.render())

控制台输出

<div class="text-white text-large">
    <div class="btn" disable="true"></div>
    <div class="btn warning"></div>
    <div class="flex">
        <p>
            abc
            <span>not not not</span>
        </p>
    </div>
</div>

上面的代码实现了一个简单的html渲染,最核心的设计在于,在初始化元素时,会将自己加入到stack末尾中Tag的children:

if len(Tag.stack) > 0:
    Tag.stack[-1]._children.append(self)

在进入with代码块的时候,会将自己加入到stack中:

Tag.stack.append(self)

在退出with代码块的时候,将自己从stack中弹出:

Tag.stack.pop()

不过上面代码在多线程的情况下仍然是有问题的,多个线程会共用同一个stack会导致stack混乱。 我们可以进一步改良:

from typing import List, Dict
from threading import local
from concurrent.futures import ThreadPoolExecutor


class Tag:
    stack_local = local()

    def __init__(self, tag_name: str, **attributes):
        self.tag_name = tag_name
        self.attributes = attributes
        self._children = []

        if not hasattr(Tag.stack_local, "stack"):
            Tag.stack_local.stack = []

        if len(Tag.stack_local.stack) > 0:
            Tag.stack_local.stack[-1]._children.append(self)

    def __enter__(self):
        Tag.stack_local.stack.append(self)
        return self

    def __exit__(self, *args):
        Tag.stack_local.stack.pop()

    def render(self, indent: int = 4, level: int = 1) -> str:
        ...


def render_html():
    with Tag("div", cls="text-white text-lg") as root:
        Tag("div", cls="btn", disable="true")
        Tag("div", cls="btn warning")
        with Tag("div", cls="flex"):
            with Tag("p", text="abc"):
                Tag("span", text="not not not")
    return root.render()


with ThreadPoolExecutor(10) as pool:
    futures = [pool.submit(render_html) for _ in range(10)]

    for f in futures:
        print(f.result())

这次,我们将stack放到local里,这是用来存储每个线程本地的变量。这样html文本就能正常渲染了。

简单读写锁

from threading import Condition
from typing import Any, NoReturn
from time import sleep
from concurrent.futures import ThreadPoolExecutor


class Frozen:
    """
    用来保护value参数不被修改
    a.b 和 a['b'] 都可以运行
    以下都会报错:
        a.b = 1
        a.['b'] = 1
        del a.b
        del a['b']
    """

    def __init__(self, value) -> None:
        object.__setattr__(self, "__value", value)

    def __getattribute__(self, name: str) -> Any:
        return getattr(object.__getattribute__(self, "__value"), name)

    def __setattr__(self, _, __) -> NoReturn:
        raise ValueError(
            "%s cannot be assigned any value!"
            % object.__getattribute__(self, "__value")
        )

    def __delattr__(self, _: str) -> NoReturn:
        raise ValueError(
            "%s cannot be deleted any value!" % object.__getattribute__(self, "__value")
        )

    def __getitem__(self, name: str) -> Any:
        return object.__getattribute__(self, "__value")[name]

    __setitem__ = __setattr__
    __delitem__ = __delattr__

    def __str__(self):
        return "<Frozen: %s>" % object.__getattribute__(self, "__value")


class ReadWriteLock:
    """
    用来保护数据线程安全。被保护的数据可以由多个数据进行读取,或者只能有一个数据进行写入。

    参数:
        value:
            被 ReadWriteLock 保护的值

    使用方式:
    >>> data = {"a": 1}
    >>> rwLock = ReadWriteLock(data)
    >>> with rwLock.read() as data_released: # 读取数据
    ...     print(data_released)
    >>> with rwLock.write() as data_released: # 写入数据
    ...     data_released["b"] = 2
    """

    def __init__(self, value: Any) -> None:
        self._cond = Condition()
        self._write = 0
        self._read = 0
        self._value = value

    def read(self) -> "ReadStatus":
        return ReadStatus(self)

    def write(self) -> "WriteStatus":
        return WriteStatus(self)


class ReadStatus:

    def __init__(self, parent: ReadWriteLock) -> None:
        self.parent = parent

    def __enter__(self):
        with self.parent._cond:
            while self.parent._write > 0:
                self.parent._cond.wait()
            self.parent._read += 1
            return Frozen(self.parent._value)

    def __exit__(self, *args):
        with self.parent._cond:
            self.parent._read -= 1
            self.parent._cond.notify_all()


class WriteStatus:

    def __init__(self, parent: ReadWriteLock) -> None:
        self.parent = parent

    def __enter__(self):
        with self.parent._cond:
            while self.parent._write > 0 or self.parent._read > 0:
                self.parent._cond.wait()
            self.parent._write += 1
            return self.parent._value

    def __exit__(self, *args):
        with self.parent._cond:
            self.parent._write -= 1
            self.parent._cond.notify_all()


def reading(data: ReadWriteLock):
    with data.read() as data_released:
        print(data_released)
        sleep(1)


def writing(data: ReadWriteLock, i):
    with data.write() as data_released:
        data_released["b"] = i
        sleep(1)
        print(i)
    print(data_released)
    print(i)


data = {"a": 123}
rwlock = ReadWriteLock(data)
with ThreadPoolExecutor() as pool:
    futures = [pool.submit(reading, rwlock) for _ in range(10)]
    for f in futures:
        f.result()

    futures = [pool.submit(writing, rwlock, i) for i in range(10)]
    for f in futures:
        f.result()

这一段代码需要读者对 threading.Condition 比较熟悉。

读写锁,确保数据可以同时被多个线程读取,或者只被一个线程写入。 代码设计要点:

  1. 被保护数据与读写锁绑定
  2. 通过read 和 write 返回不同的状态对象
  3. ReadStatus 和 WriteStatus 通过 with 返回数据对象
  4. ReadStatus 通过 with 返回的数据对象进行了保护,使其不能被修改