Scipy -- 3d griddata -- 为什么需要将griddata的xi参数转换为元组?

15 投票
2 回答
8328 浏览
提问于 2025-04-18 02:23

为什么下面这个调用griddata会失败呢?

import scipy.interpolate
import numpy as np
grid_vals = np.meshgrid(*([np.linspace(-1,1,200)] * 3))
interp_vals = scipy.interpolate.griddata(np.random.randn(50,3), np.random.randn(50), grid_vals, 'linear')

出现了以下错误:

ValueError: xi的维度数量和x不匹配

如果我把xi(grid_vals)这个参数转换成元组:

interp_vals = scipy.interpolate.griddata(np.random.randn(50,3), np.random.randn(50), tuple(grid_vals), 'linear') 

错误就消失了。为什么会这样呢?

2 个回答

1

这个问题出现是因为底层的scipy插值模块的代码能够正确处理多维的元组,但不能处理多维的列表。

你可以在这里查看相关代码

我已经创建了问题 4123请求 4124,希望能解决这个问题,至少是针对列表的。

7

简单来说,griddata这个函数会把pointsxi都传递给一个叫points = _ndim_coords_from_arrays(points)的函数。这个函数的说明文档里提到:

Convert a tuple of coordinate arrays to a (..., ndim)-shaped array.

对元组的关键操作是:

p = np.broadcast_arrays(*points)

其他任何东西,包括列表,都会被转换成数组:

points = np.asanyarray(points)

实际的插值操作需要数组的最后一个维度是'3d'。

所以你有的3个(200,200,200)数组会变成一个形状为(3,200,200,200)的数组。但是你的points数组是(50,3)。这个number of dimensions in xi does not match x的错误信息就是因为2003不匹配。

griddata的文档对points的说明很清楚,但对xi的说明就没那么详细了。不过它的例子使用了(x, Y),这些数组是从mgrid生成的。

所以这样做是可以的:

 X, Y, Z = np.meshgrid(*([np.linspace(-1,1,200)] * 3))
 interp_vals = scipy.interpolate.griddata(points, values, (X,Y,Z), 'linear')

生成所需数组的另一种方法是把你的meshgrid列表变成一个数组,然后把第一个维度滚动一下。

grid_vals = np.rollaxis(np.array(grid_vals),0,4)

生成网格的另一种方法是使用np.ix_,它会返回一个以元组形式表示的开放网格。这样的开放网格确实需要广播。

单个点的插值可以用以下任意一种方式进行:

interpolate.griddata(points,values,[[[[0,0,0]]]],'linear')
interpolate.griddata(points,values,([0],[0],[0]),'linear')

关于约翰的4123拉取请求的反应中有更多关于原因的讨论。

撰写回答