如何更新sympy图的matplotlib元素?
下面的代码展示了一个使用matplotlib绘制的函数图像,函数是y > 5/x,用户可以在图上平移或缩放时填充图形。
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
x_range = np.linspace(-10, 10, 400)
y_range = 5 / x_range
line, = ax.plot(x_range, y_range, 'r', linewidth=2, linestyle='--')
ax.fill_between(x_range, y_range, y_range.max(), alpha=0.3, color='gray')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('Inequality: y > 5 / x')
ax.axhline(0, color='black',linewidth=3)
ax.axvline(0, color='black',linewidth=3)
ax.grid(color='gray', linestyle='--', linewidth=0.5)
def update_limits(event):
xlim = ax.get_xlim()
x_range = np.linspace(xlim[0], xlim[1], max(200, int(200 * (xlim[1] - xlim[0]))))
y_range = 5 / x_range
line.set_data(x_range, y_range)
for collection in ax.collections:
collection.remove()
ax.fill_between(x_range, y_range, max(y_range.max(), ax.get_ylim()[1]), alpha=0.3, color='gray')
plt.draw()
fig.canvas.mpl_connect('button_release_event', update_limits)
plt.show()
我一直在尝试把这个概念用代码实现,使用的是sympy模块(通过访问Matplotlib的后端),因为它对代数函数的控制更好。不过,下面的代码在用户平移时似乎没有填充图形。这是为什么呢?我该怎么解决这个问题呢?
import sympy as sp
from sympy.plotting.plot import MatplotlibBackend
x, y = sp.symbols('x y')
# Define the implicit plot
p1 = sp.plot_implicit(sp.And(y > 5 / x), (x, -10, 10), (y, -10, 10), show=False)
mplPlot = MatplotlibBackend(p1)
mplPlot.process_series()
mplPlot.fig.tight_layout()
mplPlot.ax[0].set_xlabel("x-axis")
mplPlot.ax[0].set_ylabel("y-axis")
def update_limits(event):
global mplPlot
xmin, xmax = mplPlot.ax[0].get_xlim()
ymin, ymax = mplPlot.ax[0].get_ylim()
p1 = sp.plot_implicit(sp.And(y > 5 / x), (x, xmin, xmax), (y, ymin, ymax), show=False)
mplPlot = MatplotlibBackend(p1)
mplPlot.process_series()
mplPlot.fig.tight_layout()
mplPlot.ax[0].set_xlabel("x-axis")
mplPlot.ax[0].set_ylabel("y-axis")
mplPlot.plt.draw()
mplPlot.fig.canvas.mpl_connect('button_release_event', update_limits)
mplPlot.plt.show()
更新:经过一些调试(使用sympy代码),我发现xmax
和xmin
这两个变量在update_limits
函数中只改变了一次,之后在程序运行的整个过程中都保持不变。如果可以的话,我也想知道这是为什么。
更新 2:如果不运行mplPlot.plt.draw()
,而是运行mplPlot.plt.show()
,会创建一个新窗口,显示正确的图形。但这不是我想要的,因为我希望所有的变化都在同一个窗口中显示。当我这样做时,还发现了另一个问题,就是当我在图的第四象限平移得足够远时,图形似乎变得无响应。这并不是每次都会发生,我也找不到原因。如果有人知道这是为什么,请随时补充到你的回答中!
1 个回答
0
我打算使用SymPy绘图后端的模块,因为它已经写了很多代码来支持交互功能。在写这个回答的时候,版本3.1.1已经发布,但它还没有实现平移/缩放等事件。不过,我们可以很容易地自己实现这些功能。
你可以先尝试一个隐式绘图:
%matplotlib widget
from sympy import *
from spb import *
from matplotlib.colors import ListedColormap
var("x, y")
g = graphics(
implicit_2d(
y>5/x, (x, -10, 10), (y, -200, 200), n=500,
rendering_kw={"alpha": 0.3, "cmap": ListedColormap(["#ffffff00", "gray"])},
border_kw={"linestyles": "--", "cmap": ListedColormap(["r", "r"])}
),
xlabel="x", ylabel="y", title="y > 5/x", show=False
)
def _update_axis_limits(event):
xlim = g.ax.get_xlim()
ylim = g.ax.get_ylim()
limits = [xlim, ylim]
all_params = {}
for s in g.series:
new_ranges = []
for r, l in zip(s.ranges, limits):
new_ranges.append((r[0], *l))
s.ranges = new_ranges
# inform the data series that they must generate new data on next update
s._interactive_ranges = True
s.is_interactive = True
# extract any existing parameters
all_params = g.merge({}, all_params)
# create new data and update the plot
g.update_interactive(all_params)
g.show(block=False)
g.fig.canvas.mpl_connect('button_release_event', _update_axis_limits)
这个可视化有个问题:在x=0的时候,函数是未定义的,所以不应该有那条垂直的红色虚线。隐式绘图依赖于Matplotlib的contour
功能:在等高线绘图中实现未定义的“点/线”是非常困难的。
我们可以通过使用fill_between
来改善可视化,就像你最初的方法那样。然而,绘图模块并没有实现这个功能。我们必须自己创建它,像这样:
import numpy as np
from sympy import *
from spb import *
from spb.series import LineOver1DRangeSeries
from spb.backends.matplotlib.renderers.renderer import MatplotlibRenderer
# Things to know in order to extend the plotting module capabilities:
#
# 1. data is generated by some data series.
# 2. each data series is mapped to a particular renderer
# Let's create a data series.
# Matplotlib's `fill_between`: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.fill_between.html
# requires the coordinates of the first curve and second curve.
# The first curve is just an ordinary line corresponding to the y-values of your expressions.
# The second curve is the maximum value of the y-coordinate of the first curve.
# So, the data is identical to the one we would use to render a line.
class FillBetweenSeries(LineOver1DRangeSeries):
pass
# Next, we must create a renderer, which must have two methods:
#
# 1. a `draw` method, where the initial handle will be created.
# 2. an `update` method, where the handle will be updated with new data.
#
# More information can be found here: https://sympy-plot-backends.readthedocs.io/en/latest/tutorials/tut-6.html
def draw(renderer, data):
ax = renderer.plot.ax
s = renderer.series
merge = renderer.plot.merge
x, y = data
rkw = merge({}, {"y2": np.nanmax(y)}, s.rendering_kw)
handles = [
ax.fill_between(x, y, **rkw)
]
return handles
def update(renderer, data, handles):
for h in handles:
h.remove()
handles[0] = draw(renderer, data)[0]
class FillBetweenRenderer(MatplotlibRenderer):
draw_update_map = {
draw: update
}
# Next, we need to inform the backend that when a FillBetweenSeries is
# ecountered, it will be rendered with FillBetweenRenderer
MB.renderers_map[FillBetweenSeries] = FillBetweenRenderer
# Next, let's create a function similar to the ones exposed by the plotting module.
# Note that it returns a list of series. Depending on the visualization
# that you are trying to create, you may need more than one data
# series...
def fill_between(expr, range, label="", rendering_kw={}, **kwargs):
"""
Parameters
----------
expr : Expr
The symbolic expression
ramge : (symbol, min, max)
Initial range in which to plot the expression
label : str
Eventual label to be shown on the legend
rendering_kw : dict
A dictionary containing keys/values which are going to
be passed to matplotlib.fill_between.
**kwargs :
Keyword arguments related to `line()`.
"""
series = [
FillBetweenSeries(expr, range, label, rendering_kw=rendering_kw, **kwargs)
]
return series
# Finally, we are ready to use it
var("x")
g = graphics(
fill_between(
5/x, (x, -10, 10), rendering_kw={"color": "gray", "alpha": 0.3},
exclude=[0] # exclude this point, where the function is undefined
),
line(
5/x, (x, -10, 10), rendering_kw={"color": "r", "linestyle": "--"},
exclude=[0] # exclude this point, where the function is undefined
),
xlabel="x", ylabel="y", title="y > 5/x", show=False, legend=False
)
def _update_axis_limits(event):
xlim = g.ax.get_xlim()
ylim = g.ax.get_ylim()
limits = [xlim, ylim]
all_params = {}
for s in g.series:
new_ranges = []
for r, l in zip(s.ranges, limits):
new_ranges.append((r[0], *l))
s.ranges = new_ranges
# inform the data series that they must generate new data on next update
s._interactive_ranges = True
s.is_interactive = True
# extract any existing parameters
all_params = g.merge({}, all_params)
# create new data and update the plot
g.update_interactive(all_params)
g.show(block=False)
g.fig.canvas.mpl_connect('button_release_event', _update_axis_limits)