为了执行衍生,我开发了以下代码:
import matplotlib.pyplot as plt
import numpy as np
from math import *
xi = jnp.linspace(-3,3)
def f(x):
a = x**3+5
return a
g1i = jax.vmap(jax.grad(f))(xi)
g2i = jax.vmap(jax.grad(jax.grad(f)))(xi)
g3i = jax.vmap(jax.grad(jax.grad(jax.grad(f))))(xi)
plt.plot(xi,yi, label = "f")
plt.plot(xi,g1i, label = "f'")
plt.plot(xi,g2i, label = "f''")
plt.plot(xi,g3i, label = "f'''")
plt.legend()
此代码有效,但现在我感兴趣的是应用以下代码计算买入价相对于标的资产(即delta)的一阶导数,尝试使用以下代码,但不起作用:
import scipy.stats as si
import sympy as sy
import sys
xi = jnp.linspace(1,1.5)
def analytical_call(s0):
T=1.
q=0.
r=0.
k=1.
sigma=0.4
Kt = k*exp((q-r)*T)
d = (log(Kt/s0)+(sigma**2)/2*T)/sigma
result = (Kt * si.norm.cdf((d / sqrt(T)), 0.0, 1.0) - s0 * si.norm.cdf(((d - sigma * T) / sqrt(T)), 0.0, 1.0) ) * exp(-q * T) + exp(-q * T) * (s0 - Kt)
return result
print(analytical_call(1))
g1i = jax.vmap(jax.grad(analytical_call))(xi)
g2i = jax.vmap(jax.grad(jax.grad(analytical_call)))(xi)
plt.plot(xi,yi, label = "f")
plt.plot(xi,g1i, label = "f'")
plt.legend()
你有什么提示吗?提前谢谢
正如在注释中已经提到的,您不能在jax库之外使用像} 。类似地,用jax等价物
scipy.stats.norm.cdf
这样的方法。改用^{jnp.exp
和jnp.sqrt
替换exp
和sqrt
:然后,您可以计算
g(xi)
和h(xi)
相关问题 更多 >
编程相关推荐