我能否修正R中paste0()的使用,使其运行速度与原始Python示例相同?

8 投票
2 回答
3168 浏览
提问于 2025-04-18 08:08

我最近在玩一些R语言的代码,这些代码模仿了Norvig的拼写检查器,它是用Python写的。特别是,我想搞清楚如何在R中正确实现edit2这个函数:

def splits(word):
    return [(word[:i], word[i:]) 
            for i in range(len(word)+1)]

def edits1(word):
    pairs      = splits(word)
    deletes    = [a+b[1:]           for (a, b) in pairs if b]
    transposes = [a+b[1]+b[0]+b[2:] for (a, b) in pairs if len(b) > 1]
    replaces   = [a+c+b[1:]         for (a, b) in pairs for c in alphabet if b]
    inserts    = [a+c+b             for (a, b) in pairs for c in alphabet]
    return set(deletes + transposes + replaces + inserts)

def edits2(word):
    return set(e2 for e1 in edits1(word) for e2 in edits1(e1))

不过,在我的测试中,使用paste0(或者stringr里的str_c,或者stringi里的stri_join)在R中生成成千上万的小字符串,结果发现这些代码的速度大约比Norvig展示的Python实现慢10倍(或者大约100倍,或者50倍)。(是的,使用stringr和stringi的函数,速度甚至比用paste0还要慢。) 我有几个问题(其中第3个是我最想解决的):

  1. 我这样做是对的吗(代码“正确”吗)?

  2. 如果是的话,这是不是R的一个已知问题(字符串连接极其慢)?

  3. 有没有什么办法可以让这个显著加快速度(至少快一个或多个数量级),而不需要把整个函数重写成Rcpp11或者类似的东西?

这是我为edit2函数写的R代码:

# 1. generate a list of all binary splits of a word
binary.splits <- function(w) {
  n <- nchar(w)
  lapply(0:n, function(x)
         c(stri_sub(w, 0, x), stri_sub(w, x + 1, n)))
}

# 2. generate a list of all bigrams for a word
bigram.unsafe <- function(word)
  sapply(2:nchar(word), function(i) substr(word, i-1, i))
bigram <- function(word)
  if (nchar(word) > 1) bigram.unsafe(word) else word

# 3. four edit types: deletion, transposition, replacement, and insertion
alphabet = letters
deletions <- function(splits) if (length(splits) > 1) {
   sapply(1:(length(splits)-1), function(i)
          paste0(splits[[i]][1], splits[[i+1]][2]), simplify=FALSE) 
} else {
    splits[[1]][2]
}   
transpositions <- function(splits) if (length(splits) > 2) {
  swaps <- rev(bigram.unsafe(stri_reverse(splits[[1]][2])))
  sapply(1:length(swaps), function(i)
         paste0(splits[[i]][1], swaps[i], splits[[i+2]][2]), simplify=FALSE)
} else {
  stri_reverse(splits[[1]][2])
} 
replacements <- function(splits) if (length(splits) > 1) {
  sapply(1:(length(splits)-1), function(i)
         lapply(alphabet, function(symbol)
                paste0(splits[[i]][1], symbol, splits[[i+1]][2])))
} else {
  alphabet
} 
insertions <- function(splits)
  sapply(splits, function(pair)
         lapply(alphabet, function(symbol)
                paste0(pair[1], symbol, pair[2]))) 

# 4. create a vector of all words at edit distance 1 given the input word
edit.1 <- function(word) {
  splits <- binary.splits(word)
  unique(unlist(c(deletions(splits),
                  transpositions(splits),
                  replacements(splits),
                  insertions(splits))))
}

# 5. create a simple function to generate all words of edit distance 1 and 2
edit.2 <- function(word) { 
  e1 <- edit.1(word)
  unique(c(unlist(lapply(e1, edit.1)), e1))
} 

如果你开始分析这段代码,你会发现replacementsinsertions有嵌套的“lapplies”,似乎比deletionstranspositions慢10倍,因为它们生成的拼写变体要多得多。

library(rbenchmark)
benchmark(edit.2('abcd'), replications=20)

在我的Core i5 MacBook Air上,这段代码大约需要8秒,而对应的Python基准测试(运行对应的edit2函数20次)只需要大约0.6秒,也就是说,Python快了大约10到15倍!

