使用字典创建线图

2024-05-26 21:53:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一本字典如下:

my_dict={2:((4, 0.56),(8, 0.75)), 6:((3, 0.05),(5, 0.46)), 7: ((4, 0.99),(1, 0.56))}

我想创建带有错误条的线图: 字典键位于x轴,值位于y轴。此外,每个值的第一个元组的项需要在一行中,每个值的第二个元组的项需要在另一行中

类似于下面的内容

import matplotlib.pyplot as plt
my_dict={2:((4, 0.56),(8, 0.75)), 6:((3, 0.05),(5, 0.46)), 7: ((4, 0.99),(1, 0.56))}

x_vals=[2,6,7]
line1=[4,3,4]
line2=[8,5,1]
errorbar1=[0.56,0.05,0.99]
errorbar2=[0.75,0.46,0.56]
plt.plot(x_vals, line1, linestyle='dotted')
plt.plot(x_vals, line2, linestyle='dotted')
plt.errorbar(x_vals, line1, yerr=errorbar1, fmt=' ')
plt.errorbar(x_vals, line2, yerr=errorbar2, fmt=' ')
plt.xlabel('x axis')
plt.ylabel('yaxis')
plt.show()

enter image description here


Tags: 字典plotmypltdict元组dottedvals
2条回答

我想你可以简单地用列表理解来代替硬编码的列表

x_vals=[2,6,7]
line1=[4,3,4]
line2=[8,5,1]
errorbar1=[0.56,0.05,0.99]
errorbar2=[0.46,0.75,0.56]

x_vals=sorted(my_dict.keys())
line1=[my_dict[k][0][0] for k in x_vals] # [4,3,4]
line2=[my_dict[k][1][0] for k in x_vals] # [8,5,1]
errorbar1=[my_dict[k][0][1] for k in x_vals] # [0.56,0.05,0.99]
errorbar2=[my_dict[k][1][1] for k in x_vals] # [0.46,0.75,0.56]
  • 您需要解压缩keysvaluesmy_dict
  • Frompython >=3.6{}是insertion ordered,因此提取的值不需要排序
  • python 3.8pandas 1.3.1matplotlib 3.4.2seaborn 0.11.1中测试
my_dict = {21: ((0.667, 0.126), (0.63, 0.068)), 52: ((0.679, 0.059), (0.637, 0.078)), 73: ((0.612, 0.211), (0.519, 0.143)), 94: ((0.709, 0.09), (0.711, 0.097))}

x_vals = my_dict.keys()

line1 = [v[0][0] for v in my_dict.values()]
line2 = [v[1][0] for v in my_dict.values()]

errorbar1 = [v[0][1] for v in my_dict.values()]
errorbar2 = [v[1][1] for v in my_dict.values()]

plt.plot(x_vals, line1, linestyle='dotted')
plt.plot(x_vals, line2, linestyle='dotted')
plt.errorbar(x_vals, line1, yerr=errorbar1, fmt=' ')
plt.errorbar(x_vals, line2, yerr=errorbar2, fmt=' ')
plt.xlabel('x axis')
plt.ylabel('yaxis')
plt.show()

enter image description here

使用pandas

  • pandas 1.3.1matplotlib 3.4.2测试
  • 这将代码从13行减少到7行
import pandas as pd

my_dict = {21: ((0.667, 0.126), (0.63, 0.068)), 52: ((0.679, 0.059), (0.637, 0.078)), 73: ((0.612, 0.211), (0.519, 0.143)), 94: ((0.709, 0.09), (0.711, 0.097))}

# load the dictionary into pandas
df = pd.DataFrame.from_dict(my_dict, orient='index', columns=['line1', 'line2'])

# display(df)
             line1           line2
21  (0.667, 0.126)   (0.63, 0.068)
52  (0.679, 0.059)  (0.637, 0.078)
73  (0.612, 0.211)  (0.519, 0.143)
94   (0.709, 0.09)  (0.711, 0.097)

# separate the tuples to columns
for i, col in enumerate(df.columns, 1):
    df[[col, f'errorbar{i}']] = pd.DataFrame(df[col].tolist(), index= df.index)

# display(df)
    line1  line2  errorbar1  errorbar2
21  0.667  0.630      0.126      0.068
52  0.679  0.637      0.059      0.078
73  0.612  0.519      0.211      0.143
94  0.709  0.711      0.090      0.097

# plot
ax = df.plot(y=['line1', 'line2'], linestyle='dotted', ylabel='y-axis', xlabel='x-axis', title='title', figsize=(8, 6))
ax.errorbar(df.index, 'line1', yerr='errorbar1', data=df, fmt=' ')
ax.errorbar(df.index, 'line2', yerr='errorbar2', data=df, fmt=' ')

enter image description here

更新

  • 我意识到前面的所有代码都是为了适应您的其他question
  • 这两个问题中的所有内容都可以忽略,绘图都是为了使数据符合绘图API
  • 如果使用另一个问题中的rl,则可以直接转换为长形式,并用^{}绘制。
    • 正如您在前面的绘图中所看到的,错误条重叠,使得绘图更难读取。此处dodge用于稍微偏移点,因此错误条不会重叠
import seaborn as sns
import pandas
import matplotlib.pyplot as plt

# using rl from the other question convert to a long form
rl = [{21: (0.5714285714285714, 0.6888888888888889), 52: (0.6153846153846154, 0.7111111111111111), 73: (0.7123287671232876, 0.6222222222222222), 94: (0.7127659574468085, 0.6)}, {21: (0.6190476190476191, 0.6444444444444445), 52: (0.6923076923076923, 0.6444444444444445), 73: (0.3698630136986301, 0.35555555555555557), 94: (0.7978723404255319, 0.7777777777777778)}, {21: (0.8095238095238095, 0.5555555555555556), 52: (0.7307692307692307, 0.5555555555555556), 73: (0.7534246575342466, 0.5777777777777777), 94: (0.6170212765957447, 0.7555555555555555)}]
df = pd.DataFrame(rl).melt(var_name='Sample Size')

df[['Train', 'Test']] = pd.DataFrame(df['value'].tolist(), index= df.index)
df.drop('value', axis=1, inplace=True)

df = df.melt(id_vars='Sample Size', var_name='Score')

# display(df)
   Sample Size  Score     value
0           21  Train  0.571429
1           21  Train  0.619048
2           21  Train  0.809524
3           52  Train  0.615385
4           52  Train  0.692308

# plot
fig = plt.figure(figsize=(8, 5))
p = sns.pointplot(data=df, x='Sample Size', y='value', hue='Score', ci='sd', dodge=0.25, linestyles='dotted')
p.set(ylabel='Mean of Trials', title='Score Metrics')

enter image description here

# given df, you can still see the metrics with
dfg = df.groupby(['Sample Size', 'Score']).agg(['mean', 'std'])

# display(dfg)
                      value          
                       mean       std
Sample Size Score                    
21          Test   0.629630  0.067890
            Train  0.666667  0.125988
52          Test   0.637037  0.078042
            Train  0.679487  0.058751
73          Test   0.518519  0.142869
            Train  0.611872  0.210591
94          Test   0.711111  0.096864
            Train  0.709220  0.090478

相关问题 更多 >

    热门问题