如何对齐双y轴刻度的网格线?

63 投票
9 回答
75319 浏览
提问于 2025-05-01 18:08

我正在绘制两个数据集,它们的y轴单位不同。有没有办法让两个y轴的刻度和网格线对齐呢?

第一张图片展示了我现在得到的效果,第二张图片则是我想要的效果。

这是我用来绘图的代码:

import seaborn as sns
import numpy as np
import pandas as pd

np.random.seed(0)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(pd.Series(np.random.uniform(0, 1, size=10)))
ax2 = ax1.twinx()
ax2.plot(pd.Series(np.random.uniform(10, 20, size=10)), color='r')

不想要的效果示例

想要的效果示例

暂无标签

9 个回答

2

这段代码的作用是让两个方向的网格线能够对齐,而不需要隐藏任何一组的网格线。在这个例子中,它可以让你选择对齐那些网格线更细的那一组。这是基于@Leo的想法。希望这对你有帮助!

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(pd.Series(np.random.uniform(0,1,size=10)))
ax2 = ax1.twinx()
ax2.plot(pd.Series(np.random.uniform(10,20,size=10)),color='r')
ax2.grid(None)

# Determine which plot has finer grid. Set pointers accordingly
l1 = len(ax1.get_yticks())
l2 = len(ax2.get_yticks())
if l1 > l2:
  a = ax1
  b = ax2
  l = l1
else:
  a = ax2
  b = ax1
  l = l2

# Respace grid of 'b' axis to match 'a' axis
b_ticks = np.linspace(b.get_yticks()[0],b.get_yticks()[-1],l)
b.set_yticks(b_ticks)

plt.show()
11

我写了一个方法,用来对齐多个y轴的刻度(可能不止两个),这些y轴的刻度范围可能不同。

下面是一个示例图: 在这里输入图片描述

图中有3个y轴,一个是左边的蓝色,右边有一个绿色和一个红色。三条曲线分别用对应的颜色绘制在y轴上。注意,它们的数值范围差别很大。

  • 左边的图:没有对齐。
  • 中间的图:大致在每个y轴的下限对齐。
  • 右边的图:在指定的值上对齐:蓝色为0,红色为2.2*1e8,绿色为44。这些值是随便选的。

我的做法是将每个y数组缩放到1到100的范围内,然后把所有缩放后的y值合并成一个数组,再用MaxNLocator创建一组新的刻度。接着,这组新的刻度会根据对应的缩放因子缩放回去,以得到每个轴的新刻度。如果需要特定的对齐,y数组会在缩放前先进行位移,缩放后再恢复。

完整代码在这里(关键函数是alignYaxes()):

import matplotlib.pyplot as plt
import numpy as np

def make_patch_spines_invisible(ax):
    '''Used for creating a 2nd twin-x axis on the right/left

    E.g.
        fig, ax=plt.subplots()
        ax.plot(x, y)
        tax1=ax.twinx()
        tax1.plot(x, y1)
        tax2=ax.twinx()
        tax2.spines['right'].set_position(('axes',1.09))
        make_patch_spines_invisible(tax2)
        tax2.spines['right'].set_visible(True)
        tax2.plot(x, y2)
    '''

    ax.set_frame_on(True)
    ax.patch.set_visible(False)
    for sp in ax.spines.values():
        sp.set_visible(False)

def alignYaxes(axes, align_values=None):
    '''Align the ticks of multiple y axes

    Args:
        axes (list): list of axes objects whose yaxis ticks are to be aligned.
    Keyword Args:
        align_values (None or list/tuple): if not None, should be a list/tuple
            of floats with same length as <axes>. Values in <align_values>
            define where the corresponding axes should be aligned up. E.g.
            [0, 100, -22.5] means the 0 in axes[0], 100 in axes[1] and -22.5
            in axes[2] would be aligned up. If None, align (approximately)
            the lowest ticks in all axes.
    Returns:
        new_ticks (list): a list of new ticks for each axis in <axes>.

        A new sets of ticks are computed for each axis in <axes> but with equal
        length.
    '''
    from matplotlib.pyplot import MaxNLocator

    nax=len(axes)
    ticks=[aii.get_yticks() for aii in axes]
    if align_values is None:
        aligns=[ticks[ii][0] for ii in range(nax)]
    else:
        if len(align_values) != nax:
            raise Exception("Length of <axes> doesn't equal that of <align_values>.")
        aligns=align_values

    bounds=[aii.get_ylim() for aii in axes]

    # align at some points
    ticks_align=[ticks[ii]-aligns[ii] for ii in range(nax)]

    # scale the range to 1-100
    ranges=[tii[-1]-tii[0] for tii in ticks]
    lgs=[-np.log10(rii)+2. for rii in ranges]
    igs=[np.floor(ii) for ii in lgs]
    log_ticks=[ticks_align[ii]*(10.**igs[ii]) for ii in range(nax)]

    # put all axes ticks into a single array, then compute new ticks for all
    comb_ticks=np.concatenate(log_ticks)
    comb_ticks.sort()
    locator=MaxNLocator(nbins='auto', steps=[1, 2, 2.5, 3, 4, 5, 8, 10])
    new_ticks=locator.tick_values(comb_ticks[0], comb_ticks[-1])
    new_ticks=[new_ticks/10.**igs[ii] for ii in range(nax)]
    new_ticks=[new_ticks[ii]+aligns[ii] for ii in range(nax)]

    # find the lower bound
    idx_l=0
    for i in range(len(new_ticks[0])):
        if any([new_ticks[jj][i] > bounds[jj][0] for jj in range(nax)]):
            idx_l=i-1
            break

    # find the upper bound
    idx_r=0
    for i in range(len(new_ticks[0])):
        if all([new_ticks[jj][i] > bounds[jj][1] for jj in range(nax)]):
            idx_r=i
            break

    # trim tick lists by bounds
    new_ticks=[tii[idx_l:idx_r+1] for tii in new_ticks]

    # set ticks for each axis
    for axii, tii in zip(axes, new_ticks):
        axii.set_yticks(tii)

    return new_ticks

