带colorbar图例的交互式散点图

2024-04-16 11:40:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个3D数组(3,10)。

我有一个二维散点图,其中标记的颜色和大小取决于第三列。我被两件事困住了:试图让图例显示颜色和/或大小,以及尝试使标签交互以便我可以移动它们(这样它们就不会重叠)。这很难解释;我希望我下面的代码能带来一些澄清。

代码

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

aha = [0.1872, 0.0101, 0.0166, 0.0164, 0.0164, 0.0170, 0.0187, 0.0188, 0.0652, 0.0102]
ahaa = [0.2872, 0.0301, 0.0466, 0.0364, 0.0564, 0.0670, 0.0287, 0.0888, 0.0852, 0.0502]
dist = [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]
mod = ['One', 'Another', 'Other', 'That', 'This', 'Two', 'Three', 'Four', 'Five', 'Six']
N = 10
data = np.vstack((aha, ahaa, dist))
data = np.transpose(data)
labels = [mod[i].format(i) for i in range (N)]
plt.subplots_adjust(bottom = 0.1)
plt.scatter(
    data[:, 0], data[:, 1], marker = 'o', c = data[:, 0], s = data[:, 2]*1500,
    cmap = plt.get_cmap('Spectral'))
for label, x, y in zip(labels, data[:, 0], data[:, 1]):
    plt.annotate(
        label, 
        xy = (x, y), xytext = (-20, 20),
        textcoords = 'offset points', ha = 'right', va = 'bottom',
        bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
        arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))


plt.show()

我是Python新手,如果有任何帮助,我将非常感激。

问候 乔尔

*编辑

根据反馈,我现在也有这个代码

^{pr2}$

标签边界框的位置看起来不错,但是我遇到的问题是标签本身不在边界框中。标签位于每个点的正上方,而箭头和边界框位于不同的位置。

再次感谢您的帮助。


Tags: 代码importmodfordatalabels颜色dist
1条回答
网友
1楼 · 发布于 2024-04-16 11:40:00

多亏了@DavidG我才得以解决这个问题。在

我将张贴我的代码,以防将来它可以帮助某人

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from numpy.random import *
from matplotlib.colors import ListedColormap

def get_text_positions(x_data, y_data, txt_width, txt_height):
    a = zip(y_data, x_data)
    text_positions = y_data.copy()
    for index, (y, x) in enumerate(a):
        local_text_positions = [i for i in a if i[0] > (x - txt_height) 
                            and (abs(i[1] - x) < txt_width * 2) and i != (y,x)]
        if local_text_positions:
            sorted_ltp = sorted(local_text_positions)
            if abs(sorted_ltp[0][0] - y) < txt_height: #True == collision
                differ = np.diff(sorted_ltp, axis=0)
                a[index] = (sorted_ltp[-1][0] + txt_height, a[index][1])
                text_positions[index] = sorted_ltp[-1][0] + txt_height
                for k, (j, m) in enumerate(differ):
                    #j is the vertical distance between words
                    if j > txt_height * 2: #if True then room to fit a word in
                        a[index] = (sorted_ltp[k][0] + txt_height, a[index][1])
                        text_positions[index] = sorted_ltp[k][0] + txt_height
                        break
    return text_positions

def text_plotter(x_data, y_data, z_data, text_positions, axis,txt_width,txt_height):
    for x,y,z,t in zip(x_data, y_data, z_data, text_positions):
        axis.text(x - txt_width, 1.03*t, '%s'%z,rotation=0, color='blue')
        if y != t:
            axis.arrow(x, t,0,y-t, color='red',alpha=0.5, width=txt_width*0.1, 
                       head_width=txt_width, head_length=txt_height*0.5, 
                       zorder=0,length_includes_head=True)




aha = [0.1872, 0.0101, 0.0166, 0.0164, 0.0164, 0.0170, 0.0187, 0.0188, 0.0652, 0.0102]
ahaa = [0.2872, 0.0301, 0.0466, 0.0364, 0.0564, 0.0670, 0.0287, 0.0888, 0.0852, 0.0502]
dist = [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]
mod = ['One', 'Another', 'Other', 'That', 'This', 'Two', 'Three', 'Four', 'Five', 'Six']
N = 10
data = np.vstack((aha, ahaa, dist))
data = np.transpose(data)

fig1 = plt.figure()
ax2 = fig1.add_subplot(111)
labels = [mod[i].format(i) for i in range (N)]
plt.subplots_adjust(bottom = 0.1)
CS = ax2.scatter(
    data[:, 0], data[:, 1], marker = 'o', c = data[:, 0], s = data[:, 2]*1500, alpha = 0.5,
    cmap = plt.cm.Spectral)



##################################
x_data = data[:,0]
y_data = data[:,1]
z_data = mod
#set the bbox for the text. Increase txt_width for wider text.
txt_height = 0.03*(plt.ylim()[1] - plt.ylim()[0])
txt_width = 0.02*(plt.xlim()[1] - plt.xlim()[0])
#Get the corrected text positions, then write the text.
text_positions = get_text_positions(x_data, y_data, txt_width, txt_height)
text_plotter(x_data, y_data, z_data, text_positions, ax2, txt_width, txt_height)

print text_positions

cbar = plt.colorbar(CS)
cbar.ax.set_ylabel('Legend Label')


plt.show()

相关问题 更多 >