Python中带有差异的错误消息

2024-06-16 09:04:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用蒙特卡罗方法计算这些衍生工具的一般看涨期权。我对这个组合导数感兴趣(关于S和Sigma)。通过算法微分,我得到了一个可以在页面末尾看到的错误。可能的解决办法是什么?为了解释与代码有关的内容,我将附上用于计算以下代码中“X”的公式:

enter image description here

from jax import jit, grad, vmap
import jax.numpy as jnp
from jax import random
Underlying_asset = jnp.linspace(1.1,1.4,100)
volatilities = jnp.linspace(0.5,0.6,100)
def second_derivative_mc(S,vol):
    N = 100
    j,T,q,r,k = 10000,1.,0,0,1.
    S0 = jnp.array([S]).T #(Nx1) vector underlying asset
    C = jnp.identity(N)*vol    #matrix of volatilities with 0 outside diagonal 
    e = jnp.array([jnp.full(j,1.)])#(1xj) vector of "1"
    Rand = np.random.RandomState()
    Rand.seed(10)
    U= Rand.normal(0,1,(N,j)) #Random number for Brownian Motion
    sigma2 = jnp.array([vol**2]).T #Vector of variance Nx1

    first = jnp.dot(sigma2,e) #First part equation
    second = jnp.dot(C,U)     #Second part equation

    X = -0.5*first+jnp.sqrt(T)*second

    St = jnp.exp(X)*S0

    P = jnp.maximum(St-k,0)
    payoff = jnp.average(P, axis=-1)*jnp.exp(-q*T)
    return payoff 


greek = vmap(grad(grad(second_derivative_mc, argnums=1), argnums=0)(Underlying_asset,volatilities)

这是错误消息:

> UnfilteredStackTrace                      Traceback (most recent call
> last) <ipython-input-78-0cc1da97ae0c> in <module>()
>      25 
> ---> 26 greek = vmap(grad(grad(second_derivative_mc, argnums=1), argnums=0))(Underlying_asset,volatilities)
> 
> 18 frames UnfilteredStackTrace: TypeError: Gradient only defined for
> scalar-output functions. Output had shape: (100,).

下面的堆栈跟踪不包括JAX内部帧。 前面是发生的原始异常,未经修改


上述异常是以下异常的直接原因:

> TypeError                                 Traceback (most recent call
> last) /usr/local/lib/python3.7/dist-packages/jax/_src/api.py in
> _check_scalar(x)
>     894     if isinstance(aval, ShapedArray):
>     895       if aval.shape != ():
> --> 896         raise TypeError(msg(f"had shape: {aval.shape}"))
>     897     else:
>     898       raise TypeError(msg(f"had abstract value {aval}"))

> TypeError: Gradient only defined for scalar-output functions. Output had shape: (100,).

Tags: importassetsecondshapehadjaxtypeerrorvmap
1条回答
网友
1楼 · 发布于 2024-06-16 09:04:49

如错误消息所示,只能为返回标量的函数计算梯度。函数返回一个向量:

print(len(second_derivative_mc(1.1, 0.5)))
# 100

对于向量值函数,可以计算雅可比矩阵(类似于多维梯度)。这就是你的想法吗

from jax import jacobian
greek = vmap(jacobian(jacobian(second_derivative_mc, argnums=1), argnums=0))(Underlying_asset,volatilities)

此外,这不是您所问的问题,但即使您解决了问题中的问题,上述函数也可能无法按您的意愿工作。NumpyRandomState对象是有状态的,因此通常无法正确使用jax转换,如gradjitvmap等,这些转换需要无副作用的代码(请参见Stateful Computations In JAX)。您可以尝试使用jax.random;有关更多信息,请参见JAX: Random Numbers

相关问题 更多 >