如何在Scala中近似Python的或运算符进行集合比较?
在听完最新的Stack Overflow播客后,彼得·诺维格的紧凑型Python拼写检查器让我很感兴趣,所以我决定用Scala来实现它,看看能不能用函数式编程的方式表达得更好,同时也想知道需要多少行代码。
这是整个问题。(我们暂时不比较代码行数。)
(有两点说明:如果你愿意,可以在Scala解释器中运行这个。如果你需要big.txt的副本,或者整个项目,可以在GitHub上找到。)
import scala.io.Source
val alphabet = "abcdefghijklmnopqrstuvwxyz"
def train(text:String) = {
"[a-z]+".r.findAllIn(text).foldLeft(Map[String, Int]() withDefaultValue 1)
{(a, b) => a(b) = a(b) + 1}
}
val NWORDS = train(Source.fromFile("big.txt").getLines.mkString.toLowerCase)
def known(words:Set[String]) =
{Set.empty ++ (for(w <- words if NWORDS contains w) yield w)}
def edits1(word:String) = {
Set.empty ++
(for (i <- 0 until word.length) // Deletes
yield (word take i) + (word drop (i + 1))) ++
(for (i <- 0 until word.length - 1) // Transposes
yield (word take i) + word(i + 1) + word(i) + (word drop (i + 2))) ++
(for (i <- 0 until word.length; j <- alphabet) // Replaces
yield (word take i) + j + (word drop (i+1))) ++
(for (i <- 0 until word.length; j <- alphabet) // Inserts
yield (word take i) + j + (word drop i))
}
def known_edits2(word:String) = {Set.empty ++ (for (e1 <- edits1(word);
e2 <- edits1(e1) if NWORDS contains e2) yield e2)}
def correct(word:String) = {
val options = Seq(() => known(Set(word)), () => known(edits1(word)),
() => known_edits2(word), () => Set(word))
val candidates = options.foldLeft(Set[String]())
{(a, b) => if (a.isEmpty) b() else a}
candidates.foldLeft("") {(a, b) => if (NWORDS(a) > NWORDS(b)) a else b}
}
具体来说,我在想有没有什么更简洁的方法来处理correct
函数。在原来的Python实现中,代码看起来要干净一些:
def correct(word):
candidates = known([word]) or known(edits1(word)) or
known_edits2(word) or [word]
return max(candidates, key=NWORDS.get)
显然在Python中,一个空集合会被判断为布尔值False
,所以只有第一个返回非空集合的候选项会被计算,这样可以避免可能耗时的edits1
和known_edits2
的调用。
我能想到的唯一解决方案就是你现在看到的这个版本,在这里,匿名函数的Seq
会被调用,直到其中一个返回非空的Set
,而最后一个是一定会做到的。
所以,经验丰富的Scala高手们,有没有更简洁或更好的方法来实现这个呢?提前谢谢你们!
5 个回答
这样做可以吗?这里的 _
语法是一个部分应用的函数,通过使用一个(懒惰的) Stream
,我确保在 reduceLeft
中的计算(我觉得这里用 reduceLeft
比 foldLeft
更合适)只有在需要的时候才会发生!
def correct(word:String) = {
Stream(known(Set(word)) _,
known(edits1(word)) _,
known_edits2(word) _,
Set(word) _
).find( !_().isEmpty ) match {
case Some(candidates) =>
candidates.reduceLeft {(res, n) => if (NWORDS(res) > NWORDS(n)) res else n}
case _ => "" //or some other value
}
我可能在这里犯了一些语法错误,但我觉得 Stream
的方法是有效的。
我不太明白你为什么要对known
使用懒惰求值,而不是像oxbow_lakes那样直接使用流。其实,有一种更好的方法来实现他所做的事情:
def correct(word: String) = {
import Stream._
val str = cons(known(Set(word)),
cons(known(edits1(word)),
cons(known_edits2(word),
cons(Set(word), empty))))
str find { !_.isEmpty } match {
case Some(candidates) =>
candidates.foldLeft(Set[String]()) { (res, n) =>
if (NWORDS(res) > NWORDS(n)) res else n
}
case None => Set()
}
}
这个方法利用了Stream.cons
本身就是懒惰的特点,所以我们不需要把所有东西都包裹在一个延迟执行的函数里。
不过,如果你真的想要更好看的语法,我们可以给这些cons
加点语法糖:
implicit def streamSyntax[A](tail: =>Stream[A]) = new {
def #::(hd: A) = Stream.cons(hd, tail)
}
现在我们之前那个丑陋的str
定义变得更简洁了:
def correct(word: String) = {
val str = known(Set(word)) #:: known(edits1(word)) #::
known_edits2(word) #:: Set(word) #:: Stream.empty
...
}
迭代器是懒惰的(虽然它们不是很灵活,因为你只能遍历一次)。所以,你可以这样做:
def correct(word: String) = {
val sets = List[String => Set[String]](
x => known(Set(x)), x => known(edits1(x)), known_edits2
).elements.map(_(word))
sets find { !_.isEmpty } match {
case Some(candidates: Set[String]) => candidates.reduceLeft { (res, n) => if (NWORDS(res) > NWORDS(n)) res else n }
case None => word
}
}
另外,迭代器的find()方法不会强制计算下一个元素。