如何解决AttributeError:“numpy.ndarray”对象在绘制子图时没有属性“get_figure”

2024-05-13 20:10:00 发布

您现在位置:Python中文网/ 问答频道 /正文

当我试图运行一个代码单元时,我遇到了一个问题。 我试图为数据帧中的每个变量绘制散点图,但遇到了一个我不太确定的错误。你能帮忙吗

我的代码:

fig, axes = plt.subplots(nrows=3, ncols=7, figsize=(12,10))
for xcol, ax in zip(df[df.columns], axes):
    df.plot(kind='scatter', x=xcol, y='price', ax=ax, alpha=0.5, color='r')

返回错误: AttributeError:'numpy.ndarray'对象没有属性'get\u figure'

enter image description here


Tags: 数据代码df错误fig绘制pltax
1条回答
网友
1楼 · 发布于 2024-05-13 20:10:00
  • fig, axes = plt.subplots(nrows=3, ncols=7, figsize=(12,10))创建3组7AxesSubplot对象
array([[<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
       [<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
       [<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>]], dtype=object)
  • 通过使用zip(df[df.columns], axes)压缩,您将得到如下结果:
    • 这是错误的根源;如您所见,循环中的axarray,而不是AxesSubplot
[('col1', array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>], dtype=object)),
 ('col2', array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>], dtype=object)),
 ('col3', array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>], dtype=object))]
  • 您想要的是,将一列压缩到一个子图,这可以通过解包所有轴子图来完成,使用列表理解,或者使用axes.ravel(),然后将它们压缩到列名。
    • ^{}返回平坦数组
  • 使用df.columns而不是df[df.columns]获取列名
# the list comprehension unpacks all the axes
zip(df.columns, [x for v in axes for x in v])

# which results in one column name per subplot
[('col1', <AxesSubplot:>),
 ('col2', <AxesSubplot:>),
 ('col3', <AxesSubplot:>),
 ('col4', <AxesSubplot:>),
 ('col5', <AxesSubplot:>),
 ('col6', <AxesSubplot:>),
 ('col7', <AxesSubplot:>),
 ('col8', <AxesSubplot:>),
 ('col9', <AxesSubplot:>),
 ('col10', <AxesSubplot:>),
 ('col11', <AxesSubplot:>),
 ('col12', <AxesSubplot:>),
 ('col13', <AxesSubplot:>),
 ('col14', <AxesSubplot:>),
 ('col15', <AxesSubplot:>),
 ('col16', <AxesSubplot:>),
 ('col17', <AxesSubplot:>),
 ('col18', <AxesSubplot:>),
 ('col19', <AxesSubplot:>),
 ('col20', <AxesSubplot:>),
 ('col21', <AxesSubplot:>)]

范例

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# load sample data
df = sns.load_dataset('car_crashes')

# setup figure
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 10))

# iterate and plot subplots
for xcol, ax in zip(df.columns[1:-1], [x for v in axes for x in v]):
    df.plot.scatter(x=xcol, y='speeding', ax=ax, alpha=0.5, color='r')

enter image description here

相关问题 更多 >