seaborn clustermap 中的下三角遮罩 bug

1 投票
1 回答
31 浏览
提问于 2025-04-12 12:05

要重现这个错误,你可以使用这个数据框(是用 .to_clipboard 方法从 pandas.DataFrame 对象创建的,这样你可以很方便地导入它)。

    01:01   01:02   01:03   01:04   01:05   01:06   01:07   01:10   01:12   01:21   02:01   03:01   03:02   03:03   04:01   04:02   04:04   05:01   05:02   05:03   05:05   05:08   05:09   05:11   05:13   06:01
01:01   0   0   0   1   1   0   1   0   1   0   4   4   4   4   5   5   5   6   6   6   6   6   6   6   6   5
01:02   0   0   0   1   1   0   1   0   1   0   4   4   4   4   5   5   5   6   6   6   6   6   6   6   6   5
01:03   0   0   0   1   1   0   1   0   1   0   4   4   4   4   5   5   5   6   6   6   6   6   6   6   6   5
01:04   1   1   1   0   0   1   0   1   0   1   5   5   5   5   6   6   6   7   7   7   7   7   7   7   7   6
01:05   1   1   1   0   0   1   0   1   0   1   5   5   5   5   6   6   6   7   7   7   7   7   7   7   7   6
01:06   0   0   0   1   1   0   1   0   1   0   4   4   4   4   5   5   5   6   6   6   6   6   6   6   6   5
01:07   1   1   1   0   0   1   0   1   0   1   5   5   5   5   6   6   6   7   7   7   7   7   7   7   7   6
01:10   0   0   0   1   1   0   1   0   1   0   4   4   4   4   5   5   5   6   6   6   6   6   6   6   6   5
01:12   1   1   1   0   0   1   0   1   0   1   5   5   5   5   6   6   6   7   7   7   7   7   7   7   7   6
01:21   0   0   0   1   1   0   1   0   1   0   4   4   4   4   5   5   5   6   6   6   6   6   6   6   6   5
02:01   4   4   4   5   5   4   5   4   5   4   0   2   2   2   3   3   3   4   4   4   4   4   4   4   4   3
03:01   4   4   4   5   5   4   5   4   5   4   2   0   0   0   3   3   3   4   4   4   4   4   4   4   4   3
03:02   4   4   4   5   5   4   5   4   5   4   2   0   0   0   3   3   3   4   4   4   4   4   4   4   4   3
03:03   4   4   4   5   5   4   5   4   5   4   2   0   0   0   3   3   3   4   4   4   4   4   4   4   4   3
04:01   5   5   5   6   6   5   6   5   6   5   3   3   3   3   0   0   0   1   1   1   1   1   1   1   1   0
04:02   5   5   5   6   6   5   6   5   6   5   3   3   3   3   0   0   0   1   1   1   1   1   1   1   1   0
04:04   5   5   5   6   6   5   6   5   6   5   3   3   3   3   0   0   0   1   1   1   1   1   1   1   1   0
05:01   6   6   6   7   7   6   7   6   7   6   4   4   4   4   1   1   1   0   0   0   0   0   0   0   0   1
05:02   6   6   6   7   7   6   7   6   7   6   4   4   4   4   1   1   1   0   0   0   0   0   0   0   0   1
05:03   6   6   6   7   7   6   7   6   7   6   4   4   4   4   1   1   1   0   0   0   0   0   0   0   0   1
05:05   6   6   6   7   7   6   7   6   7   6   4   4   4   4   1   1   1   0   0   0   0   0   0   0   0   1
05:08   6   6   6   7   7   6   7   6   7   6   4   4   4   4   1   1   1   0   0   0   0   0   0   0   0   1
05:09   6   6   6   7   7   6   7   6   7   6   4   4   4   4   1   1   1   0   0   0   0   0   0   0   0   1
05:11   6   6   6   7   7   6   7   6   7   6   4   4   4   4   1   1   1   0   0   0   0   0   0   0   0   1
05:13   6   6   6   7   7   6   7   6   7   6   4   4   4   4   1   1   1   0   0   0   0   0   0   0   0   1
06:01   5   5   5   6   6   5   6   5   6   5   3   3   3   3   0   0   0   1   1   1   1   1   1   1   1   0

这是一个26x26的矩阵。我参考了这个StackOverflow帖子中的代码(使用seaborn的聚类图制作下三角形遮罩)来写的:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

matrix = matrix.astype(int)

# Generate a clustermap
cg = sns.clustermap(matrix, annot=True, cmap = "Blues", cbar_pos=(.09, .6, .05, .2))

# Mask the lower triangle
mask = np.tril(np.ones_like(matrix))
values = cg.ax_heatmap.collections[0].get_array().reshape(matrix.shape)
new_values = np.ma.array(values, mask=mask)

cg.ax_heatmap.collections[0].set_array(new_values)
cg.ax_row_dendrogram.set_visible(False)
cg.ax_col_dendrogram.set_visible(False)
cg.savefig("dqa_eplet_distances_abv.png", dpi=600)

但是我得到的结果是:

这里输入图片描述

为什么这不是我预期的矩阵的三角形版本呢?原帖子中的最小可重现示例和我的代码几乎一样,我搞不懂哪里出错了。

我使用的是 seaborn-0.13.2Python 3.11.1

1 个回答

1

这个代码可以很好地隐藏数值,但对注释不起作用

一种解决方法是遍历文本,隐藏那些在对角线以下的部分。

使用Python可以这样做:

n = matrix.shape[0]

for idx, t in enumerate(cg.ax_heatmap.texts):
    if idx//n >= idx%n:
        t.set_visible(False)

或者用numpy也可以:

for idx in np.ravel_multi_index(np.tril_indices_from(matrix), matrix.shape):
    cg.ax_heatmap.texts[idx].set_visible(False)

在这里输入图片描述

撰写回答