浅谈生成器函数与yield语句

341 阅读8分钟

摘要

本文思路是从迭代器讲到生成器再到生成器的高级用法,结合自己的实践,也记录下自己的感受。

使用迭代器

迭代是访问集合元素的一种方式。迭代器是一个可以记住遍历的位置的对象。迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。

可迭代对象

我们已经知道可以对list、tuple、str等类型的数据使用for...in...的循环语法从其中依次拿到数据进行使用,我们把这样的过程称为遍历,也叫迭代,这个数据就是可迭代的。

>>> alist=[1,2,3,4,5]
>>> for ele in alist:
	print(ele)
	
1
2
3
4
5

但是,是否所有的数据类型都可以放到for...in...的语句中,然后每次从中取出一条数据供我们使用,即供我们迭代吗? 我们可以用isinstance()语句来判断:

>>> isinstance({},Iterable)
True
>>> isinstance('abc',Iterable)
True
>>> isinstance(100,Iterable)
False

返回值为True的数据为可迭代的,为False的为不可迭代的。 那么我们要怎样来构造一个迭代器呢? 参考一下python迭代器协议的定义。

含有 __ iter __() 方法或 __ getitem __() 方法的对象称之为可迭代对象

因此,我们可以重写__iter__()方法来自定义一个迭代器类。

>>> class MyList(object):
    def __init__(self):
        self.container = []

    def add(self, item):
        self.container.append(item)

    def __iter__(self):
        """返回一个迭代器"""
        # 我们暂时忽略如何构造一个迭代器对象
        pass
>>> isinstance(MyList(),Iterable)
True

可以看到MyList类的对象已经是一个可迭代的了。

迭代过程的本质

当然,仅凭上面的代码还是无法进行迭代的。如下:

>>> mylist=MyList()
>>> mylist.add(1)
>>> mylist.add(2)
>>> mylist.container
[1, 2]
>>> for i in mylist:
	print(i)
		
Traceback (most recent call last):
  File "<pyshell#14>", line 1, in <module>
    for i in mylist:
TypeError: iter() returned non-iterator of type 'NoneType'

错误提示表示没有返回一个迭代器对象,说明我们还需要一个迭代器对象,那么是时候分析迭代的过程了。在数据迭代的过程中,我们存在一个工具来帮我们标记现在遍历的位置,直到数据被迭代完成。这个工具就是我们的迭代器 (Iterator)可迭代对象的本质就是可以向我们提供一个这样的中间“人”即迭代器帮助我们对其进行迭代遍历使用。 实际上就是,可迭代对象通过__iter__()方法,向我们提供了迭代器。然后利用迭代器,我们再对对象进行迭代。在for x in something:语句中,我们首先调用something__iter__()方法来获取到一个该对象的迭代器,而后for循环会调用迭代器的__next__()方法,获取迭代器的下一个对象,赋值给x。

迭代器

显然,在迭代的过程中,迭代器是必不可少的,同样的我们也可以自定义一个迭代器。python迭代器协议中规定:

迭代器协议(iterator protocol)是指要实现对象的__iter()__ 和 next() 方法(注意:Python3 要实现__next__() 方法),其中,__iter __() 方法返回迭代器对象本身,next() 方法返回容器的下一个元素,在没有后续元素时抛出 StopIteration 异常。

所以,我们需要重写两个方法,一个是__iter__(),一个是__next__()方法。并且可迭代对象的__iter__()方法要返回我们的迭代器。

class MyList(object):
    """自定义的一个可迭代对象"""
    def __init__(self):
        self.items = []

    def add(self, val):
        self.items.append(val)

    def __iter__(self):
    	#返回迭代器
        myiterator = MyIterator(self)
        return myiterator


class MyIterator(object):
    """自定义的供上面可迭代对象使用的一个迭代器"""
    def __init__(self, mylist):
        self.mylist = mylist
        # current用来记录当前访问到的位置
        self.current = 0

    def __next__(self):
        if self.current < len(self.mylist.items):
            item = self.mylist.items[self.current]
            self.current += 1
            return item
        else:
            raise StopIteration

    def __iter__(self):
        return self

使用for...in来测试一下:

mylist = MyList()
    mylist.add(1)
    mylist.add(2)
    mylist.add(3)
    for num in mylist:
        print(num)

结果为:

1
2
3

其实还可以这样来遍历:

    mylist = MyList()
    mylist.add(1)
    mylist.add(2)
    mylist.add(3)
    iterator=MyIterator(mylist)
    while(1):
        try:
            a=iterator.__next__()
            print(a)
        except StopIteration:
            break

返回的结果都是一样的,这也说明在for ...in ...中隐性地调用了迭代器,并且将调用了__next__()方法赋值给循环元素。

使用生成器

生成器与生成器函数

如果一个函数包含 yield 表达式,那么它是一个生成器函数,调用它会返回一个生成器

生成器也是一种迭代器,在每次迭代的时候返回一个值,直到抛出StopIteration异常。

def func():
    return 1

def gen():
    yield 1

if __name__=='__main__':
    print(type(func))   # <class 'function'>
    print(type(gen))    # <class 'function'>

    print(type(func())) # <class 'int'>
    print(type(gen()))  # <class 'generator'>

可以看到生成器函数与普通函数的区别在于:

  1. 生成器函数没有return,而是yield
  2. 生成器函数返回的是一个生成器对象

那么它与一般的函数有何不同呢,它的特点在哪呢?下面就简单讲两个例子。

例一、读取文件

