如何在 matplotlib 中制作以密度为色的散点图?

2024-06-08 22:48:02 发布

您现在位置:Python中文网/ 问答频道 /正文

我想做一个散点图,每个点都用附近点的空间密度来着色。

我遇到了一个非常类似的问题,它展示了一个使用R的例子:

R Scatter Plot: symbol color represents number of overlapping points

使用matplotlib在python中实现类似功能的最佳方法是什么?


Tags: of功能numberplotmatplotlib空间symbolpoints
3条回答

此外,如果点数使KDE计算太慢,则可以在np.histogram2d中插值颜色[根据注释更新:如果要显示颜色栏,请使用plt.scatter()而不是ax.scatter(),然后使用plt.colorbar()]:

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interpn

def density_scatter( x , y, ax = None, sort = True, bins = 20, **kwargs )   :
    """
    Scatter plot colored by 2d histogram
    """
    if ax is None :
        fig , ax = plt.subplots()
    data , x_e, y_e = np.histogram2d( x, y, bins = bins)
    z = interpn( ( 0.5*(x_e[1:] + x_e[:-1]) , 0.5*(y_e[1:]+y_e[:-1]) ) , data , np.vstack([x,y]).T , method = "splinef2d", bounds_error = False )

    # Sort the points by density, so that the densest points are plotted last
    if sort :
        idx = z.argsort()
        x, y, z = x[idx], y[idx], z[idx]

    ax.scatter( x, y, c=z, **kwargs )
    return ax


if "__main__" == __name__ :

    x = np.random.normal(size=100000)
    y = x * 3 + np.random.normal(size=100000)
    density_scatter( x, y, bins = [30,30] )

除了@askewchan建议的hist2dhexbin之外,您还可以使用与链接到的问题中接受的答案相同的方法。

如果你想这么做:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

# Generate fake data
x = np.random.normal(size=1000)
y = x * 3 + np.random.normal(size=1000)

# Calculate the point density
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)

fig, ax = plt.subplots()
ax.scatter(x, y, c=z, s=100, edgecolor='')
plt.show()

enter image description here

如果希望按密度顺序绘制点,以便最密集的点始终位于顶部(类似于链接的示例),只需按z值对它们进行排序。我还将使用更小的标记大小,因为它看起来更好:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

# Generate fake data
x = np.random.normal(size=1000)
y = x * 3 + np.random.normal(size=1000)

# Calculate the point density
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)

# Sort the points by density, so that the densest points are plotted last
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]

fig, ax = plt.subplots()
ax.scatter(x, y, c=z, s=50, edgecolor='')
plt.show()

enter image description here

你可以做一个直方图:

import numpy as np
import matplotlib.pyplot as plt

# fake data:
a = np.random.normal(size=1000)
b = a*3 + np.random.normal(size=1000)

plt.hist2d(a, b, (50, 50), cmap=plt.cm.jet)
plt.colorbar()

2dhist

相关问题 更多 >