panda有条件地更新列np.哪里()

2024-04-20 08:37:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图做一个交易回溯测试在熊猫和有一些问题'如果'语句使用np.哪里()有条件地更新其他列。在

我的初始df,其中signal指示是买入还是卖出(1/-1/0),根据这些信号,我想更新Cash、Hold、Value和Total列。在

                        open         high        low        close   change  signal  Cash  Hold Value Total 
time                                        
2017-09-09 03:01:00 4255.000000 4256.799805 4233.600098 4252.799805 -0.000065   0   10000.0 0.0 0.0 10000.0
2017-09-09 03:02:00 4251.399902 4258.500000 4247.500000 4258.399902 0.002046    1   10000.0 0.0 0.0 10000.0
2017-09-09 03:03:00 4256.500000 4289.299805 4256.500000 4273.700195 0.001262    1   10000.0 0.0 0.0 10000.0
2017-09-09 03:04:00 4273.100098 4299.899902 4262.580566 4284.100098 0.001905    1   10000.0 0.0 0.0 10000.0
2017-09-09 03:05:00 4291.200195 4299.799805 4284.200195 4289.899902 -0.000854   -1  10000.0 0.0 0.0 10000.0
2017-09-09 03:06:00 4295.000000 4298.799805 4279.500000 4279.500000 -0.000047   0   10000.0 0.0 0.0 10000.0
2017-09-09 03:07:00 4278.000000 4278.299805 4277.000000 4277.799805 -0.000244   0   10000.0 0.0 0.0 10000.0

我可以根据信号手动调用以下每个函数:

^{pr2}$

这样就产生了:

                        open         high        low        close   change  signal  Cash        Hold       Value      Total 
time                                                        
2017-09-09 03:01:00 4255.000000 4256.799805 4233.600098 4252.799805 -0.000065   0   10000.00000 0.000000    0.000000    10000.000000
2017-09-09 03:02:00 4251.399902 4258.500000 4247.500000 4258.399902 0.002046    1   9900.00000  0.023483    100.000000  10000.000000
2017-09-09 03:03:00 4256.500000 4289.299805 4256.500000 4273.700195 0.001262    1   9800.00000  0.046882    200.359297  10000.359297
2017-09-09 03:04:00 4273.100098 4299.899902 4262.580566 4284.100098 0.001905    1   9700.00000  0.070224    300.846864  10000.846864
2017-09-09 03:05:00 4291.200195 4299.799805 4284.200195 4289.899902 -0.000854   -1  10001.25415 0.000000    0.000000    10001.254150
2017-09-09 03:06:00 4295.000000 4298.799805 4279.500000 4279.500000 -0.000047   0   10001.25415 0.000000    0.000000    10001.254150
2017-09-09 03:07:00 4278.000000 4278.299805 4277.000000 4277.799805 -0.000244   0   10001.25415 0.000000    0.000000    10001.254150

我以为np.哪里()会根据signal列调用正确的函数,但我没有任何运气。下面的循环覆盖每一行。在

for i in range(len(pf)):
    np.where(pf['signal'].iloc[i] == -1, sell_update(i), np.where(pf['signal'].iloc[i] == 1, buy_update(i), no_action(i)))
    print(i)

我认为它当前调用了每个函数—sell,然后buy,然后none(每个函数都覆盖最后一个),并生成一个SettingWithCopyWarning警告。在

而且每一行上的for循环显然非常慢,有没有一种方法可以将其矢量化?在


Tags: 函数closesignaltime信号valuenpcash
1条回答
网友
1楼 · 发布于 2024-04-20 08:37:33

当计算代码变得复杂时,很难将其矢量化。由于pandas中逐个元素的处理速度很慢,您可以将dataframe转换为dict列表,并进行计算,下面是一个使用cytoolz的示例:

import io
import pandas as pd

text="""time                        open         high        low        close   change  signal  Cash  Hold Value Total 
2017-09-09 03:01:00 4255.000000 4256.799805 4233.600098 4252.799805 -0.000065   0   10000.0 0.0 0.0 10000.0
2017-09-09 03:02:00 4251.399902 4258.500000 4247.500000 4258.399902 0.002046    1   10000.0 0.0 0.0 10000.0
2017-09-09 03:03:00 4256.500000 4289.299805 4256.500000 4273.700195 0.001262    1   10000.0 0.0 0.0 10000.0
2017-09-09 03:04:00 4273.100098 4299.899902 4262.580566 4284.100098 0.001905    1   10000.0 0.0 0.0 10000.0
2017-09-09 03:05:00 4291.200195 4299.799805 4284.200195 4289.899902 -0.000854   -1  10000.0 0.0 0.0 10000.0
2017-09-09 03:06:00 4295.000000 4298.799805 4279.500000 4279.500000 -0.000047   0   10000.0 0.0 0.0 10000.0
2017-09-09 03:07:00 4278.000000 4278.299805 4277.000000 4277.799805 -0.000244   0   10000.0 0.0 0.0 10000.0"""
df = pd.read_csv(io.StringIO(text), delim_whitespace=True)
trade_size = 100

import cytoolz

def f(p, c):
    change = c["signal"]   
    if change == 0:
        cash = c["Cash"]
        hold = c["Hold"]        
    elif change == 1:
        cash = p["Cash"] - trade_size
        hold = p["Hold"] + trade_size / c["close"]
    elif change == -1:
        cash = p["Cash"] + p["Hold"] * c["close"]
        hold = 0
    return cytoolz.merge(c, {"Cash":cash, "Hold":hold})

pd.DataFrame(list(cytoolz.accumulate(f, df.to_dict("records"))))

相关问题 更多 >