在Cython中遍历字节/Unicode字符串的最佳方法

5 投票
1 回答
2636 浏览
提问于 2025-04-17 18:40

我刚开始学习Cython,发现网上找Cython相关的资料也挺难的,所以提前说声抱歉。

我正在用Cython重新实现一个Python函数。这个函数在Python中的样子大概是这样的:

def func(s, numbers=None):
    if numbers:
         some_dict = numbers
    else:
         some_dict = default
    return sum(some_dict[c] for c in s)

在Python 2和3上都能正常工作。但是如果我尝试给变量 sc 赋值,至少在一个Python版本上就会出错。我尝试了:

def func(char *s, numbers=None):
    if numbers:
         some_dict = numbers
    else:
         some_dict = default
    cdef char c
    cdef double m = 0.0
    for c in s:
        m += some_dict[<bytes>c]
    return m

老实说,这是我唯一能让它工作的方法,它在Python 2上速度提升不错,但在Python 3上就出问题了。我看了这篇Cython文档,以为下面的代码在Python 3上能行:

def func(unicode s, numbers=None):
    if numbers:
         some_dict = numbers
    else:
         some_dict = default
    cdef double m = 0.0
    for c in s:
        m += some_dict[c]
    return m

但实际上它抛出了一个 KeyError 错误,看起来 c 还是个 char(如果 s'P' 开头,缺失的键是 80),但当我用 print(type(c)) 打印时,它显示 <class 'str'>

需要注意的是,原来的未指定类型的代码在两个版本下都能工作,但速度大约是Python 2上工作类型版本的两倍慢。

那么,我该怎么才能让它在Python 3上工作呢?然后又该如何让它在两个Python版本上都能用?我可以/应该在类型声明中加上版本检查吗?还是说我应该写两个函数,然后根据条件把其中一个赋值给一个公开可用的名字?

附言:如果只允许字符串中的ASCII字符,我也没问题,但我怀疑这是否重要,因为Cython似乎更倾向于明确的编码/解码。


编辑:我也尝试了明确的编码和遍历字节字符串,这样做是有道理的,但以下代码:

def func(s, numbers=None):
    if numbers:
         some_dict = numbers
    else:
         some_dict = default
    cdef double m = 0.0
    cdef bytes bs = s.encode('ascii')
    cdef char c
    for c in bs:
        m += some_dict[(<bytes>c).decode('ascii')]
    return m

在Python 2上比我第一次尝试的慢了3倍(接近纯Python函数的速度),在Python 3上几乎慢了2倍。

1 个回答

0

foo.h

// #include <unistd.h>;  // for ssize_t
double foo(char * str, ssize_t str_len, double weights[256]){
    double output = 0.0;
    int i;
    for(i = 0; i < str_len; ++i){
        output += weights[str[i]];
    }
    return output;
}

from cpython.string cimport PyString_GET_SIZE, PyString_Check, PyString_AS_STRING

cdef extern from "foo.h":
    double foo(char * str, ssize_t str_len, double weights[256])   

cdef class Numbers:
    cdef double nums[256]

    def __cinit__(self, py_numbers):
        for x in range(256):
            self.nums[i] = py_numbers[i]

def py_foo(my_str, Numbers nums_inst):
    cdef:
        double res
    # check here my_str is BYTEstring
    if not PyString_Check(my_str):
        raise TypeError("bytestring expected got %s instead" % type(my_str))
    res = foo(PyString_AS_STRING(my_str), PyString_GET_SIZE(my_str), nums_inst.nums)
    return res

(未经测试)

撰写回答