当将lambda函数作为参数传递时pyo3的性能如何
我开始研究pyo3,作为测试,我想用pyo3来封装一个rust库;不过,当我把lambda函数作为参数传递时,遇到了一些性能问题。
假设我有一个rust库,里面有一个函数,它接受一个回调函数作为参数。这个函数会多次调用这个回调,然后返回一个结果,比如:
pub fn test_function<F: Fn(f64) -> f64>(cb: F) -> f64 {
//Not actually an implementation this trivial, it is just to execute the callback a number of times
(0..225_000).map(|i| cb(i as f64)).sum::<f64>()
}
我试着通过执行这个函数1000次来测量它的执行时间,传入一个回调函数,然后计算平均时间。
use std::time::SystemTime;
use pyo3test::test_function;
pub fn main() {
let mut sum = 0.0f64;
let reps = 1_000;
let start = SystemTime::now();
for _ in 0..reps {
sum += test_function(|x| x);
}
let end = start.elapsed().unwrap();
println!("Result: {sum}");
println!("Duration: {:?}", end.checked_div(reps).unwrap());
}
在我的机器上,当以发布模式运行时,执行一次test_function
大约需要250微秒。
然后我尝试用pyo3来封装这个函数,方法如下:
use pyo3::{PyAny, pyfunction, pymodule, PyResult, Python, wrap_pyfunction};
use pyo3::prelude::PyModule;
use crate::test_function;
#[pymodule]
fn pyo3test(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(py_test_function, m)?)?;
Ok(())
}
#[pyfunction]
pub fn py_test_function(function: &PyAny) -> f64 {
assert!(function.is_callable());
let cb = move |x| function.call1((x, )).unwrap().extract::<f64>().unwrap();
test_function(cb)
}
我编译了所有内容(以发布模式),在python中导入这个模块,然后再次测量它的执行时间,执行这个函数1000次并计算平均时间。
import pyo3test
from time import time
reps = 1000
cb = lambda x: x
summation = 0
start = time()
for _ in range(reps):
summation += pyo3test.py_test_function(cb)
end = time()
duration = end-start
avg = duration/reps
print(avg)
在这种情况下,平均执行时间大约是20毫秒,几乎是纯rust情况下的100倍。我本来不指望执行时间会一样,因为有全局解释器锁(GIL),但我觉得应该接近毫秒级。
这是正常现象吗,还是我漏掉了什么?有没有可能在不改变纯rust实现的情况下改善这个问题?
我试着查看文档,虽然它提到 extract
比较慢,但我觉得在这里我无能为力,因为downcast
在这里不能使用。还有其他可以做的事情吗?
更新
按照@Ahmed AEK的建议,我写了另一个rust函数:
pub fn test_function_alt(values: &[f64]) -> f64 {
values.iter().sum::<f64>()
}
我用pyo3封装了它。
#[pyfunction]
pub fn py_test_function_alt(values: Vec<f64>) -> f64 {
test_function_alt(&values)
}
然后我写了以下python函数:
import numpy as np
def foobar(cb):
vals = cb(np.arange(225000))
return pyo3test.py_test_function_alt(vals)
这个函数的执行时间仍然大约是20毫秒。