Cython Gibbs采样器比numpy稍慢

2024-04-25 21:21:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我实现了一个吉布斯采样器来生成纹理图像。根据beta参数(形状数组(4)),我们可以生成各种纹理。在

下面是我使用Numpy的初始函数:

def gibbs_sampler(img_label, betas, burnin, nb_samples):
    nb_iter = burnin + nb_samples

    lst_samples = []

    labels = np.unique(img)

    M, N = img.shape
    img_flat = img.flatten()

    # build neighborhood array by means of numpy broadcasting:
    m, n = np.ogrid[0:M, 0:N]

    top_left, top, top_right =   m[0:-2, :]*N + n[:, 0:-2], m[0:-2, :]*N + n[:, 1:-1]  , m[0:-2, :]*N + n[:, 2:]
    left, pix, right = m[1:-1, :]*N + n[:, 0:-2],  m[1:-1, :]*N + n[:, 1:-1], m[1:-1, :]*N + n[:, 2:]
    bottom_left, bottom, bottom_right = m[2:, :]*N + n[:, 0:-2],  m[2:, :]*N + n[:, 1:-1], m[2:, :]*N + n[:, 2:]

    mat_neigh = np.dstack([pix, top, bottom, left, right, top_left, bottom_right, bottom_left, top_right])

    mat_neigh = mat_neigh.reshape((-1, 9))    
    ind = np.arange((M-2)*(N-2))  

    # loop over iterations
    for iteration in np.arange(nb_iter):

        np.random.shuffle(ind)

        # loop over pixels
        for i in ind:                  

            truc = map(functools.partial(lambda label, img_flat, mat_neigh : 1-np.equal(label, img_flat[mat_neigh[i, 1:]]).astype(np.uint), img_flat=img_flat, mat_neigh=mat_neigh), labels)
            # bidule is of shape (4, 2, labels.size)
            bidule = np.array(truc).T.reshape((-1, 2, labels.size))

            # theta is of shape (labels.size, 4) 
            theta = np.sum(bidule, axis=1).T
            # prior is thus an array of shape (labels.size)
            prior = np.exp(-np.dot(theta, betas))

            # sample from the posterior
            drawn_label = np.random.choice(labels, p=prior/np.sum(prior))

            img_flat[(i//(N-2) + 1)*N + i%(N-2) + 1] = drawn_label


        if iteration >= burnin:
            print('Iteration %i --> sample' % iteration)
            lst_samples.append(copy.copy(img_flat.reshape(M, N)))

        else:
            print('Iteration %i --> burnin' % iteration)

    return lst_samples

我们不能摆脱任何循环,因为它是一个迭代算法。因此,我尝试使用Cython(带静态类型)来加快速度:

^{pr2}$

然而,我最终得到了几乎相同的计算时间,numpy版本比Cython版本稍快一些。在

因此,我试图改进Cython代码。在

编辑:

对于两个函数(Cython和no Cython): 我替换了:

truc = map(functools.partial(lambda label, img_flat, mat_neigh : 1-np.equal(label, img_flat[mat_neigh[i, 1:]]).astype(np.uint), img_flat=img_flat, mat_neigh=mat_neigh), labels)

广播方式:

truc = 1-np.equal(labels[:, None], img_flat[mat_neigh[i, 1:]][None, :])

所有的np.arange都是range,先验值的计算现在由Divakar建议的np.einsum来完成。在

这两个函数都比以前快,但是Python函数仍然比Cython函数稍快。在


Tags: 函数rightimglabelstopnpleftlabel
3条回答

如果您希望加快NumPy代码的速度,我们可以改进最内层循环的性能,希望这会转化为一些整体的改进。在

因此,我们有:

theta = np.sum(bidule, axis=1).T
prior = np.exp(-np.dot(theta, betas))

将求和和和与矩阵乘法结合在一起,我们可以-

^{pr2}$

现在,这包括沿着一个轴求和,然后在元素乘法后求和。在许多工具中,我们有np.einsum来帮助我们,特别是因为我们可以一次性地执行这些缩减,比如-

^{3}$

运行时测试-

In [98]: # Setup
    ...: N = 100
    ...: bidule = np.random.rand(4,2,N)
    ...: betas = np.random.rand(4)
    ...: 

In [99]: %timeit np.dot(np.sum(bidule, axis=1).T, betas)
100000 loops, best of 3: 12.4 µs per loop

In [100]: %timeit np.einsum('ijk,i->k',bidule,betas)
100000 loops, best of 3: 4.05 µs per loop

In [101]: # Setup
     ...: N = 10000
     ...: bidule = np.random.rand(4,2,N)
     ...: betas = np.random.rand(4)
     ...: 

In [102]: %timeit np.dot(np.sum(bidule, axis=1).T, betas)
10000 loops, best of 3: 157 µs per loop

In [103]: %timeit np.einsum('ijk,i->k',bidule,betas)
10000 loops, best of 3: 90.9 µs per loop

所以,希望在运行多次迭代时,加速会很明显。在

我已经在您的源代码上运行了Cython in annotated mode,并查看了结果。也就是说,将它保存在q.pyx中之后,我运行

cython -a q.pyx
firefox q.html

(当然,可以使用任何浏览器)。在

代码被涂成深黄色,这表明,就Cython而言,代码远不是静态类型的。这可分为两类。在

在某些情况下,最好静态键入代码:

  1. for iteration in np.arange(nb_iter):for i in ind:中,每次迭代大约要花30条C行。请参阅here如何在Cython中有效地访问numpy数组。

  2. truc = map(functools.partial(func_for_map, img_flat=img_flat, mat_neigh=mat_neigh, i=i), labels)中,静态类型并没有给您带来任何好处。我建议您cdef函数func_for_map,并在循环中自己调用它。

在其他情况下,调用numpy向量化函数,例如theta = np.sum(bidule, axis=1).Tprior = np.exp(-np.dot(theta, betas)).astype(DOUBLETYPE)等。在这些情况下,Cython实际上没有什么好处。在

This answer很好地解释了为什么Numpy效率低下,而您仍然想使用Cython。基本上:

  • 小数组的开销(也减少了像您的np.sum(bidule, axis=1)
  • 由于中间层,大型阵列的缓存抖动。在

在这种情况下,为了从Cython中获益,您必须用Python循环替换Numpy数组操作,Cython必须能够将其转换为C代码,否则就没有意义了。这并不意味着你必须重写所有的Numpy函数,你必须聪明一点。在

例如,您应该去掉mat_neighbidule数组,只在循环中执行索引和求和。在

另一方面,您应该保留(规范化的)prior数组,并继续使用np.random.choice。真的没有一个简单的方法来解决这个问题(嗯。。见source for ^{})。不幸的是,这意味着这一部分可能会成为性能瓶颈。在

相关问题 更多 >