我尝试使用expand.grid来去掉内部的lapply,但这反而让代码变得更慢,而我知道用lapply替代sapply可以让我的代码稍微快一点,但我觉得没有必要为了小幅度的速度提升而使用“错误”的函数(我想要返回一个向量)。不过,也许在纯R中生成edit.2函数的结果可以变得更快呢?

2 个回答

1

根据@LukeTierney在问题评论中关于将paste0调用向量化的建议,我对函数进行了修改,使其能够正确地进行向量化。我还按照@MartinMorgan在他的回答中提到的额外修改进行了调整:使用单个后缀来删除项目,而不是使用选择范围(也就是说,用"[-1]"代替"[2:n]"等;但需要注意的是:对于多个后缀,在transpositions中使用时,这实际上会更慢),特别是使用rep进一步向量化replacementsinsertions中的paste0调用。

这样做的结果是实现了在R中执行edit.2的最佳答案(到目前为止?感谢Luke和Martin!)。换句话说,在Luke提供的主要提示和Martin后续的一些改进下,R的实现速度大约是Python的一半(但请参见Martin在他下面的回答中的最终评论)。(函数edit.1edit.2bigram.unsafe保持不变,如上所示。)

binary.splits <- function(w) {
  n <- nchar(w)
  list(left=stri_sub(w, rep(0, n + 1), 0:n),
       right=stri_sub(w, 1:(n + 1), rep(n, n + 1)))
}

deletions <- function(splits) {
  n <- length(splits$left)
  if (n > 1) paste0(splits$left[-n], splits$right[-1])
  else splits$right[1]
}
transpositions <- function(splits) if (length(splits$left) > 2) {
  swaps <- rev(bigram.unsafe(stri_reverse(splits$right[1])))
  paste0(splits$left[1:length(swaps)], swaps,
         splits$right[3:length(splits$right)])
} else {
  stri_reverse(splits$right[1])
}
replacements <- function(splits) {
  n <- length(splits$left)
  if (n > 1) paste0(splits$left[-n],
                    rep(alphabet, each=n-1),
                    splits$right[-1])
  else alphabet
}
insertions <- function(splits)
  paste0(splits$left,
         rep(alphabet, each=length(splits$left)),
         splits$right)

总体来说,结束这个练习,Luke和Martin的建议使得R的实现速度大约是最开始展示的Python代码的一半,提升了我原始代码的效率大约6倍。不过,最后让我更担心的是两个不同的问题:(1)R代码似乎要冗长得多(行数,可能还需要稍微优化一下),以及(2)即使是稍微偏离“正确向量化”,R代码的表现就会很糟糕,而在Python中,稍微偏离“正确的Python”通常不会有如此极端的影响。尽管如此,我会继续努力编写高效的R代码 - 感谢所有参与的人!

4

R的paste0和Python的''.join性能比较

最开始的问题是问R中的paste0是否比Python的字符串连接慢10倍。如果真是这样,那在R中写一个依赖于字符串连接的算法就没希望能和Python的对应算法一样快了。

我有

> R.version.string
[1] "R version 3.1.0 Patched (2014-05-31 r65803)"

>>> sys.version '3.4.0 (default, Apr 11 2014, 13:05:11) \n[GCC 4.8.2]'

这是第一次比较

> library(microbenchmark)
> microbenchmark(paste0("a", "b"), times=1e6)
Unit: nanoseconds
             expr min   lq median   uq      max neval
 paste0("a", "b") 951 1071   1162 1293 21794972 1e+06

(所以所有重复的时间大约是1秒)和

>>> import timeit
>>> timeit.timeit("''.join(x)", "x=('a', 'b')", number=int(1e6))
0.119668865998392

我想这就是最初提问者观察到的10倍性能差异。不过,R在处理向量时表现更好,而这个算法本身就是在处理单词的向量,所以我们可能会对以下比较感兴趣

> x = y = sample(LETTERS, 1e7, TRUE); system.time(z <- paste0(x, y))
   user  system elapsed 
  1.479   0.009   1.488 

