仅绘制热图的上三角/下三角

29 投票
6 回答
35209 浏览
提问于 2025-04-15 19:37

在matplotlib这个库里,你可以用imshow这个函数来创建一个热力图,显示一个相关性矩阵。简单来说,相关性矩阵是对称的,也就是说它的上半部分和下半部分是一样的,所以其实只需要展示一半就可以了。比如说:

correlation matrix
(来源: wisc.edu)

上面的例子是从这个网站上拿来的。可惜的是,我不知道怎么在matplotlib里实现这个效果。如果把矩阵的上半部分或下半部分设置为None,就会出现黑色的三角形。我在网上搜索了“matplotlib 缺失值”,但没有找到什么有用的信息。

6 个回答

11
import numpy as NP
from matplotlib import pyplot as PLT
from matplotlib import cm as CM

A = NP.random.randint(10, 100, 100).reshape(10, 10)
# create an upper triangular 'matrix' from A
A2 = NP.triu(A)
fig = PLT.figure()
ax1 = fig.add_subplot(111)
# use dir(matplotlib.cm) to get a list of the installed colormaps
# the "_r" means "reversed" and accounts for why zero values are plotted as white
cmap = CM.get_cmap('gray_r', 10)
ax1.imshow(A2, interpolation="nearest", cmap=cmap)
ax1.grid(True)
PLT.show()

plot

13

我得到的最佳答案是来自seaborn库。这个输出结果看起来平滑而简单。这个函数可以把三角形保存到本地。

def get_lower_tri_heatmap(df, output="cooc_matrix.png"):
    mask = np.zeros_like(df, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True

    # Want diagonal elements as well
    mask[np.diag_indices_from(mask)] = False

    # Set up the matplotlib figure
    f, ax = plt.subplots(figsize=(11, 9))

    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(220, 10, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    sns_plot = sns.heatmap(data, mask=mask, cmap=cmap, vmax=.3, center=0,
            square=True, linewidths=.5, cbar_kws={"shrink": .5})
    # save to file
    fig = sns_plot.get_figure()
    fig.savefig(output)

下三角形

32

doug提供的答案有个问题,就是它假设颜色映射会把零值显示成白色。这就意味着,如果颜色映射里没有白色,那就没什么用。解决这个问题的关键在于cm.set_bad这个函数。你可以用None或者NumPy的掩码数组来遮住不需要的矩阵部分,然后把set_bad设置成白色,而不是默认的黑色。根据doug的例子,我们可以得到以下内容:

import numpy as NP
from matplotlib import pyplot as PLT
from matplotlib import cm as CM

A = NP.random.randint(10, 100, 100).reshape(10, 10)
mask =  NP.tri(A.shape[0], k=-1)
A = NP.ma.array(A, mask=mask) # mask out the lower triangle
fig = PLT.figure()
ax1 = fig.add_subplot(111)
cmap = CM.get_cmap('jet', 10) # jet doesn't have white color
cmap.set_bad('w') # default value is 'k'
ax1.imshow(A, interpolation="nearest", cmap=cmap)
ax1.grid(True)
PLT.show()

撰写回答