尝试计算然后显示函数的梯度向量

2024-06-16 08:26:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我能够制作一个程序,显示一个2变量函数的3d图形,然后是函数梯度的向量场,但是我想让它计算梯度本身,但是我一直从plt.quiver()中得到isinfinite错误。我觉得部分原因是因为我来回使用了x和y的numpy和sympy符号,但我不知道在这种情况下该怎么办

def z_func(x,y):
    return (x**2+y**2)

def show_graph():
    x,y = np.meshgrid(np.linspace(-15,15,20),np.linspace(-15,15,20))
    z = z_func(x,y)

    fig = plt.figure(2)
    ax = fig.gca( projection='3d')
    surf = ax.plot_surface(x,y,z,rstride=1,cstride=1)


    ax.set_xlabel('X', fontweight = 'bold', fontsize = 14)
    ax.set_ylabel('Y', fontweight = 'bold', fontsize = 14)
    ax.set_zlabel('Z', fontweight = 'bold', fontsize = 14)

plt.title('Ahem', fontweight = 'bold', fontsize = 16)

def get_grad():
    x = sy.Symbol('x')
    y= sy.Symbol('y')
    f = z_func(x,y)
    gradi = sy.diff(f,x)
    gradj = sy.diff(f,y)
    show_vector(gradi,gradj)

def show_vector(gradi,gradj):
    a = sy.Symbol('x')
    b = sy.Symbol('y')
    u = gradi
    v = gradj

print('[{0},{1}]'.format(u,v))
a,b = np.meshgrid(np.linspace(-10,10,10),np.linspace(-10,10,10))
print('[{0},{1}]'.format(u,v))

figv = plt.figure(1)    
plt.xlabel('X')
plt.ylabel('Y')
plt.quiver(a,b,u,v)

def lazy():
    get_grad()
    show_graph()
    plt.show()

lazy()

Tags: defshownppltaxsymbolfuncbold
1条回答
网友
1楼 · 发布于 2024-06-16 08:26:36

如果要在symphy之外使用symphy表达式,则需要^{}

以下代码是否符合您的预期

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import sympy as sy

def z_func(x, y):
    return (x ** 2 + y ** 2)

x = sy.Symbol('x')
y = sy.Symbol('y')
f = z_func(x, y)
gradi = sy.diff(f, x)
gradj = sy.diff(f, y)
np_gradi = sy.lambdify(x, gradi, 'numpy')
np_gradj = sy.lambdify(y, gradj, 'numpy')

a, b = np.meshgrid(np.linspace(-10, 10, 10), np.linspace(-10, 10, 10))
u = np_gradi(a)
v = np_gradj(b)

x, y = np.meshgrid(np.linspace(-15, 15, 20), np.linspace(-15, 15, 20))
z = z_func(x, y)

fig = plt.figure(2)
ax = fig.gca(projection='3d')
surf = ax.plot_surface(x, y, z, rstride=1, cstride=1)

ax.set_xlabel('X', fontweight='bold', fontsize=14)
ax.set_ylabel('Y', fontweight='bold', fontsize=14)
ax.set_zlabel('Z', fontweight='bold', fontsize=14)

figv = plt.figure(1)
plt.xlabel('X')
plt.ylabel('Y')
plt.quiver(a, b, u, v)
plt.show()

resulting plot

相关问题 更多 >