Scala中相当于Python生成器的是什么?

61 投票
5 回答
23694 浏览
提问于 2025-04-15 18:28

在Scala中,能不能实现类似Python中yield语句的功能?也就是说,它能记住函数的局部状态,每次被调用时都能“产出”下一个值?

我想要这样的功能,目的是把一个递归函数转换成一个迭代器。大概是这样的:

# this is python
def foo(i):
  yield i
  if i > 0:
    for j in foo(i - 1):
      yield j

for i in foo(5):
  print i

不过,foo可能会更复杂,并且会在一些无环的对象图中递归。

补充说明: 让我再举个更复杂(但仍然简单)的例子: 我可以写一个简单的递归函数,边走边打印东西:

// this is Scala
def printClass(clazz:Class[_], indent:String=""): Unit = {
  clazz match {
    case null =>
    case _ =>
      println(indent + clazz)
      printClass(clazz.getSuperclass, indent + "  ")
      for (c <- clazz.getInterfaces) {
        printClass(c, indent + "  ")
      }
  }
}

理想情况下,我希望有一个库,可以让我轻松地修改几行代码,让它变成一个迭代器:

// this is not Scala
def yieldClass(clazz:Class[_]): Iterator[Class[_]] = {
  clazz match {
    case null =>
    case _ =>
      sudoYield clazz
      for (c <- yieldClass(clazz.getSuperclass)) sudoYield c
      for (c <- clazz.getInterfaces; d <- yieldClasss(c)) sudoYield d
  }
}

看起来继续执行(continuations)可以做到这一点,但我就是不太理解shift/reset的概念。继续执行最终会被纳入主编译器吗?能不能把复杂的部分提取到一个库里?

编辑 2: 可以查看Rich的回答,在那个其他的讨论中。

5 个回答

4

要做到这一点,我觉得你需要用到continuations 插件

下面是一个简单的实现(手写的,没有编译或检查):

def iterator = new {
  private[this] var done = false

  // Define your yielding state here
  // This generator yields: 3, 13, 0, 1, 3, 6, 26, 27
  private[this] var state: Unit=>Int = reset {
    var x = 3
    giveItUp(x)
    x += 10
    giveItUp(x)
    x = 0
    giveItUp(x)
    List(1,2,3).foreach { i => x += i; giveItUp(x) }
    x += 20
    giveItUp(x)
    x += 1
    done = true
    x
  }

  // Well, "yield" is a keyword, so how about giveItUp?
  private[this] def giveItUp(i: Int) = shift { k: (Unit=>Int) =>
    state = k
    i
  }

  def hasNext = !done
  def next = state()
}

这里发生的事情是,任何对shift的调用都会捕捉到从调用它的地方到它所在的reset块结束的控制流。这部分信息会作为k参数传递给shift函数。

所以,在上面的例子中,每次调用giveItUp(x)都会返回x的值(到目前为止),并把剩下的计算过程保存在state变量中。这个过程是通过外部的hasNextnext方法来驱动的。

请轻点,这显然不是实现这个功能的好方法。但这是我晚上没带编译器时能想到的最好办法。

12

这是另一种基于“继续”插件的解决方案,这次使用了一个相对封装的生成器类型。

import scala.continuations._
import scala.continuations.ControlContext._

object Test {

  def loopWhile(cond: =>Boolean)(body: =>(Unit @suspendable)): Unit @suspendable = {
    if (cond) {
      body
      loopWhile(cond)(body)
    } else ()
  }

  abstract class Generator[T] {
    var producerCont : (Unit => Unit) = null
    var consumerCont : (T => Unit) = null

    protected def body : Unit @suspendable

    reset {
      body
    }

    def generate(t : T) : Unit @suspendable =
      shift {
        (k : Unit => Unit) => {
          producerCont = k
          if (consumerCont != null)
            consumerCont(t)
        }
      }

    def next : T @suspendable =
      shift {
        (k : T => Unit) => {
          consumerCont = k
          if (producerCont != null)
            producerCont()
        }
      }
  }

  def main(args: Array[String]) {
    val g = new Generator[Int] {
      def body = {
        var i = 0
        loopWhile(i < 10) {
          generate(i)
          i += 1
        }
      }
    }

    reset {
      loopWhile(true) {
        println("Generated: "+g.next)
      }
    }
  }
}
34

虽然Python的生成器很酷,但在Scala中尝试复制它们并不是最好的方法。比如,下面的代码可以完成你想要的功能:

def classStream(clazz: Class[_]): Stream[Class[_]] = clazz match {
  case null => Stream.empty
  case _ => (
    clazz 
    #:: classStream(clazz.getSuperclass) 
    #::: clazz.getInterfaces.toStream.flatMap(classStream) 
    #::: Stream.empty
  )
}

在这段代码中,流是懒加载的,也就是说,它不会处理任何元素,直到你请求它,这一点你可以通过运行以下代码来验证:

def classStream(clazz: Class[_]): Stream[Class[_]] = clazz match {
  case null => Stream.empty
  case _ => (
    clazz 
    #:: { println(clazz.toString+": super"); classStream(clazz.getSuperclass) } 
    #::: { println(clazz.toString+": interfaces"); clazz.getInterfaces.toStream.flatMap(classStream) } 
    #::: Stream.empty
  )
}

结果可以通过简单地调用结果的Stream上的.iterator来转换成一个Iterator

def classIterator(clazz: Class[_]): Iterator[Class[_]] = classStream(clazz).iterator

使用Streamfoo定义可以这样写:

scala> def foo(i: Int): Stream[Int] = i #:: (if (i > 0) foo(i - 1) else Stream.empty)
foo: (i: Int)Stream[Int]

scala> foo(5) foreach println
5
4
3
2
1
0

另一种选择是将不同的迭代器连接起来,注意不要提前计算它们。这里有一个例子,里面还有调试信息,帮助你跟踪执行过程:

def yieldClass(clazz: Class[_]): Iterator[Class[_]] = clazz match {
  case null => println("empty"); Iterator.empty
  case _ =>
    def thisIterator = { println("self of "+clazz); Iterator(clazz) }
    def superIterator = { println("super of "+clazz); yieldClass(clazz.getSuperclass) }
    def interfacesIterator = { println("interfaces of "+clazz); clazz.getInterfaces.iterator flatMap yieldClass }
    thisIterator ++ superIterator ++ interfacesIterator
}

这和你的代码很接近。不同的是,我用的是定义,而不是sudoYield,然后我可以根据需要将它们连接起来。

所以,虽然这不是直接的答案,但我觉得你走错方向了。在Scala中写Python的代码肯定不会有效率。你应该更努力地学习Scala的用法,这样才能实现相同的目标。

撰写回答