如何使用numba加速多维Longsumexp和softmax

2 投票
1 回答
61 浏览
提问于 2025-04-13 14:28

我想用多项式逻辑回归来计算softmax(也就是概率),同时使用longsumexp来避免溢出。使用numba可以让我速度提高2到3倍。我还能做得更好吗?另外,当我使用fastmath=True时,似乎并没有加快速度,那我是不是把numba的循环写错了?

import numba
import numpy as np
def get_p_4d(a, lamda):
    m = a * lamda[:, None][:,None].transpose(0,3,1,2)
    c = np.max(m, axis=2)[:,None].transpose(0,2,1,3)
    aa = np.exp(m - c)
    logsumexp = c + np.log(aa.sum(axis=2)[:,None].transpose(0,2,1,3))
    p = np.exp(m - logsumexp)
    return p

@numba.njit()
def get_p_4d_nb(a, lamda, num_code, num_draw, num_action):
    p = np.empty((num_code, num_draw, num_action, 3))
    a = a.transpose(0, 1, 3, 2)
    for i in range(num_code):
        for j in range(num_draw):
            this_lamda = lamda[i,j]
            for k in range(num_action):
                p[i, j, k, 0] = a[i, j, k, 0] * this_lamda
                p[i, j, k, 1] = a[i, j, k, 1] * this_lamda
                p[i, j, k, 2] = a[i, j, k, 2] * this_lamda

                c = p[i,j,k,0]
                c = max(c, p[i,j,k,1])
                c = max(c, p[i,j,k,2])

                logsumexp = np.log(
                    np.exp(p[i, j, k, 0] - c) + np.exp(p[i, j, k, 1] - c) + np.exp(p[i, j, k, 2] - c)) + c

                p[i, j, k, 0] = np.exp(p[i, j, k, 0] - logsumexp)
                p[i, j, k, 1] = np.exp(p[i, j, k, 1] - logsumexp)
                p[i, j, k, 2] = np.exp(p[i, j, k, 2] - logsumexp)

    return p.transpose(0, 1, 3, 2)

a=np.ones((112,1000,3,3))
lamda = np.random.uniform(0., 1., size=112*1000).reshape(112,1000)
get_p_4d(a, lamda)
get_p_4d_nb(a, lamda, 112, 1000, 3)

1 个回答

0

你可以尝试把这个任务进行并行处理(我还稍微简化了一下代码,使用了切片 0:3):

@numba.njit(parallel=True)
def get_p_4d_nb_parallel(a, lamda, num_code, num_draw, num_action):
    p = np.empty((num_code, num_draw, num_action, 3), dtype="float32")
    a = a.transpose(0, 1, 3, 2)
    for i in numba.prange(num_code):
        for j in range(num_draw):
            this_lamda = lamda[i, j]
            for k in range(num_action):
                p[i, j, k, 0:3] = a[i, j, k, 0:3] * this_lamda
                c = np.max(p[i, j, k, 0:3])
                logsumexp = np.log(np.exp(p[i, j, k, 0:3] - c).sum()) + c
                p[i, j, k, 0:3] = np.exp(p[i, j, k, 0:3] - logsumexp)
    return p.transpose(0, 1, 3, 2)

基准测试:

from timeit import timeit

import numba
import numpy as np


def get_p_4d(a, lamda):
    m = a * lamda[:, None][:, None].transpose(0, 3, 1, 2)
    c = np.max(m, axis=2)[:, None].transpose(0, 2, 1, 3)
    aa = np.exp(m - c)
    logsumexp = c + np.log(aa.sum(axis=2)[:, None].transpose(0, 2, 1, 3))
    p = np.exp(m - logsumexp)
    return p


@numba.njit
def get_p_4d_nb(a, lamda, num_code, num_draw, num_action):
    p = np.empty((num_code, num_draw, num_action, 3))
    a = a.transpose(0, 1, 3, 2)
    for i in range(num_code):
        for j in range(num_draw):
            this_lamda = lamda[i, j]
            for k in range(num_action):
                p[i, j, k, 0] = a[i, j, k, 0] * this_lamda
                p[i, j, k, 1] = a[i, j, k, 1] * this_lamda
                p[i, j, k, 2] = a[i, j, k, 2] * this_lamda

                c = p[i, j, k, 0]
                c = max(c, p[i, j, k, 1])
                c = max(c, p[i, j, k, 2])

                logsumexp = (
                    np.log(
                        np.exp(p[i, j, k, 0] - c)
                        + np.exp(p[i, j, k, 1] - c)
                        + np.exp(p[i, j, k, 2] - c)
                    )
                    + c
                )

                p[i, j, k, 0] = np.exp(p[i, j, k, 0] - logsumexp)
                p[i, j, k, 1] = np.exp(p[i, j, k, 1] - logsumexp)
                p[i, j, k, 2] = np.exp(p[i, j, k, 2] - logsumexp)

    return p.transpose(0, 1, 3, 2)


@numba.njit(parallel=True)
def get_p_4d_nb_parallel(a, lamda, num_code, num_draw, num_action):
    p = np.empty((num_code, num_draw, num_action, 3), dtype="float32")
    a = a.transpose(0, 1, 3, 2)
    for i in numba.prange(num_code):
        for j in range(num_draw):
            this_lamda = lamda[i, j]
            for k in range(num_action):
                p[i, j, k, 0:3] = a[i, j, k, 0:3] * this_lamda
                c = np.max(p[i, j, k, 0:3])
                logsumexp = np.log(np.exp(p[i, j, k, 0:3] - c).sum()) + c
                p[i, j, k, 0:3] = np.exp(p[i, j, k, 0:3] - logsumexp)
    return p.transpose(0, 1, 3, 2)


a = np.ones((112, 1000, 3, 3))
lamda = np.random.uniform(0.0, 1.0, size=112 * 1000).reshape(112, 1000)

x = get_p_4d(a, lamda)
y = get_p_4d_nb(a, lamda, 112, 1000, 3)
z = get_p_4d_nb_parallel(a, lamda, 112, 1000, 3)

assert np.allclose(x, y)
assert np.allclose(x, z)


t1 = timeit("get_p_4d(a, lamda)", number=1, globals=globals())
t2 = timeit("get_p_4d_nb(a, lamda, 112, 1000, 3)", number=1, globals=globals())
t3 = timeit("get_p_4d_nb_parallel(a, lamda, 112, 1000, 3)", number=1, globals=globals())

print(t1, t2, t3, sep="\n")

在我的电脑上(AMD 5700x)打印的结果:

0.032106522005051374
0.010540901996137109
0.0014921170004527085

撰写回答