填充numpy数组中的空缺
我想用最简单的方式来处理一个三维数据集的插值。线性插值、最近邻插值这些方法都可以用(因为这是为了启动一个算法,所以不需要特别精确的结果)。
在新的scipy版本中,像griddata这样的工具会很有用,但我现在只有scipy 0.8。所以我有一个“立方体”数据(data[:,:,:]
,大小是(NixNjxNk)),还有一个同样大小的标志数组(flags[:,:,:,]
,里面是True
或False
)。我想对那些在标志数组中对应元素为False的数据进行插值,比如用最近的有效数据点,或者用一些“附近”的点的线性组合。
在这个数据集中,至少有两个维度可能会有很大的空缺。除了用kd树或类似的方法编写一个完整的最近邻算法外,我真的找不到一个通用的N维最近邻插值方法。
5 个回答
不久前,我为我的博士论文写了一个脚本:https://github.com/Technariumas/Inpainting
这里有个例子:http://blog.technariumas.lt/post/117630308826/healing-holes-in-python-arrays
这个脚本运行得比较慢,但能完成任务。使用高斯核是最好的选择,只需检查一下大小和sigma值。
使用scipy.ndimage这个库,你的问题可以用最近邻插值法在两行代码内解决:
from scipy import ndimage as nd
indices = nd.distance_transform_edt(invalid_cell_mask, return_distances=False, return_indices=True)
data = data[tuple(ind)]
现在,把它做成一个函数:
import numpy as np
from scipy import ndimage as nd
def fill(data, invalid=None):
"""
Replace the value of invalid 'data' cells (indicated by 'invalid')
by the value of the nearest valid data cell
Input:
data: numpy array of any dimension
invalid: a binary array of same shape as 'data'.
data value are replaced where invalid is True
If None (default), use: invalid = np.isnan(data)
Output:
Return a filled array.
"""
if invalid is None: invalid = np.isnan(data)
ind = nd.distance_transform_edt(invalid,
return_distances=False,
return_indices=True)
return data[tuple(ind)]
使用示例:
def test_fill(s,d):
# s is size of one dimension, d is the number of dimension
data = np.arange(s**d).reshape((s,)*d)
seed = np.zeros(data.shape,dtype=bool)
seed.flat[np.random.randint(0,seed.size,int(data.size/20**d))] = True
return fill(data,-seed), seed
import matplotlib.pyplot as plt
data,seed = test_fill(500,2)
data[nd.binary_dilation(seed,iterations=2)] = 0 # draw (dilated) seeds in black
plt.imshow(np.mod(data,42)) # show cluster
结果:
你可以设置一个类似晶体生长的算法,交替沿着每个轴移动视图,只替换那些标记为 False
但有 True
邻居的数据。这会产生一种“最近邻”的效果(不过不是在欧几里得或曼哈顿距离下——我觉得如果你在计算像素,算上所有连接的像素和公共角落的话,可能会算作最近邻)。这样做在使用 NumPy 时应该会比较高效,因为它只遍历轴和收敛迭代,而不是小块数据。
简单、快速且稳定。我想这正是你想要的:
import numpy as np
# -- setup --
shape = (10,10,10)
dim = len(shape)
data = np.random.random(shape)
flag = np.zeros(shape, dtype=bool)
t_ct = int(data.size/5)
flag.flat[np.random.randint(0, flag.size, t_ct)] = True
# True flags the data
# -- end setup --
slcs = [slice(None)]*dim
while np.any(~flag): # as long as there are any False's in flag
for i in range(dim): # do each axis
# make slices to shift view one element along the axis
slcs1 = slcs[:]
slcs2 = slcs[:]
slcs1[i] = slice(0, -1)
slcs2[i] = slice(1, None)
# replace from the right
repmask = np.logical_and(~flag[slcs1], flag[slcs2])
data[slcs1][repmask] = data[slcs2][repmask]
flag[slcs1][repmask] = True
# replace from the left
repmask = np.logical_and(~flag[slcs2], flag[slcs1])
data[slcs2][repmask] = data[slcs1][repmask]
flag[slcs2][repmask] = True
为了更好地说明,这里有一个(二维)可视化,展示了最初标记为 True
的数据所生成的区域。