在未记录所有数据点的情况下,如何在matplotlib散点图中添加y=x线

40 投票
4 回答
92527 浏览
提问于 2025-04-18 18:31

这里有一段代码,它使用matplotlib库绘制了多个不同系列的散点图,并且在图上添加了一条y=x的线:

import numpy as np, matplotlib.pyplot as plt, matplotlib.cm as cm, pylab

nseries = 10
colors = cm.rainbow(np.linspace(0, 1, nseries))

all_x = []
all_y = []
for i in range(nseries):
    x = np.random.random(12)+i/10.0
    y = np.random.random(12)+i/5.0
    plt.scatter(x, y, color=colors[i])
    all_x.extend(x)
    all_y.extend(y)

# Could I somehow do the next part (add identity_line) if I haven't been keeping track of all the x and y values I've seen?
identity_line = np.linspace(max(min(all_x), min(all_y)),
                            min(max(all_x), max(all_y)))
plt.plot(identity_line, identity_line, color="black", linestyle="dashed", linewidth=3.0)

plt.show()

为了实现这个效果,我需要记录所有散点图中的x和y值,这样我才能知道identity_line应该从哪里开始,到哪里结束。有没有办法让我在没有所有绘制点的列表的情况下,依然能显示y=x这条线呢?我觉得matplotlib中应该有某种方法可以让我在绘制完成后获取到所有的点,但我还没找到获取这个列表的方法。

4 个回答

9

如果你把scalex和scaley都设置为False,这样可以节省一些管理上的麻烦。这是我最近用来叠加y=x的方式:

xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, linestyle='--', color='k', lw=3, scalex=False, scaley=False)

或者如果你有一个坐标轴的话:

xpoints = ypoints = ax.get_xlim()
ax.plot(xpoints, ypoints, linestyle='--', color='k', lw=3, scalex=False, scaley=False)

当然,这样做不会让你的图形保持正方形的比例。如果你在意这个问题,可以参考Paul H的解决方案。

15

一句话总结:

ax.plot([0,1],[0,1], transform=ax.transAxes)

不需要调整x轴或y轴的范围。

73

你其实不需要对你的数据有太多了解。你只需要关注matplotlib的Axes对象会告诉你关于数据的信息就可以了。

下面看看:

import numpy as np
import matplotlib.pyplot as plt

# random data 
N = 37
x = np.random.normal(loc=3.5, scale=1.25, size=N)
y = np.random.normal(loc=3.4, scale=1.5, size=N)
c = x**2 + y**2

# now sort it just to make it look like it's related
x.sort()
y.sort()

fig, ax = plt.subplots()
ax.scatter(x, y, s=25, c=c, cmap=plt.cm.coolwarm, zorder=10)

这部分很不错:

lims = [
    np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
    np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
]

# now plot both limits against eachother
ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0)
ax.set_aspect('equal')
ax.set_xlim(lims)
ax.set_ylim(lims)
fig.savefig('/Users/paul/Desktop/so.png', dpi=300)

看!就是这样

在这里输入图片描述

27

从matplotlib 3.3版本开始,绘制直线变得非常简单,只需要用到axline这个方法,它只需要一个点和一个斜率。比如,要画出x=y的直线:

ax.axline((0, 0), slope=1)

你不需要查看你的数据就可以使用这个方法,因为你指定的那个点(比如这里的(0,0))实际上并不需要在你的数据范围内或者绘图范围内。

撰写回答