通常我们可以使用生成器来读取文件,比如csv文件。现在,我们想要计算csv文件的行数时,可以分别使用如下两种代码来完成,结果如下:

from memory_profiler import profile

@profile
def csv_func(file_path):
    file = open(file_path)
    csv_gen = file.read().split("\n")
    row_count = 0

    for row in csv_gen:
        row_count += 1

    print(f"Row count is {row_count}")

csv_func('E:\\pycharmProject\\webserverProject\\techcrunch.csv')

这份代码中,我们打开文件,并将数据读取到列表csv_gen中,然后每读取一个元素,row_count就加一。然后,我们使用生成器来达到我们的目标,代码如下:

def csv_test(file_path):
    for row in open(file_path, "r"):
        yield row

@profile
def csv_generator(file_path):
    csv_gen = csv_test(file_path)
    row_count = 0

    for row in csv_gen:
        row_count += 1

    print(f"Row count is {row_count}")

csv_generator('E:\\pycharmProject\\webserverProject\\techcrunch.csv')

这两份代码中,我们还导入了memory_profiler库,来查看运行中的内存的使用情况。那么结果如下:

#以下是用列表的方式读取数据的结果
Row count is 504051

Line #    Mem usage    Increment   Line Contents
================================================
     9     16.5 MiB     16.5 MiB   @profile
    10                             def csv_func(file_path):
    11     16.5 MiB      0.0 MiB       file = open(file_path)
    12     79.7 MiB     63.2 MiB       csv_gen = file.read().split("\n")
    13     79.7 MiB      0.0 MiB       row_count = 0
    14                             
    15     79.7 MiB      0.0 MiB       for row in csv_gen:
    16     79.7 MiB      0.0 MiB           row_count += 1
    17                             
    18     79.7 MiB      0.0 MiB       print(f"Row count is {row_count}")

#以下是用生成器读取数据的结果
Row count is 504050

Line #    Mem usage    Increment   Line Contents
================================================
    20     16.5 MiB     16.5 MiB   @profile
    21                             def csv_generator(file_path):
    22     16.5 MiB      0.0 MiB       csv_gen = csv_test(file_path)
    23     16.5 MiB      0.0 MiB       row_count = 0
    24                             
    25     16.7 MiB      0.1 MiB       for row in csv_gen:
    26     16.7 MiB      0.0 MiB           row_count += 1
    27                             
    28     16.7 MiB      0.0 MiB       print(f"Row count is {row_count}")

对比结果,看第一部分的12行csv_gen = file.read().split("\n"),说明用列表读取数据的话,内存占用增量为63.2MiB;看第二部分25行,可以知道用生成器读取数据的话,内存的占用增量为0.1MiB,显然,使用生成器可以有效减少内存占用。 其实也很好理解,当我们存在一个容器时,想要遍历其中的值,有两种做法:

  1. 先将容器中的所有值都取出来,然后进行遍历
  2. 从头开始,边取值边遍历

显然,第二种是更加节省内存空间的。 实际上,当文件大到一定程度时,只有第二种代码才能读取到数据内容,使用第一种代码的话,会报错MemoryError

例二、产生一个无限序列

首先,先着眼于有限序列,我们可以使用range()函数来生成一个有限序列。

>>> a = range(5)
>>> list(a)
[0, 1, 2, 3, 4]

然而,要产生一个无限序列的话,我们就需要使用到生成器了,因为我们的内存是有限的。代码如下:

import time
def infinite_sequence():
    num=0
    while True:
        yield num
        num+=1

if __name__=='__main__':
    for i in infinite_sequence():
        print(i,end=' ')
        time.sleep(0.1)

代码逻辑很简单:当infinite_sequence()函数执行时,yield生成一个数值并返回,且保留函数当时的num值,然后生成下一个数。所以通过它可以产生一个无限序列。 也可以不使用for循环,而是先通过生成器函数得到生成器,然后调用next()方法作用于生成器对象,来获取下一个生成的值。

>>> def infinite_sequence():
...    num=0
...    while True:
...        yield num
...        num+=1
...        
>>> gen=infinite_sequence()
>>> type(gen)
<class 'generator'>
>>> next(gen)
0
>>> next(gen)
1
>>> gen.__next__()
3
>>> gen.__next__()
4

那么简单分析一下:

  1. 调用 infinite_sequence()函数不会立即执行代码,而返回了一个生成器对象gen
  2. 当使用 next() (在 for 循环中会自动调用 next()) 作用于返回的生成器对象时,函数开始执行,在遇到 yield 的时候会【暂停】,并返回当前的迭代值;
  3. 当再次使用 next() 的时候,函数会被唤醒,从原来【暂停】的地方继续执行,直到遇到 yield 语句,如果没有 yield 语句,则抛出异常;
  4. 当使用 yield 时,它会自动创建__iter__()__next()__方法,即简单高效地生成了迭代器,gen.__ next__()返回值和使用next()的效果相同也说明了这一点。

生成器函数【暂停】时,会保留当时的上下文环境(即位置和变量);

小结

结合两个例子,总结一下:

  1. yield语句把函数变为一个生成器函数,函数返回值是生成器。
  2. 生成器函数通过yield语句可以简洁地生成__iter__()__next()__方法,即简洁地生成一个迭代器。
  3. 相比于一般函数,使用生成器可以节省内存开销 。
  4. 生成器函数的执行过程看起来就是不断地 执行->中断->执行->中断 的过程。yield语句暂停函数返回迭代值,next()语句唤醒函数继续执行。

参考文章

Python 中的黑暗角落(一):理解 yield 关键字