Python中的"yield"关键字是什么作用?

12886 投票
50 回答
3325125 浏览
提问于 2025-04-11 19:07

yield 这个关键词在 Python 中有什么功能呢?

比如,我想理解这段代码1

def _get_child_candidates(self, distance, min_dist, max_dist):
    if self._leftchild and distance - max_dist < self._median:
        yield self._leftchild
    if self._rightchild and distance + max_dist >= self._median:
        yield self._rightchild  

这是调用者的部分:

result, candidates = [], [self]
while candidates:
    node = candidates.pop()
    distance = node._get_dist(obj)
    if distance <= max_dist and distance >= min_dist:
        result.extend(node._values)
    candidates.extend(node._get_child_candidates(distance, min_dist, max_dist))
return result

当调用 _get_child_candidates 这个方法时,会发生什么呢?

是返回一个列表吗?还是返回一个单独的元素?这个方法会再次被调用吗?后续的调用什么时候会停止呢?


1. 这段代码是 Jochen Schulz (jrschulz) 写的,他做了一个很棒的 Python 库,用于度量空间。完整的源代码链接是:Module mspace

50 个回答

772

可以这样理解:

迭代器其实就是一个听起来很复杂的词,指的是一个有 next() 方法的对象。所以,当你用 yield 的时候,实际上就是在做这样的事情:

原始版本:

def some_function():
    for i in xrange(4):
        yield i

for i in some_function():
    print i

这基本上就是 Python 解释器对上面代码的处理方式:

class it:
    def __init__(self):
        # Start at -1 so that we get 0 when we add 1 below.
        self.count = -1

    # The __iter__ method will be called once by the 'for' loop.
    # The rest of the magic happens on the object returned by this method.
    # In this case it is the object itself.
    def __iter__(self):
        return self

    # The next method will be called repeatedly by the 'for' loop
    # until it raises StopIteration.
    def next(self):
        self.count += 1
        if self.count < 4:
            return self.count
        else:
            # A StopIteration exception is raised
            # to signal that the iterator is done.
            # This is caught implicitly by the 'for' loop.
            raise StopIteration

def some_func():
    return it()

for i in some_func():
    print i

如果想更深入了解背后的原理,可以把 for 循环改写成这样:

iterator = some_func()
try:
    while 1:
        print iterator.next()
except StopIteration:
    pass

这样说是不是更容易理解,还是让你更困惑了呢? :)

我得说明一下,这其实是为了说明问题而做的简单化处理。 :)

2544

快速理解 yield

当你看到一个包含 yield 的函数时,可以用这个简单的方法来理解它的作用:

  1. 在函数开头插入一行 result = []
  2. 把每个 yield expr 替换成 result.append(expr)
  3. 在函数底部插入一行 return result
  4. 太好了——没有 yield 语句了!可以阅读并理解代码。
  5. 把这个函数和原来的定义进行比较。

这个方法可以让你大致了解函数的逻辑,但实际上,使用 yield 的效果和用列表的方式是有很大不同的。在很多情况下,使用 yield 的方法会更节省内存,并且速度更快。在其他情况下,这个方法可能会让你陷入无限循环,尽管原来的函数运行得很好。继续阅读以了解更多……

不要混淆可迭代对象、迭代器和生成器

首先,了解一下 迭代器协议——当你写下:

for x in mylist:
    ...loop body...

Python 会执行以下两个步骤:

  1. 获取 mylist 的迭代器:

    调用 iter(mylist) ——这会返回一个带有 next() 方法(在 Python 3 中是 __next__())的对象。

    [这是大多数人忘记告诉你的步骤]

  2. 使用迭代器遍历项目:

    不断调用步骤 1 中返回的迭代器的 next() 方法。next() 的返回值会被赋值给 x,然后执行循环体。如果在 next() 中抛出 StopIteration 异常,说明迭代器中没有更多值了,循环就会结束。

实际上,Python 每次想要遍历一个对象的内容时,都会执行上述两个步骤——这可以是一个 for 循环,也可以是像 otherlist.extend(mylist) 这样的代码(其中 otherlist 是一个 Python 列表)。

这里 mylist 是一个 可迭代对象,因为它实现了迭代器协议。在用户自定义的类中,你可以实现 __iter__() 方法,使你的类的实例变得可迭代。这个方法应该返回一个 迭代器。迭代器是一个带有 next() 方法的对象。可以在同一个类中实现 __iter__()next(),并让 __iter__() 返回 self。这在简单情况下是可行的,但如果你想让两个迭代器同时遍历同一个对象,就不行了。

所以这就是迭代器协议,许多对象都实现了这个协议:

  1. 内置的列表、字典、元组、集合和文件。
  2. 实现了 __iter__() 的用户自定义类。
  3. 生成器。

注意,for 循环并不知道它正在处理什么类型的对象——它只是遵循迭代器协议,开心地一个接一个地获取项目,调用 next()。内置的列表一个一个地返回它们的项目,字典一个一个地返回 ,文件一个一个地返回 ,等等。而生成器返回的……这就是 yield 的用武之地:

