如何在高维numpy数组中使用最近邻插值

3 投票
2 回答
5962 浏览
提问于 2025-04-18 12:23

我正在用Python编程,使用scipy和numpy库。我有一个数据查找表(LUT),我通过以下方式访问它:

self.lut_data[n_iter][m_iter][l_iter][k_iter][j_iter][i_iter] 

在这里,*_iter索引对应的是我保存在字典中的一组值。例如,i_iter索引对应的是光的波长,所以我有一个标签和数值的字典,可以通过以下方式获取:

labels['wavelength']

这样就能返回一个数组,里面是每个i_iter对应的波长。这在我直接查找时非常有用。如果我想要500纳米的lut_data,我首先在labels['wavelength']中找到对应的索引,然后用这个索引去查找

lut_data[][][][][][wavelength_index]

我对其他维度也做了同样的事情,比如观察角度等,它们对应其他的*_iters。

我需要做的是在查找表中的值之间找到值,并且我希望这个方法在我事先不知道查找表的维度时也能工作。如果我知道维度,那我就可以用循环来解决每个维度的问题。但是如果我不知道LUT有多少维度,那我就不知道要嵌套多少个循环。

我觉得我可以用cKDTree来做到这一点,但我就是搞不懂怎么让它工作。如果能给我一个和我的结构相似的例子,我会非常感激。

谢谢!

2 个回答

1

scipy.interpolate.RegularGridInterpolator这个工具非常适合解决这个问题。不过,它只在Scipy 0.14版本中可用(截至目前,这是最新的版本)。

如果你的*_iter数据存储在变量里,你可以这样做:

from scipy.interpolate import RegularGridInterpolator

points = tuple([n_iter, m_iter, l_iter, k_iter, j_iter, i_iter])
interpolator = RegularGridInterpolator(points, lut_data, method='nearest')

或者你可以从你的字典中获取points

keys = ['k1', 'k2', 'k3', 'k4', 'k5', 'wavelength']
points = tuple([labels[key] for key in keys])

一旦你有了插值器,你就可以使用它的__call__方法来进行插值。这基本上意味着你可以像调用函数一样调用你创建的这个类的实例:

point_of interest = tuple([x1, x2, x3, x4, x5, some_wavelength])
interp_value = interpolator(point_of_interest)

这个插值器还允许一次插值多个值(也就是说,可以传入一个Numpy数组的点),如果你的代码需要这样做,这样会大大提高效率。

1

如果你有一整套信息可以用来插值,线性插值其实并不难。它只是稍微耗时一点,但如果你的数据能放进内存里,处理起来也就几秒钟的事。

关键是线性插值可以逐个轴进行。对于每个轴,你需要:

  • 找到最近的两个点来进行插值
  • 计算这两个点之间的相对距离(d = 0..1),比如说如果你有540和550纳米的数据,而你想要在548纳米的数据,d 就是0.8。
  • 对所有轴重复这个过程;每次都会减少一个维度

就像这样:

import numpy as np

def ndim_interp(A, ranges, p):
    # A: array with n dimensions
    # ranges: list of n lists or numpy arrays of values along each dimension
    # p: vector of values to find (n elements)

    # iterate through all dimensions
    for i in range(A.ndim):
        # check if we are overrange; if we are, use the edgemost values
        if p[i] <= ranges[i][0]:
            A = A[0]
            continue
        if p[i] >= ranges[i][-1]:
            A = A[-1]
            continue

        # find the nearest values
        right = np.searchsorted(ranges[i], p[i])
        left = right - 1

        # find the relative distance
        d = (p[i] - ranges[i][left]) / (ranges[i][right] - ranges[i][left])

        # calculate the interpolation
        A = (1 - d) * A[left] + d * A[right]            

    return A

举个例子:

# data axis points
arng = [1, 2, 3]
brng = [100, 200]
crng = [540, 550, 560]

# some data
A = np.array([
    [[1., 2., 3.], [2., 3., 4.]],
    [[0.5, 1.5, 2.], [1.5, 2.0, 3.0]],
    [[0., 0.5, 1.], [1., 1., 1.]]])

# lookup:
print ndim_interp(A, (arng, brng, crng), (2.3, 130., 542.))

如果你想做一些更复杂的事情(比如三次样条插值等),你可以使用 scipy.ndimage.interpolation.map_coordinates。这样的话,步骤就会变成:

import numpy as np
import scipy.ndimage.interpolation

def ndim_interp(A, ranges, p):
    # A: array with n dimensions
    # ranges: list of n lists or numpy arrays of values along each dimension
    # p: vector of values to find (n elements)

    # calculate the coordinates into array positions in each direction
    p_arr = []
    # iterate through all dimensions
    for i in range(A.ndim):
        # check if we are overrange; if we are, use the edgemost values
        if p[i] <= ranges[i][0]:
            p_arr.append(0)
            continue
        if p[i] >= ranges[i][-1]:
            p_arr.append(A.shape[i] - 1)
            continue

        # find the nearest values to the left
        right = np.searchsorted(ranges[i], p[i])
        left = right - 1

        # find the relative distance
        d = (p[i] - ranges[i][left]) / (ranges[i][right] - ranges[i][left])

        # append the position
        p_arr.append(left + d)

    coords = np.array(p_arr).reshape(A.ndim, -1)
    return scipy.ndimage.interpolation.map_coordinates(A, coords, order=1, mode='nearest')[0]

当然,使用最简单的设置(order=1 就是线性插值)是没有意义的,但即使是三次样条插值,自己写插值算法也不简单。

当然,如果你的网格在所有方向上都是均匀分布的,那么代码会更简单,因为你不需要先插值来找到正确的位置(简单的除法就可以了)。

撰写回答