Pandas的线性组合函数

2024-05-15 08:49:58 发布

您现在位置:Python中文网/ 问答频道 /正文

给定一个像这样的数据帧:

from datetime import datetime
test = pd.DataFrame([
    {'id': 1, 'date': datetime.fromisoformat('2016-01-01'), 'a': 1}, 
    {'id': 1, 'date': datetime.fromisoformat('2016-01-02'), 'a': 2}, 
    {'id': 1, 'date': datetime.fromisoformat('2016-01-03'), 'a': 3}]
)

我使用的是一个线性组合Python函数:

def lin_comb(v1, v2, beta=0.9): 
    return beta*v1 + (1-beta)*v2

要基于具有以下值的列a生成列lin_comb,请执行以下操作:

    id  date        a   lin_comb
0   1   2016-01-01  1   1.000000
1   1   2016-01-02  2   1.099609
2   1   2016-01-03  3   1.290039

例如,通过以下表达式计算上面最后一行的值:

(1 * 0.9 + 2 * 0.1) * 0.9 + 3 * 0.1 = 1.29

以下是完整的可执行代码:

def lin_comb(v1, v2, beta=0.9): return beta*v1 + (1-beta)*v2

from datetime import datetime
test = pd.DataFrame([
    {'id': 1, 'date': datetime.fromisoformat('2016-01-01'), 'a': 1}, 
    {'id': 1, 'date': datetime.fromisoformat('2016-01-02'), 'a': 2}, 
    {'id': 1, 'date': datetime.fromisoformat('2016-01-03'), 'a': 3}]
)

lin_com_list = []
c = 0.
for a in test['a']:
    c = lin_comb(c or a, a, 0.9)
    lin_com_list.append(c)

test['lin_comb'] = lin_com_list

我的问题:Pandas中是否有一个内置函数可以生成与上述相同的输出

我问的原因主要是性能。在数百万条记录上执行此函数时,此代码相当慢


Tags: 函数fromtestimportcomiddatetimedate
2条回答

我认为pandas中没有用于这种递归操作的内置函数。但我认为这是一个很好的例子。我是新手,所以可能有更好的方法,但想法是:

from numba import jit

@jit
def numba_comb(arr_in, beta=0.9): 
    arr_out = np.zeros_like(arr_in)
    c = 0.
    for i in range(arr_in.shape[0]):
        a = arr_in[i]
        c = beta*(c or a) + (1-beta)*a
        arr_out[i] = c
    return arr_out

比较

def lin_comb(v1, v2, beta=0.9): return beta*v1 + (1-beta)*v2

def list_comb (ser, beta=0.9):
    lin_com_list = []
    c = 0.
    for a in ser:
        c = lin_comb(c or a, a, beta)
        lin_com_list.append(c)
    return lin_com_list

然后给出:

test = pd.DataFrame({'a':range(1, 10000)})

# list solution
%timeit list_comb (test['a'], 0.9)
#3.51 ms ± 148 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# numba
%timeit numba_comb(test['a'].to_numpy().astype(float), 0.9)
#63.8 µs ± 990 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

#same result
print ((np.array(list_comb (test['a'])) 
        == numba_comb(test['a'].to_numpy().astype(float), 0.9)).all())
#True

事实上,熊猫支持指数加权平均值,这或多或少就是我想要的

对于我的特定问题,我最终使用了带有平均值的Pandasewm函数。本质上,这是一行代码,帮助我计算特定天数内的指数移动平均数:

dt[esmean_col] = grouped_sales.transform(lambda x : x.ewm(alpha=1/win, adjust=False).mean())

win变量是窗口中的天数,在我的例子中是7

这个实现的性能非常好,因为我可以在30秒内处理4400万条记录

有关Pandasewm函数here的更多信息

相关问题 更多 >