def f123():
    yield 1
    yield 2
    yield 3

for item in f123():
    print item

如果在 f123() 中有三个 return 语句,只有第一个会被执行,函数就会退出。但 f123() 不是普通的函数。当调用 f123() 时,它 不会 返回任何 yield 语句中的值!它返回的是一个生成器对象。此外,函数并没有真正退出——它进入了一种挂起状态。当 for 循环尝试遍历生成器对象时,函数会从之前返回的 yield 后的下一行继续执行,执行下一行代码,这里是一个 yield 语句,并将其作为下一个项目返回。这种情况会一直持续到函数退出,此时生成器会抛出 StopIteration,循环也会结束。

所以生成器对象有点像一个适配器——一端它遵循迭代器协议,提供 __iter__()next() 方法,让 for 循环开心。而另一端,它只运行函数到获取下一个值的程度,然后又将其放回挂起状态。

为什么使用生成器?

通常,你可以编写不使用生成器但实现相同逻辑的代码。一个选择是使用我之前提到的临时列表“技巧”。但这并不适用于所有情况,比如如果你有无限循环,或者当你有一个非常长的列表时,它可能会低效地使用内存。另一种方法是实现一个新的可迭代类 SomethingIter,它在实例成员中保持状态,并在其 next()(或 Python 3 中的 __next__())方法中执行下一个逻辑步骤。根据逻辑的不同,next() 方法中的代码可能会变得非常复杂,并容易出错。在这种情况下,生成器提供了一个干净且简单的解决方案。

17987
>>> class Bank(): # Let's create a bank, building ATMs
...    crisis = False
...    def create_atm(self):
...        while not self.crisis:
...            yield "$100"
>>> hsbc = Bank() # When everything's ok the ATM gives you as much as you want
>>> corner_street_atm = hsbc.create_atm()
>>> print(corner_street_atm.next())
$100
>>> print(corner_street_atm.next())
$100
>>> print([corner_street_atm.next() for cash in range(5)])
['$100', '$100', '$100', '$100', '$100']
>>> hsbc.crisis = True # Crisis is coming, no more money!
>>> print(corner_street_atm.next())
<type 'exceptions.StopIteration'>
>>> wall_street_atm = hsbc.create_atm() # It's even true for new ATMs
>>> print(wall_street_atm.next())
<type 'exceptions.StopIteration'>
>>> hsbc.crisis = False # The trouble is, even post-crisis the ATM remains empty
>>> print(corner_street_atm.next())
<type 'exceptions.StopIteration'>
>>> brand_new_atm = hsbc.create_atm() # Build a new one to get back in business
>>> for cash in brand_new_atm:
...    print cash
$100
$100
$100
$100
$100
$100
$100
$100
$100
...

要理解 yield 的作用,首先你得明白什么是 生成器。而在理解生成器之前,你需要先了解 可迭代对象

可迭代对象

当你创建一个列表时,你可以一个一个地读取它的元素。一个一个读取元素的过程叫做迭代:

>>> mylist = [1, 2, 3]
>>> for i in mylist:
...    print(i)
1
2
3

mylist 是一个 可迭代对象。当你使用列表推导式时,你创建了一个列表,因此也创建了一个可迭代对象:

>>> mylist = [x*x for x in range(3)]
>>> for i in mylist:
...    print(i)
0
1
4

任何你可以用 "for... in..." 来操作的东西都是可迭代对象,比如 列表字符串、文件等等。

这些可迭代对象很方便,因为你可以随意读取它们,但它们会把所有的值都存储在内存中,当你有很多值时,这并不总是你想要的。

生成器

生成器是一种 迭代器,它是一种只能迭代一次的可迭代对象。生成器不会把所有的值都存储在内存中,而是 动态生成值

>>> mygenerator = (x*x for x in range(3))
>>> for i in mygenerator:
...    print(i)
0
1
4

它和可迭代对象的写法差不多,只是你用 () 而不是 []。但是,你 不能 第二次使用 for i in mygenerator,因为生成器只能用一次:它计算出第一个值后就忘记了,然后计算第二个,依此类推,直到计算完最后一个值。

Yield

yield 是一个关键字,使用方式类似于 return,不过这个函数会返回一个生成器。

>>> def create_generator():
...    mylist = range(3)
...    for i in mylist:
...        yield i*i
...
>>> mygenerator = create_generator() # create a generator
>>> print(mygenerator) # mygenerator is an object!
<generator object create_generator at 0xb7555c34>
>>> for i in mygenerator:
...     print(i)
0
1
4

这里有个没什么用的例子,但当你知道你的函数会返回一大堆值,而你只需要读取一次时,它就很有用了。

要掌握 yield,你必须明白 当你调用这个函数时,函数体里的代码并不会立即执行。 函数只会返回生成器对象,这一点有点复杂。

然后,每次用 for 调用生成器时,代码会从上次停止的地方继续执行。