def plotLines(x, y1, y2, y3, ax):

    ax.plot(x, y1, 'b-')
    ax.tick_params('y',colors='b')

    tax1=ax.twinx()
    tax1.plot(x, y2, 'r-')
    tax1.tick_params('y',colors='r')

    tax2=ax.twinx()
    tax2.spines['right'].set_position(('axes',1.2))
    make_patch_spines_invisible(tax2)
    tax2.spines['right'].set_visible(True)
    tax2.plot(x, y3, 'g-')
    tax2.tick_params('y',colors='g')

    ax.grid(True, axis='both')

    return ax, tax1, tax2

#-------------Main---------------------------------
if __name__=='__main__':

    # craft some data to plot
    x=np.arange(20)
    y1=np.sin(x)
    y2=x/1000+np.exp(x)
    y3=x+x**2/3.14

    figure=plt.figure(figsize=(12,4),dpi=100)

    ax1=figure.add_subplot(1, 3, 1)
    axes1=plotLines(x, y1, y2, y3, ax1)
    ax1.set_title('No alignment')

    ax2=figure.add_subplot(1, 3, 2)
    axes2=plotLines(x, y1, y2, y3, ax2)
    alignYaxes(axes2)
    ax2.set_title('Default alignment')

    ax3=figure.add_subplot(1, 3, 3)
    axes3=plotLines(x, y1, y2, y3, ax3)
    alignYaxes(axes3, [0, 2.2*1e8, 44])
    ax3.set_title('Specified alignment')

    figure.tight_layout()
    figure.show()
16

我写了一个函数,这个函数需要用到Matplotlib的坐标轴对象ax1和ax2,还有两个浮点数minresax1和minresax2:

def align_y_axis(ax1, ax2, minresax1, minresax2):
    """ Sets tick marks of twinx axes to line up with 7 total tick marks

    ax1 and ax2 are matplotlib axes
    Spacing between tick marks will be a factor of minresax1 and minresax2"""

    ax1ylims = ax1.get_ybound()
    ax2ylims = ax2.get_ybound()
    ax1factor = minresax1 * 6
    ax2factor = minresax2 * 6
    ax1.set_yticks(np.linspace(ax1ylims[0],
                               ax1ylims[1]+(ax1factor -
                               (ax1ylims[1]-ax1ylims[0]) % ax1factor) %
                               ax1factor,
                               7))
    ax2.set_yticks(np.linspace(ax2ylims[0],
                               ax2ylims[1]+(ax2factor -
                               (ax2ylims[1]-ax2ylims[0]) % ax2factor) %
                               ax2factor,
                               7))

这个函数会计算并设置坐标轴上的刻度,确保总共有七个刻度。最小的刻度对应当前的最小刻度,然后增加最大的刻度,使得每个刻度之间的间隔是minresax1或minresax2的整数倍。

为了让这个函数更通用,你可以通过把代码中所有的7改成你想要的总刻度数,和把6改成总刻度数减去1,来设置你想要的刻度数量。

我已经提交了一个请求,希望把这些功能加入到matplotlib.ticker.LinearLocator中:

https://github.com/matplotlib/matplotlib/issues/6142

在未来的版本中(可能是Matplotlib 2.0?),你可以试试:

import matplotlib.ticker
nticks = 11
ax1.yaxis.set_major_locator(matplotlib.ticker.LinearLocator(nticks))
ax2.yaxis.set_major_locator(matplotlib.ticker.LinearLocator(nticks))

这样应该就能正常工作,并为两个y轴选择合适的刻度。

16

我通过在其中一个坐标轴上关闭 ax.grid(None) 来解决这个问题:

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(pd.Series(np.random.uniform(0, 1, size=10)))
ax2 = ax1.twinx()
ax2.plot(pd.Series(np.random.uniform(10, 20, size=10)), color='r')
ax2.grid(None)

plt.show()

结果图

34

我不确定这样做是不是最漂亮的方法,但这一行代码确实解决了问题:

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

np.random.seed(0)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(pd.Series(np.random.uniform(0, 1, size=10)))
ax2 = ax1.twinx()
ax2.plot(pd.Series(np.random.uniform(10, 20, size=10)), color='r')

# ADD THIS LINE
ax2.set_yticks(np.linspace(ax2.get_yticks()[0], ax2.get_yticks()[-1], len(ax1.get_yticks())))

plt.show()

撰写回答