填充numpy数组中的空缺

16 投票
5 回答
12397 浏览
提问于 2025-04-16 15:07

我想用最简单的方式来处理一个三维数据集的插值。线性插值、最近邻插值这些方法都可以用(因为这是为了启动一个算法,所以不需要特别精确的结果)。

在新的scipy版本中,像griddata这样的工具会很有用,但我现在只有scipy 0.8。所以我有一个“立方体”数据(data[:,:,:],大小是(NixNjxNk)),还有一个同样大小的标志数组(flags[:,:,:,],里面是TrueFalse)。我想对那些在标志数组中对应元素为False的数据进行插值,比如用最近的有效数据点,或者用一些“附近”的点的线性组合。

在这个数据集中,至少有两个维度可能会有很大的空缺。除了用kd树或类似的方法编写一个完整的最近邻算法外,我真的找不到一个通用的N维最近邻插值方法。

5 个回答

2

不久前,我为我的博士论文写了一个脚本:https://github.com/Technariumas/Inpainting

这里有个例子:http://blog.technariumas.lt/post/117630308826/healing-holes-in-python-arrays

这个脚本运行得比较慢,但能完成任务。使用高斯核是最好的选择,只需检查一下大小和sigma值。

40

使用scipy.ndimage这个库,你的问题可以用最近邻插值法在两行代码内解决:

from scipy import ndimage as nd

indices = nd.distance_transform_edt(invalid_cell_mask, return_distances=False, return_indices=True)
data = data[tuple(ind)]

现在,把它做成一个函数:

import numpy as np
from scipy import ndimage as nd

def fill(data, invalid=None):
    """
    Replace the value of invalid 'data' cells (indicated by 'invalid') 
    by the value of the nearest valid data cell

    Input:
        data:    numpy array of any dimension
        invalid: a binary array of same shape as 'data'. 
                 data value are replaced where invalid is True
                 If None (default), use: invalid  = np.isnan(data)

    Output: 
        Return a filled array. 
    """    
    if invalid is None: invalid = np.isnan(data)

    ind = nd.distance_transform_edt(invalid, 
                                    return_distances=False, 
                                    return_indices=True)
    return data[tuple(ind)]

使用示例:

def test_fill(s,d):
     # s is size of one dimension, d is the number of dimension
    data = np.arange(s**d).reshape((s,)*d)
    seed = np.zeros(data.shape,dtype=bool)
    seed.flat[np.random.randint(0,seed.size,int(data.size/20**d))] = True

    return fill(data,-seed), seed

import matplotlib.pyplot as plt
data,seed  = test_fill(500,2)
data[nd.binary_dilation(seed,iterations=2)] = 0   # draw (dilated) seeds in black
plt.imshow(np.mod(data,42))                       # show cluster

结果: 在这里输入图片描述

16

你可以设置一个类似晶体生长的算法,交替沿着每个轴移动视图,只替换那些标记为 False 但有 True 邻居的数据。这会产生一种“最近邻”的效果(不过不是在欧几里得或曼哈顿距离下——我觉得如果你在计算像素,算上所有连接的像素和公共角落的话,可能会算作最近邻)。这样做在使用 NumPy 时应该会比较高效,因为它只遍历轴和收敛迭代,而不是小块数据。

简单、快速且稳定。我想这正是你想要的:

import numpy as np
# -- setup --
shape = (10,10,10)
dim = len(shape)
data = np.random.random(shape)
flag = np.zeros(shape, dtype=bool)
t_ct = int(data.size/5)
flag.flat[np.random.randint(0, flag.size, t_ct)] = True
# True flags the data
# -- end setup --

slcs = [slice(None)]*dim

while np.any(~flag): # as long as there are any False's in flag
    for i in range(dim): # do each axis
        # make slices to shift view one element along the axis
        slcs1 = slcs[:]
        slcs2 = slcs[:]
        slcs1[i] = slice(0, -1)
        slcs2[i] = slice(1, None)

        # replace from the right
        repmask = np.logical_and(~flag[slcs1], flag[slcs2])
        data[slcs1][repmask] = data[slcs2][repmask]
        flag[slcs1][repmask] = True

        # replace from the left
        repmask = np.logical_and(~flag[slcs2], flag[slcs1])
        data[slcs2][repmask] = data[slcs1][repmask]
        flag[slcs2][repmask] = True

为了更好地说明,这里有一个(二维)可视化,展示了最初标记为 True 的数据所生成的区域。

在这里输入图片描述

撰写回答