现在,难点来了:

第一次用 for 调用从你的函数创建的生成器对象时,它会从头开始执行函数里的代码,直到遇到 yield,然后返回循环的第一个值。接下来每次调用都会执行函数里的下一次循环,返回下一个值。这会一直持续到生成器被认为是空的,也就是函数执行完毕而没有再遇到 yield。这可能是因为循环结束了,或者因为某个 "if/else" 条件不再满足。


你的代码解释

生成器:

# Here you create the method of the node object that will return the generator
def _get_child_candidates(self, distance, min_dist, max_dist):

    # Here is the code that will be called each time you use the generator object:

    # If there is still a child of the node object on its left
    # AND if the distance is ok, return the next child
    if self._leftchild and distance - max_dist < self._median:
        yield self._leftchild

    # If there is still a child of the node object on its right
    # AND if the distance is ok, return the next child
    if self._rightchild and distance + max_dist >= self._median:
        yield self._rightchild

    # If the function arrives here, the generator will be considered empty
    # There are no more than two values: the left and the right children

调用者:

# Create an empty list and a list with the current object reference
result, candidates = list(), [self]

# Loop on candidates (they contain only one element at the beginning)
while candidates:

    # Get the last candidate and remove it from the list
    node = candidates.pop()

    # Get the distance between obj and the candidate
    distance = node._get_dist(obj)

    # If the distance is ok, then you can fill in the result
    if distance <= max_dist and distance >= min_dist:
        result.extend(node._values)

    # Add the children of the candidate to the candidate's list
    # so the loop will keep running until it has looked
    # at all the children of the children of the children, etc. of the candidate
    candidates.extend(node._get_child_candidates(distance, min_dist, max_dist))

return result

这段代码有几个聪明的地方:

  • 循环在一个列表上迭代,但这个列表在迭代的过程中还在扩展。这是一种简洁的方式来遍历所有这些嵌套数据,尽管有点危险,因为你可能会陷入无限循环。在这个例子中,candidates.extend(node._get_child_candidates(distance, min_dist, max_dist)) 会耗尽生成器的所有值,但 while 仍然会创建新的生成器对象,这些对象会产生不同于之前的值,因为它们不是在同一个节点上应用的。

  • extend() 方法是列表对象的方法,它期望一个可迭代对象,并将其值添加到列表中。

通常,我们会传递一个列表给它:

>>> a = [1, 2]
>>> b = [3, 4]
>>> a.extend(b)
>>> print(a)
[1, 2, 3, 4]

但在你的代码中,它接收的是一个生成器,这很好,因为:

  1. 你不需要读取值两次。
  2. 你可能有很多子节点,而你不想把它们全部存储在内存中。

而且它能正常工作,因为 Python 不在乎方法的参数是不是列表。Python 期望的是可迭代对象,所以它可以处理字符串、列表、元组和生成器!这被称为鸭子类型,是 Python 很酷的原因之一。但这又是另一个故事,另一个问题……

你可以在这里停下,或者再多读一点,看看生成器的高级用法:

控制生成器的耗尽

注意: 对于 Python 3,使用 print(corner_street_atm.__next__())print(next(corner_street_atm))

这在控制资源访问等各种场景中可能会很有用。

Itertools,你最好的朋友

itertools 模块包含了一些特殊函数,用于操作可迭代对象。有没有想过要复制一个生成器?将两个生成器连接起来?用一行代码将值分组到嵌套列表中?Map / Zip 而不创建另一个列表?

那就只需 import itertools

举个例子?让我们看看四匹马比赛的到达顺序:

>>> horses = [1, 2, 3, 4]
>>> races = itertools.permutations(horses)
>>> print(races)
<itertools.permutations object at 0xb754f1dc>
>>> print(list(itertools.permutations(horses)))
[(1, 2, 3, 4),
 (1, 2, 4, 3),
 (1, 3, 2, 4),
 (1, 3, 4, 2),
 (1, 4, 2, 3),
 (1, 4, 3, 2),
 (2, 1, 3, 4),
 (2, 1, 4, 3),
 (2, 3, 1, 4),
 (2, 3, 4, 1),
 (2, 4, 1, 3),
 (2, 4, 3, 1),
 (3, 1, 2, 4),
 (3, 1, 4, 2),
 (3, 2, 1, 4),
 (3, 2, 4, 1),
 (3, 4, 1, 2),
 (3, 4, 2, 1),
 (4, 1, 2, 3),
 (4, 1, 3, 2),
 (4, 2, 1, 3),
 (4, 2, 3, 1),
 (4, 3, 1, 2),
 (4, 3, 2, 1)]

理解迭代的内部机制

迭代是一个涉及可迭代对象(实现 __iter__() 方法)和迭代器(实现 __next__() 方法)的过程。可迭代对象是你可以从中获取迭代器的任何对象。迭代器是让你能够对可迭代对象进行迭代的对象。

关于这个内容,你可以在这篇文章中了解更多,讲述了 如何使用 for 循环

撰写回答