Scipy -- 3d griddata -- 为什么需要将griddata的xi参数转换为元组?
为什么下面这个调用griddata会失败呢?
import scipy.interpolate
import numpy as np
grid_vals = np.meshgrid(*([np.linspace(-1,1,200)] * 3))
interp_vals = scipy.interpolate.griddata(np.random.randn(50,3), np.random.randn(50), grid_vals, 'linear')
出现了以下错误:
ValueError: xi的维度数量和x不匹配
如果我把xi(grid_vals)这个参数转换成元组:
interp_vals = scipy.interpolate.griddata(np.random.randn(50,3), np.random.randn(50), tuple(grid_vals), 'linear')
错误就消失了。为什么会这样呢?
2 个回答
1
这个问题出现是因为底层的scipy插值模块的代码能够正确处理多维的元组,但不能处理多维的列表。
7
简单来说,griddata
这个函数会把points
和xi
都传递给一个叫points = _ndim_coords_from_arrays(points)
的函数。这个函数的说明文档里提到:
Convert a tuple of coordinate arrays to a (..., ndim)-shaped array.
对元组的关键操作是:
p = np.broadcast_arrays(*points)
其他任何东西,包括列表,都会被转换成数组:
points = np.asanyarray(points)
实际的插值操作需要数组的最后一个维度是'3d'。
所以你有的3个(200,200,200)
数组会变成一个形状为(3,200,200,200)
的数组。但是你的points
数组是(50,3)
。这个number of dimensions in xi does not match x
的错误信息就是因为200
和3
不匹配。
griddata
的文档对points
的说明很清楚,但对xi
的说明就没那么详细了。不过它的例子使用了(x, Y)
,这些数组是从mgrid
生成的。
所以这样做是可以的:
X, Y, Z = np.meshgrid(*([np.linspace(-1,1,200)] * 3))
interp_vals = scipy.interpolate.griddata(points, values, (X,Y,Z), 'linear')
生成所需数组的另一种方法是把你的meshgrid
列表变成一个数组,然后把第一个维度滚动一下。
grid_vals = np.rollaxis(np.array(grid_vals),0,4)
生成网格的另一种方法是使用np.ix_
,它会返回一个以元组形式表示的开放网格。这样的开放网格确实需要广播。
单个点的插值可以用以下任意一种方式进行:
interpolate.griddata(points,values,[[[[0,0,0]]]],'linear')
interpolate.griddata(points,values,([0],[0],[0]),'linear')
关于约翰的4123
拉取请求的反应中有更多关于原因的讨论。