如何在Python中透视数据框

1 投票
2 回答
56 浏览
提问于 2025-04-13 00:28

我有一个数据集,长得像这样:

StID          SubClassID   PhOrder     PhAmt        LbOrder      LbAmt
4326          200572288    Anti1       23.3         Asprin       13.7
4326          200572288    Anti2       39.3         Morphin      2.2
4326          200572288    NULL        NULL         Medine       30.5
4326          200572288    Anti3       13.5         Kabomin      20.3
4326          200572288    NULL        NULL         Zorifin      0.2
0993          200348299    Anti1       8.4          Zorifin      9.7
0993          200348299    Anti2       10.9         Zorifin      95.6
0993          200348299    Anti4       48.9         NULL         NULL

我想用一种叫做“独热编码”的方法把这个表格变得更简单,也就是把它变成这样:

StID          SubClassID   Anti1 Anti2 Anti3 Anti4 Asprin  Morphin Medine Kabomin Zorifin
4326          200572288    23.3  39.3  13.5  NULL  13.7    2.2     30.5   20.3    0.2
0993          200348299    8.4   10.9  NULL  48.9  NULL    NULL    NULL   NULL    52.65

这是我写的代码,但我发现生成了很多重复的列,而且里面的值也不对:

data = {
    'StId': [4326, 4326, 4326, 4326, 4326, 993, 993, 993],
    'SubClassID': [200572288, 200572288, 200572288, 200572288, 200572288, 200348299, 200348299, 200348299],
    'PhOrder': ['Anti1', 'Anti2', 'NULL', 'Anti3', 'NULL', 'Anti1', 'Anti2', 'Anti4'],
    'PhAmt': [23.3, 39.3, None, 13.5, None, 8.4, 10.9, 48.9],
    'LbOrder': ['Asprin', 'Morphin', 'Medine', 'Kabomin', 'Zorifin', 'Zorifin', 'Zorifin', None],
    'LbAmt': [13.7, 2.2, 30.5, 20.3, 0.2, 9.7, 95.6, None]
}

df = pd.DataFrame(data)
df_pivot = df.pivot_table(index=['StId', 'SubClassID'], columns=['PhOrder','LbOrder'], values=['PhAmt','LbAmt'], aggfunc='first')

2 个回答

2

理解这个操作的一种通用方法是先使用 lreshape,然后再用 pivot_table

out = (pd.lreshape(df, {'col': ['PhOrder', 'LbOrder'], 
                        'value': ['PhAmt', 'LbAmt']})
         .pivot_table(index=['StId', 'SubClassID'], columns='col',
                      values='value', aggfunc='sum')
         .reset_index().rename_axis(columns=None)
      )

输出结果:

   StId  SubClassID  Anti1  Anti2  Anti3  Anti4  Asprin  Kabomin  Medine  Morphin  Zorifin
0   993   200348299    8.4   10.9    NaN   48.9     NaN      NaN     NaN      NaN    105.3
1  4326   200572288   23.3   39.3   13.5    NaN    13.7     20.3    30.5      2.2      0.2
2

试试这个:

df1 = df.pivot_table(index=["StId", "SubClassID"], columns="PhOrder", values="PhAmt")
df2 = df.pivot_table(index=["StId", "SubClassID"], columns="LbOrder", values="LbAmt")
print(pd.concat([df1, df2], axis=1).reset_index())

输出结果是:

   StId  SubClassID  Anti1  Anti2  Anti3  Anti4  Asprin  Kabomin  Medine  Morphin  Zorifin
0   993   200348299    8.4   10.9    NaN   48.9     NaN      NaN     NaN      NaN    52.65
1  4326   200572288   23.3   39.3   13.5    NaN    13.7     20.3    30.5      2.2     0.20

撰写回答