matplotlib中有生成散点图矩阵的函数吗?
散点图矩阵的例子
在matplotlib.pyplot中有没有这样的功能?
5 个回答
17
你也可以使用 Seaborn库中的 pairplot
函数:
import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
124
对于那些不想自己定义函数的人,Python里有一个很棒的数据分析库,叫做 Pandas。在这个库里,你可以找到 scatter_matrix() 这个方法:
from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')
32
一般来说,matplotlib这个库通常不提供可以同时处理多个坐标轴(比如子图)的绘图函数。它的设计是希望你能自己写一个简单的函数,把你想要的内容组合在一起。
我不太清楚你的数据是什么样子的,但其实从头开始写一个函数来实现这个功能是很简单的。如果你总是处理结构化的数据或者记录数组,那你可以稍微简化一下这个过程。也就是说,每组数据都有一个名字,所以你可以省去指定名字的步骤。
举个例子:
import itertools
import numpy as np
import matplotlib.pyplot as plt
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
def scatterplot_matrix(data, names, **kwargs):
"""Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid."""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
axes[x,y].plot(data[x], data[y], **kwargs)
# Label the diagonal subplots...
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
return fig
main()