在Python中从嵌套循环中获取插值
我有一个在 (xz) 平面上有特定形状的二维结构。为了简单起见,我这里设定它是圆形的。我基本上需要围绕 z 轴旋转这个结构,我的想法是用一个插值函数来实现这个旋转。RegularGridInterpolator
(文档链接)听起来很合适,因为我打算把这个结构在 (xz) 平面上的数据作为插值器的输入,然后在围绕 z 轴旋转时,计算每个位置的 sqrt(x^2 + y^2)(从上面看,也就是沿着 z 轴看),对应于原来的 x 坐标,而 z 坐标则保持不变。
代码是可以运行的,但当数组变大时(每个方向最多有 1000 个点),速度非常慢。这很可能是因为插值计算中使用了嵌套的 for 循环。我考虑过用列表推导式来优化这个,但在这里无法与插值器配合使用。所以我的问题是,如何至少去掉一个 for 循环,甚至更多?
以下是代码:
import matplotlib.pyplot as plt
from mayavi import mlab
import numpy as np
import scipy.interpolate as interp
def make_simple_2Dplot( data2plot, xVals, yVals, N_contLevels=8 ):
fig, ax = plt.subplots()
# the necessity for a transposed arrays confuses me...
ax.contourf(x_arr, z_arr, data2plot.T)
ax.set_aspect('equal')
ax.set_xlabel('x')
ax.set_ylabel('z')
plt.show()
def make_simple_3Dplot( data2plot, xVals, yVals, zVals, N_contLevels=8 ):
contLevels = np.linspace( np.amin(data2plot),
np.amax(data2plot),
N_contLevels)[1:].tolist()
fig1 = mlab.figure( bgcolor=(1,1,1), fgcolor=(0,0,0),size=(800,600))
contPlot = mlab.contour3d( data2plot, contours=contLevels,
transparent=True, opacity=.4,
figure=fig1
)
mlab.xlabel('x')
mlab.ylabel('y')
mlab.zlabel('z')
mlab.show()
x_min, z_min = 0, 0
x_max, z_max = 10, 10
Nx = 100
Nz = 50
x_arr = np.linspace(x_min, x_max, Nx)
z_arr = np.linspace(z_min, z_max, Nz)
# center of circle in 2D
xc, zc = 5, 5
# radius of circle
rc = 2
# make 2D circle
data_2D = np.zeros( (Nx,Nz) )
for ii in range(Nx):
for kk in range(Nz):
if np.sqrt((x_arr[ii]-xc)**2 + (z_arr[kk]-zc)**2) < rc:
data_2D[ii,kk] = 1
# interpolation function to make 3D object
circle_xz = interp.RegularGridInterpolator( (x_arr,z_arr), data_2D,
bounds_error=False,
fill_value=0
)
# coordinate arrays for 3D data
y_min = -x_max
y_max = x_max
Ny = 100
x_arr_3D = np.linspace(-x_max, x_max, Nx)
y_arr_3D = np.linspace(y_min, y_max, Ny)
z_arr_3D = np.linspace(z_min, z_max, Nz)
# make 3D circle
data_3D = np.zeros( (Nx, Ny, Nz) )
for ii in range(Nx):
for jj in range(Ny):
# calculate R corresponding to x in (xz) plane
R = np.sqrt(x_arr_3D[ii]**2 + y_arr_3D[jj]**2)
for kk in range(Nz):
# hiding the interpolator deep in the nested for loop
# is probably not very clever
data_3D[ii,jj,kk] = circle_xz( (R, z_arr_3D[kk]) )
make_simple_2Dplot( data_2D, x_arr, z_arr, N_contLevels=8 )
make_simple_3Dplot( data_3D, x_arr_3D, y_arr_3D, z_arr_3D )
从下面的二维和三维输出可以看出,它是有效的,但速度很慢。
1 个回答
我的问题是,怎么能去掉至少一个for循环,甚至更多?
如果你使用np.meshgrid
,就可以完全去掉所有的循环,把数组变成2D或3D的形式。
# make 2D circle
x, z = np.meshgrid(x_arr, z_arr, indexing='ij')
data_2D = np.sqrt((x-xc)**2 + (z-zc)**2) < rc
# data_2D = data_2D.astype(np.float64) if you want the plot to look the same
为了更好地理解,可以看看x
和z
的前3 x 3的小数组:
x[:3, :3]
# array([[0. , 0. , 0. ],
# [0.1010101, 0.1010101, 0.1010101],
# [0.2020202, 0.2020202, 0.2020202]])
z[:3, :3]
# array([[0. , 0.20408163, 0.40816327],
# [0. , 0.20408163, 0.40816327],
# [0. , 0.20408163, 0.40816327]])
这些2D数组展示了x_arr
和z_arr
中每个值的所有组合,然后对生成的数组进行逐个元素的操作。第二行会生成一个布尔数组(比如用True
代替1
),但如果你想让图表看起来和之前一样,可以把它转换成float64
格式。
你可以用同样的方法来做3D图;meshgrid
可以接受任意数量的参数。
# make 3D circle
x, y, z = np.meshgrid(x_arr_3D, y_arr_3D, z_arr_3D, indexing='ij')
R = np.sqrt(x**2 + y**2)
points = np.vstack((R.ravel(), z.ravel())).T
data_3D = circle_xz(points).reshape(R.shape)
ravel
和vstack
被使用的原因是因为RegularGridInterpolator
的__call__
方法需要接受m
x n
的点数组来进行插值,其中m
是点的数量(500000
),n
是维度的数量(2
)。经过计算后,我们再把结果调整回原来的形状。