带matplotlib的对角线热图

2024-04-25 14:53:19 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一张用熊猫制作的热图:

tukey = tukey.set_index('index')
 
fix,ax = plt.subplots(figsize=(12,6))
ax.set_title(str(date)+' '+ str(hour)+':'+'00',fontsize=14)
heatmap_args = {'linewidths': 0.35, 'linecolor': '0.5', 'clip_on': False, 'square': True, 'cbar_ax_bbox': [0.75, 0.35, 0.04, 0.3]}
sp.sign_plot(tukey, **heatmap_args)

enter image description here

我曾尝试使用seaborn进行此操作,但未获得所需的输出:

# Generate a mask for the upper triangle
mask = np.triu(np.ones_like(tukey, dtype=bool))
# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(12, 6))
# Generate a custom diverging colormap
cmap = sns.diverging_palette(230, 20, as_cmap=True)
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(tukey, mask=mask, cmap=cmap, vmax=.3, center=0,
                square=True, linewidths=.5, cbar_kws={"shrink": .5})

enter image description here

如图所示,它仍然显示了应该被屏蔽的正方形,显然cbar是不同的

我的问题是,是否有办法不使用seaborn使其成为对角线?或者至少是为了摆脱重复部分?

编辑:我的数据帧示例(tukey):

>>>     1_a    1_b      1_c     1_d      1_e    1_f
index
1_a     1.00    0.900  0.75      0.736    0.900  0.400
1_b     0.9000  1.000  0.72      0.715    0.900  0.508
1_c     0.756   0.342  1.000     0.005    0.124  0.034
1_d     0.736   0.715  0.900     1.000    0.081  0.030 
1_e     0.900   0.900  0.804     0.793    1.000  0.475
1_f     0.400   0.508  0.036     0.030    0.475  1.000

*我可能有打字错误,对角线两边应该相等

编辑: 进口:

import scikit_posthocs as sp
import pandas as pd
import numpy as np
import statsmodels.api as sm
import scipy.stats as stats
from statsmodels.formula.api import ols

import matplotlib.pyplot as plt
import scipy.stats as stats

import seaborn as sns

Tags: theimporttrueindexasstatsnpplt
1条回答
网友
1楼 · 发布于 2024-04-25 14:53:19

scikit_posthocssign_plot()似乎创建了一个QuadMesh(就像sns.heatmap)。为此类网格设置边颜色将显示网格全宽和全高的水平线和垂直线。要使边缘在“空”区域中不可见,可以将其着色为与背景相同的颜色(例如白色)。通过将单个单元格的值设置为NaN,可以使其不可见,如下面的代码所示

删除列和行(例如tukey.drop('1_f', axis=1, inplace=True)tukey.drop('1_a', axis=0, inplace=True)),这无助于使绘图变得更小,因为sign_plot会自动将它们添加回绘图中

import matplotlib.pyplot as plt
import scikit_posthocs as sp
import pandas as pd
import numpy as np
from io import StringIO

data_str = '''     1_a    1_b      1_c     1_d      1_e    1_f
1_a     1.00    0.900  0.75      0.736    0.900  0.400
1_b     0.9000  1.000  0.72      0.715    0.900  0.508
1_c     0.756   0.342  1.000     0.005    0.124  0.034
1_d     0.736   0.715  0.900     1.000    0.081  0.030 
1_e     0.900   0.900  0.804     0.793    1.000  0.475
1_f     0.400   0.508  0.036     0.030    0.475  1.000'''
tukey = pd.read_csv(StringIO(data_str), delim_whitespace=True)

cols = tukey.columns
for i in range(len(cols)):
    for j in range(i, len(cols)):
        tukey.iloc[i, j] = np.nan

fix, ax = plt.subplots(figsize=(12, 6))
heatmap_args = {'linewidths': 0.35, 'linecolor': 'white', 'clip_on': False, 'square': True,
                'cbar_ax_bbox': [0.75, 0.35, 0.04, 0.3]}
sp.sign_plot(tukey, **heatmap_args)
plt.show()

resulting plot

相关问题 更多 >