matplotlib中的字符串数组散点图

10 投票
3 回答
12154 浏览
提问于 2025-04-17 20:21

这看起来应该是个简单的问题,但我搞不明白。我有一个 pandas 数据框,想用其中的三列做一个三维散点图。可是,X 和 Y 列不是数字,它们是字符串,但我觉得这不应该是个问题。

X= myDataFrame.columnX.values #string
Y= myDataFrame.columnY.values #string
Z= myDataFrame.columnY.values #float

fig = pl.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X, Y, np.log10(Z), s=20, c='b')
pl.show()

难道没有简单的方法可以做到这一点吗?谢谢。

3 个回答

2

试着把字符转换成数字来绘图,然后再用字符作为坐标轴的标签。

使用哈希

你可以用hash这个函数来进行转换;

from mpl_toolkits.mplot3d import Axes3D
xlab = myDataFrame.columnX.values
ylab = myDataFrame.columnY.values

X =[hash(l) for l in xlab] 
Y =[hash(l) for l in xlab] 

Z= myDataFrame.columnY.values #float

fig = figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X, Y, np.log10(Z), s=20, c='b')
ax.set_xticks(X)
ax.set_xticklabels(xlab)
ax.set_yticks(Y)
ax.set_yticklabels(ylab)
show()

正如M4rtini在评论中提到的,字符串坐标的间距或缩放方式并不明确;使用hash函数可能会导致意想不到的间距。

均匀间距

如果你想让点之间的间距均匀,那你需要用不同的转换方法。比如你可以使用:

X =[i for i in range(len(xlab))]

不过这样会导致每个点都有一个独特的x位置,即使标签是相同的,如果你对Y使用相同的方法,x和y的点也会相关联。

简化的均匀间距

第三种选择是先获取xlab中的唯一成员(可以用set),然后用这个唯一集合来映射每个xlab到一个位置;例如:

xmap = dict((sn, i)for i,sn in enumerate(set(xlab)))
X = [xmap[l] for l in xlab]
3

现在,Scatter这个功能已经自动处理这些事情了(从matplotlib 2.1.0版本开始就这样了):

plt.scatter(['A', 'B', 'B', 'C'], [0, 1, 2, 1])   

散点图

11

你可以使用 np.unique(..., return_inverse=True) 来为每个字符串获取一个代表性的整数。比如说,

In [117]: uniques, X = np.unique(['foo', 'baz', 'bar', 'foo', 'baz', 'bar'], return_inverse=True)

In [118]: X
Out[118]: array([2, 1, 0, 2, 1, 0])

需要注意的是,X 的数据类型是 int32,因为 np.unique 最多只能处理 2**31 个不同的字符串。


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as axes3d

N = 12
arr = np.arange(N*2).reshape(N,2)
words = np.array(['foo', 'bar', 'baz', 'quux', 'corge'])
df = pd.DataFrame(words[arr % 5], columns=list('XY'))
df['Z'] = np.linspace(1, 1000, N)
Z = np.log10(df['Z'])
Xuniques, X = np.unique(df['X'], return_inverse=True)
Yuniques, Y = np.unique(df['Y'], return_inverse=True)

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')
ax.scatter(X, Y, Z, s=20, c='b')
ax.set(xticks=range(len(Xuniques)), xticklabels=Xuniques,
       yticks=range(len(Yuniques)), yticklabels=Yuniques) 
plt.show()

enter image description here

撰写回答