在matplotlib中绘制相关图

13 投票
2 回答
40743 浏览
提问于 2025-04-17 06:27

假设我有一个包含离散向量的数据集,n=2,也就是说这个数据集有两个变量:

DATA = [
    ('a', 4),
    ('b', 5),
    ('c', 5),
    ('d', 4),
    ('e', 2),
    ('f', 5),
]

我该如何使用matplotlib这个工具来绘制这个数据集,以便能看到这两个变量之间是否有关系呢?

如果能给出一些简单的代码示例就太好了。

2 个回答

7

我有点困惑……有几种方法可以做到这一点。首先想到的两种是简单的茎叶图或者散点图。

你是想用像这样的茎叶图来绘制数据吗?

import matplotlib.pyplot as plt
data = [
    ('a', 4),
    ('b', 5),
    ('c', 5),
    ('d', 4),
    ('e', 2),
    ('f', 5),
]
labels, y = zip(*data)

x = range(len(y))
plt.stem(x, y)
plt.xticks(x, labels)
plt.axis([-1, 6, 0, 6])
plt.show()

在这里输入图片描述

还是像这样的散点图:

import matplotlib.pyplot as plt
data = [
    ('a', 4),
    ('b', 5),
    ('c', 5),
    ('d', 4),
    ('e', 2),
    ('f', 5),
]
labels, y = zip(*data)

x = range(len(y))
plt.plot(x, y, 'o')
plt.xticks(x, labels)
plt.axis([-1, 6, 0, 6])
plt.show()

在这里输入图片描述

还是完全不同的东西呢?

20

Joe Kington 给出的答案是正确的,但你的 DATA 可能比看起来的要复杂。它可能在 'a' 这个位置有多个值。Joe 构建 x 轴值的方法很快,但只适用于一组独特的值。可能还有更快的方法来实现这个,但这是我完成的方式:

import matplotlib.pyplot as plt

def assignIDs(list):
    '''Take a list of strings, and for each unique value assign a number.
    Returns a map for "unique-val"->id.
    '''
    sortedList = sorted(list)

    #taken from
    #http://stackoverflow.com/questions/480214/how-do-you-remove-duplicates-from-a-list-in-python-whilst-preserving-order/480227#480227
    seen = set()
    seen_add = seen.add
    uniqueList =  [ x for x in sortedList if x not in seen and not seen_add(x)]

    return  dict(zip(uniqueList,range(len(uniqueList))))

def plotData(inData,color):
    x,y = zip(*inData)

    xMap = assignIDs(x)
    xAsInts = [xMap[i] for i in x]


    plt.scatter(xAsInts,y,color=color)
    plt.xticks(xMap.values(),xMap.keys())


DATA = [
    ('a', 4),
    ('b', 5),
    ('c', 5),
    ('d', 4),
    ('e', 2),
    ('f', 5),
]


DATA2 = [
    ('a', 3),
    ('b', 4),
    ('c', 4),
    ('d', 3),
    ('e', 1),
    ('f', 4),
    ('a', 5),
    ('b', 7),
    ('c', 7),
    ('d', 6),
    ('e', 4),
    ('f', 7),
]

plotData(DATA,'blue')
plotData(DATA2,'red')

plt.gcf().savefig("correlation.png")

我的 DATA2 数据集中,每个 x 轴值都有两个值。下面的图用红色表示:

enter image description here

编辑

你问的问题很广泛。我搜索了“相关性”,在 维基百科 上找到了关于皮尔逊积矩相关系数的不错讨论,这个系数用来描述线性拟合的斜率。请记住,这个值只是一个参考,不能预测线性拟合是否合理,具体可以查看上面页面关于 相关性和线性 的说明。这里是一个更新后的 plotData 方法,它使用 numpy.linalg.lstsq 来进行线性回归,并用 numpy.corrcoef 来计算皮尔逊相关系数 R:

import matplotlib.pyplot as plt
import numpy as np

def plotData(inData,color):
    x,y = zip(*inData)

    xMap = assignIDs(x)
    xAsInts = np.array([xMap[i] for i in x])

    pearR = np.corrcoef(xAsInts,y)[1,0]
    # least squares from:
    # http://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.lstsq.html
    A = np.vstack([xAsInts,np.ones(len(xAsInts))]).T
    m,c = np.linalg.lstsq(A,np.array(y))[0]

    plt.scatter(xAsInts,y,label='Data '+color,color=color)
    plt.plot(xAsInts,xAsInts*m+c,color=color,
             label="Fit %6s, r = %6.2e"%(color,pearR))
    plt.xticks(xMap.values(),xMap.keys())
    plt.legend(loc=3)

新的图形是:

enter image description here

另外,将每个方向的数据展平并查看各自的分布可能会很有用,关于 在 matplotlib 中这样做的例子 有很多:

enter image description here

如果线性近似是有用的,你可以通过观察拟合效果来判断,可能在展平 y 方向之前先减去这个趋势会更好。这有助于显示你有一个围绕线性趋势的高斯随机分布。

撰写回答