在Python类中实现带__iter__方法的递归函数

2 投票
2 回答
677 浏览
提问于 2025-04-30 20:02

我正在解决一个问题,需要创建一个Python类来生成一个列表的所有排列组合。现在我遇到了以下几个问题:

  1. 我可以很简单地用一个递归函数来完成这个任务,但作为一个类,我觉得应该使用iter方法。我的方法调用了一个递归函数(list_all),这个函数和我的iter几乎一模一样,这让我觉得很不安。请问我该如何修改我的递归函数,以符合iter的最佳实践呢?
  2. 我写了这段代码,看到它能工作,但我感觉自己并没有完全理解它!我尝试逐行跟踪代码在一个测试案例中的表现,但对我来说,列表中的第一个元素每次都像是被冻结了,剩下的列表则被打乱了。然而,输出的结果却是意想不到的顺序。我有些搞不懂!

谢谢!

class permutations():
  def __init__(self, ls):
    self.list = ls

  def __iter__(self):
    ls = self.list
    length = len(ls)
    if length <= 1:
      yield ls
    else:
      for p in self.list_all(ls[1:]):
        for x in range(length):
          yield p[:x] + ls[0:1] + p[x:]  

  def list_all(self, ls):
    length = len(ls)
    if length <= 1:
      yield ls
    else:
      for p in self.list_all(ls[1:]):
        for x in range(length):
          yield p[:x] + ls[0:1] + p[x:]
暂无标签

2 个回答

0

你的 list_all 方法已经是一个生成器了,所以你可以直接在 __iter__ 中返回它:

class permutations():
    def __init__(self, ls):
        self.list = ls

    def __iter__(self):
        return self.list_all(self.list)

    def list_all(self, ls):
        length = len(ls)
        if length <= 1:
            yield ls
        else:
            for p in self.list_all(ls[1:]):
                for x in range(length):
                    yield p[:x] + ls[0:1] + p[x:]

这样写起来更简洁,运行速度也更快。

你还可以选择在 __iter__ 里面定义 list_all

class permutations2():
    def __init__(self, ls):
        self.list = ls

    def __iter__(self):
        def list_all(ls):
            length = len(ls)
            if length <= 1:
                yield ls
            else:
                for p in list_all(ls[1:]):
                    for x in range(length):
                        yield p[:x] + ls[0:1] + p[x:]
                    
        return list_all(self.list)

对比 permutations 和我的 permutations2 的运行时间,结果几乎是一样的。

2

只需要在 __iter__ 里面调用 self.list_all 就可以了:

class permutations():
  def __init__(self, ls):
    self.list = ls

  def __iter__(self):
    for item in self.list_all(self.list):
      yield item

  def list_all(self, ls):
    length = len(ls)
    if length <= 1:
      yield ls
    else:
      for p in self.list_all(ls[1:]):
        for x in range(length):
          yield p[:x] + ls[0:1] + p[x:]

撰写回答