Pandas scattermatrix中的类标签

17 投票
2 回答
22845 浏览
提问于 2025-04-18 01:54

这个问题之前有人问过,关于散点矩阵中多个数据,但没有得到答案。

我想制作一个散点矩阵,类似于pandas文档中的示例,但是我希望不同类别的数据点用不同的颜色来表示。比如说,我希望某些点是绿色的,而另一些是蓝色的,这取决于某一列的值(或者一个单独的列表)。

这里有一个使用鸢尾花数据集的例子。点的颜色代表鸢尾花的种类——Setosa、Versicolor或Virginica。

带有类别标签的鸢尾花散点矩阵

pandas(或者matplotlib)有没有办法制作这样的图表呢?

2 个回答

19

你也可以这样使用pandas里的scattermatrix:

pd.scatter_matrix(df,color=colors)

这里的colors是一个列表,长度和df一样,里面包含了颜色信息。

29

更新:这个功能现在已经在Seaborn的最新版本中了。这里有个例子

以下是我临时使用的解决办法:

def factor_scatter_matrix(df, factor, palette=None):
    '''Create a scatter matrix of the variables in df, with differently colored
    points depending on the value of df[factor].
    inputs:
        df: pandas.DataFrame containing the columns to be plotted, as well 
            as factor.
        factor: string or pandas.Series. The column indicating which group 
            each row belongs to.
        palette: A list of hex codes, at least as long as the number of groups.
            If omitted, a predefined palette will be used, but it only includes
            9 groups.
    '''
    import matplotlib.colors
    import numpy as np
    from pandas.tools.plotting import scatter_matrix
    from scipy.stats import gaussian_kde

    if isinstance(factor, basestring):
        factor_name = factor #save off the name
        factor = df[factor] #extract column
        df = df.drop(factor_name,axis=1) # remove from df, so it 
        # doesn't get a row and col in the plot.

    classes = list(set(factor))

    if palette is None:
        palette = ['#e41a1c', '#377eb8', '#4eae4b', 
                   '#994fa1', '#ff8101', '#fdfc33', 
                   '#a8572c', '#f482be', '#999999']

    color_map = dict(zip(classes,palette))

    if len(classes) > len(palette):
        raise ValueError('''Too many groups for the number of colors provided.
We only have {} colors in the palette, but you have {}
groups.'''.format(len(palette), len(classes)))

    colors = factor.apply(lambda group: color_map[group])
    axarr = scatter_matrix(df,figsize=(10,10),marker='o',c=colors,diagonal=None)


    for rc in xrange(len(df.columns)):
        for group in classes:
            y = df[factor == group].icol(rc).values
            gkde = gaussian_kde(y)
            ind = np.linspace(y.min(), y.max(), 1000)
            axarr[rc][rc].plot(ind, gkde.evaluate(ind),c=color_map[group])

    return axarr, color_map

作为例子,我们将使用和问题中相同的数据集,可以在这里找到。

>>> import pandas as pd
>>> iris = pd.read_csv('iris.csv')
>>> axarr, color_map = factor_scatter_matrix(iris,'Name')
>>> color_map
{'Iris-setosa': '#377eb8',
 'Iris-versicolor': '#4eae4b',
 'Iris-virginica': '#e41a1c'}

iris_scatter_matrix

希望这对你有帮助!

撰写回答