当将lambda函数作为参数传递时pyo3的性能如何

1 投票
1 回答
80 浏览
提问于 2025-04-14 18:21

我开始研究pyo3,作为测试,我想用pyo3来封装一个rust库;不过,当我把lambda函数作为参数传递时,遇到了一些性能问题。

假设我有一个rust库,里面有一个函数,它接受一个回调函数作为参数。这个函数会多次调用这个回调,然后返回一个结果,比如:

pub fn test_function<F: Fn(f64) -> f64>(cb: F) -> f64 {
    //Not actually an implementation this trivial, it is just to execute the callback a number of times
    (0..225_000).map(|i| cb(i as f64)).sum::<f64>()
}

我试着通过执行这个函数1000次来测量它的执行时间,传入一个回调函数,然后计算平均时间。

use std::time::SystemTime;
use pyo3test::test_function;

pub fn main() {
    let mut sum = 0.0f64;
    let reps = 1_000;
    let start = SystemTime::now();
    for _ in 0..reps {
        sum += test_function(|x| x);
    }
    let end = start.elapsed().unwrap();
    println!("Result: {sum}");
    println!("Duration: {:?}", end.checked_div(reps).unwrap());
}

在我的机器上,当以发布模式运行时,执行一次test_function大约需要250微秒。

然后我尝试用pyo3来封装这个函数,方法如下:

use pyo3::{PyAny, pyfunction, pymodule, PyResult, Python, wrap_pyfunction};
use pyo3::prelude::PyModule;
use crate::test_function;

#[pymodule]
fn pyo3test(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(py_test_function, m)?)?;
    Ok(())
}

#[pyfunction]
pub fn py_test_function(function: &PyAny) -> f64 {
    assert!(function.is_callable());
    let cb = move |x| function.call1((x, )).unwrap().extract::<f64>().unwrap();
    test_function(cb)
}

我编译了所有内容(以发布模式),在python中导入这个模块,然后再次测量它的执行时间,执行这个函数1000次并计算平均时间。

import pyo3test
from time import time

reps = 1000
cb = lambda x: x
summation = 0
start = time()
for _ in range(reps):
    summation += pyo3test.py_test_function(cb)
end = time()

duration = end-start
avg = duration/reps
print(avg)

在这种情况下,平均执行时间大约是20毫秒,几乎是纯rust情况下的100倍。我本来不指望执行时间会一样,因为有全局解释器锁(GIL),但我觉得应该接近毫秒级。

这是正常现象吗,还是我漏掉了什么?有没有可能在不改变纯rust实现的情况下改善这个问题?

我试着查看文档,虽然它提到 extract比较慢,但我觉得在这里我无能为力,因为downcast在这里不能使用。还有其他可以做的事情吗?


更新

按照@Ahmed AEK的建议,我写了另一个rust函数:

pub fn test_function_alt(values: &[f64]) -> f64 {
    values.iter().sum::<f64>()
}

我用pyo3封装了它。

#[pyfunction]
pub fn py_test_function_alt(values: Vec<f64>) -> f64 {
    test_function_alt(&values)
}

然后我写了以下python函数:

import numpy as np

def foobar(cb):
    vals = cb(np.arange(225000))
    return pyo3test.py_test_function_alt(vals)

这个函数的执行时间仍然大约是20毫秒。

1 个回答

1

这里的问题不是因为GIL(全局解释器锁),而是因为Python解释器本身比较慢。调用回调函数需要花费20毫秒,这个过程完全是在Python中进行的,没有使用任何外部接口。

你需要重新设计你的API接口,不要频繁地调用Python达到225_000次。可以考虑使用Python的数组或者numpy数组,这些都是专门为将数据传递给本地API而设计的。

撰写回答