>>> setup = '''
import random
import string
y = x = [random.choice(string.ascii_uppercase) for _ in range(10000000)]
'''
>>> timeit.Timer("map(''.join, zip(x, y))", setup=setup).repeat(1)
[0.362522566007101]

这表明,如果我们的R算法运行速度是Python的1/4,那我们就走在正确的道路上;而提问者发现的10倍差异,看来还有改进的空间。

R中的迭代与向量化

提问者使用了迭代(lapply等),而不是向量化。我们可以将向量版本与各种迭代方法进行比较,使用以下代码

f0 = paste0

f1 = function(x, y) 
   vapply(seq_along(x), function(i, x, y) paste0(x[i], y[i]), character(1), x, y)

f2 = function(x, y) Map(paste0, x, y)

f3 = function(x, y) {
    z = character(length(x))
    for (i in seq_along(x)) 
        z[i] = paste0(x[i], y[i])
    z 
}

f3c = compiler::cmpfun(f3)    # explicitly compile

f4 = function(x, y) {
    z = character()
    for (i in seq_along(x)) 
        z[i] = paste0(x[i], y[i])
    z 
}

将数据缩放回去,把“向量化”的解决方案定义为f0,并比较这些方法

> x = y = sample(LETTERS, 100000, TRUE)
> library(microbenchmark)
> microbenchmark(f0(x, y), f1(x, y), f2(x, y), f3(x, y), f3c(x, y), times=5)
Unit: milliseconds
      expr       min        lq    median        uq       max neval
  f0(x, y)  14.69877  14.70235  14.75409  14.98777  15.14739     5
  f1(x, y) 241.34212 250.19018 268.21613 279.01582 292.21065     5
  f2(x, y) 198.74594 199.07489 214.79558 229.50684 271.77853     5
  f3(x, y) 250.64388 251.88353 256.09757 280.04688 296.29095     5
 f3c(x, y) 174.15546 175.46522 200.09589 201.18543 214.18290     5

其中f4慢得让人痛苦,所以不包括在内

> system.time(f4(x, y))
   user  system elapsed 
 24.325   0.000  24.330 

从这里可以看出,Tierney博士的建议是,向量化那些lapply调用可能会有好处。

进一步向量化更新后的原始代码

@fnl对原始代码进行了改进,部分展开了循环。还有更多机会可以继续这样做,比如

replacements <- function(splits) if (length(splits$left) > 1) {
  lapply(1:(length(splits$left)-1), function(i)
         paste0(splits$left[i], alphabet, splits$right[i+1]))
} else {
  splits$right[1]
}

可以修改为执行一次paste调用,依赖于参数回收(短向量会重复使用,直到它们的长度与长向量匹配)

replacements1 <- function(splits) if (length(splits$left) > 1) {
    len <- length(splits$left)
    paste0(splits$left[-len], rep(alphabet, each = len - 1), splits$right[-1])
} else {
  splits$right[1]
}

虽然值的顺序不同,但这对算法并不重要。去掉下标(前面加-)可能会更节省内存。同样

deletions1 <- function(splits) if (length(splits$left) > 1) {
  paste0(splits$left[-length(splits$left)], splits$right[-1])
} else {
  splits$right[1]
}

insertions1 <- function(splits)
    paste0(splits$left, rep(alphabet, each=length(splits$left)), splits$right)

我们接下来有

edit.1.1 <- function(word) {
  splits <- binary.splits(word)
  unique(c(deletions1(splits),
           transpositions(splits),
           replacements1(splits),
           insertions1(splits)))
}

并且有了一些速度提升

> identical(sort(edit.1("word")), sort(edit.1.1("word")))
[1] TRUE
> microbenchmark(edit.1("word"), edit.1.1("word"))
Unit: microseconds
             expr     min       lq   median       uq     max neval
   edit.1("word") 354.125 358.7635 362.5260 372.9185 521.337   100
 edit.1.1("word") 296.575 298.9830 300.8305 307.3725 369.419   100

提问者表示他们的原始版本比Python慢10倍,而他们的原始修改使速度提升了5倍。我们又获得了1.2倍的速度提升,所以可能达到了使用R的paste0时算法的预期性能。下一步是问问是否有其他算法或实现更高效,特别是substr可能会很有前景。

撰写回答