如何在Scala中近似Python的或运算符进行集合比较?

6 投票
5 回答
591 浏览
提问于 2025-04-15 16:16

在听完最新的Stack Overflow播客后,彼得·诺维格的紧凑型Python拼写检查器让我很感兴趣,所以我决定用Scala来实现它,看看能不能用函数式编程的方式表达得更好,同时也想知道需要多少行代码。

这是整个问题。(我们暂时不比较代码行数。)

(有两点说明:如果你愿意,可以在Scala解释器中运行这个。如果你需要big.txt的副本,或者整个项目,可以在GitHub上找到。)

import scala.io.Source

val alphabet = "abcdefghijklmnopqrstuvwxyz"

def train(text:String) = {
  "[a-z]+".r.findAllIn(text).foldLeft(Map[String, Int]() withDefaultValue 1)
    {(a, b) => a(b) = a(b) + 1}
}

val NWORDS = train(Source.fromFile("big.txt").getLines.mkString.toLowerCase)

def known(words:Set[String]) =
  {Set.empty ++ (for(w <- words if NWORDS contains w) yield w)}

def edits1(word:String) = {
  Set.empty ++
  (for (i <- 0 until word.length) // Deletes
    yield (word take i) + (word drop (i + 1))) ++ 
  (for (i <- 0 until word.length - 1) // Transposes
    yield (word take i) + word(i + 1) + word(i) + (word drop (i + 2))) ++ 
  (for (i <- 0 until word.length; j <- alphabet) // Replaces
    yield (word take i) + j + (word drop (i+1))) ++ 
  (for (i <- 0 until word.length; j <- alphabet) // Inserts
    yield (word take i) + j + (word drop i))
}

def known_edits2(word:String) = {Set.empty ++ (for (e1 <- edits1(word);
  e2 <- edits1(e1) if NWORDS contains e2) yield e2)}

def correct(word:String) = {
  val options = Seq(() => known(Set(word)), () => known(edits1(word)),
    () => known_edits2(word), () => Set(word))
  val candidates = options.foldLeft(Set[String]())
    {(a, b) => if (a.isEmpty) b() else a}
  candidates.foldLeft("") {(a, b) => if (NWORDS(a) > NWORDS(b)) a else b}
}

具体来说,我在想有没有什么更简洁的方法来处理correct函数。在原来的Python实现中,代码看起来要干净一些:

def correct(word):
    candidates = known([word]) or known(edits1(word)) or
      known_edits2(word) or [word]
    return max(candidates, key=NWORDS.get)

显然在Python中,一个空集合会被判断为布尔值False,所以只有第一个返回非空集合的候选项会被计算,这样可以避免可能耗时的edits1known_edits2的调用。

我能想到的唯一解决方案就是你现在看到的这个版本,在这里,匿名函数的Seq会被调用,直到其中一个返回非空的Set,而最后一个是一定会做到的。

所以,经验丰富的Scala高手们,有没有更简洁或更好的方法来实现这个呢?提前谢谢你们!

5 个回答

4

这样做可以吗?这里的 _ 语法是一个部分应用的函数,通过使用一个(懒惰的) Stream,我确保在 reduceLeft 中的计算(我觉得这里用 reduceLeftfoldLeft 更合适)只有在需要的时候才会发生!

def correct(word:String) = {
  Stream(known(Set(word)) _, 
         known(edits1(word)) _, 
         known_edits2(word) _, 
         Set(word) _
         ).find( !_().isEmpty ) match {
    case Some(candidates) =>
      candidates.reduceLeft {(res, n) => if (NWORDS(res) > NWORDS(n)) res else n}
    case _ => "" //or some other value

}

我可能在这里犯了一些语法错误,但我觉得 Stream 的方法是有效的。

6

我不太明白你为什么要对known使用懒惰求值,而不是像oxbow_lakes那样直接使用流。其实,有一种更好的方法来实现他所做的事情:

def correct(word: String) = {
  import Stream._

  val str = cons(known(Set(word)), 
            cons(known(edits1(word)),
            cons(known_edits2(word),
            cons(Set(word), empty))))

  str find { !_.isEmpty } match {
    case Some(candidates) =>
      candidates.foldLeft(Set[String]()) { (res, n) =>
        if (NWORDS(res) > NWORDS(n)) res else n
      }

    case None => Set()
  }
}

这个方法利用了Stream.cons本身就是懒惰的特点,所以我们不需要把所有东西都包裹在一个延迟执行的函数里。

不过,如果你真的想要更好看的语法,我们可以给这些cons加点语法糖:

implicit def streamSyntax[A](tail: =>Stream[A]) = new {
  def #::(hd: A) = Stream.cons(hd, tail)
}

现在我们之前那个丑陋的str定义变得更简洁了:

def correct(word: String) = {
  val str = known(Set(word)) #:: known(edits1(word)) #::
            known_edits2(word) #:: Set(word) #:: Stream.empty

  ...
}
2

迭代器是懒惰的(虽然它们不是很灵活,因为你只能遍历一次)。所以,你可以这样做:

  def correct(word: String) = {
    val sets = List[String => Set[String]](
      x => known(Set(x)), x => known(edits1(x)), known_edits2
    ).elements.map(_(word))

    sets find { !_.isEmpty } match {
      case Some(candidates: Set[String]) => candidates.reduceLeft { (res, n) => if (NWORDS(res) > NWORDS(n)) res else n }
      case None => word
    }
  }

另外,迭代器的find()方法不会强制计算下一个元素。

撰写回答