如何使用周围三维阵列的值填充numpy三维阵列(或火炬张量)

2024-06-17 09:48:44 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个形状为3,3,3的3D numpy数组,我想从它周围的数组中填充两层值,这样它就变成了5,5,5数组

enter image description here

到目前为止,我使用torch cat功能(其工作原理与numpy concat相同)对y阵列进行了如下填充:

x = torch.from_numpy(np.arange(1,28,1).reshape(3,3,3))
y = torch.from_numpy(np.arange(28,55,1).reshape(3,3,3))
z = torch.from_numpy(np.arange(55,82,1).reshape(3,3,3))

torch.cat((y,z[:,:2,:]), dim=1) #To concat z+ with 2 pads
torch.cat((x[:,1:,:],y), dim=1) #To concat z- with 2 pads

torch.cat((y,z[:,:,:2]), dim=2) #To concat x+ with 2 pads
torch.cat((x[:,1:,:],y), dim=1) #To concat x- with 2 pads

torch.cat((x,z[:2,:,:]), dim=2) #To concat y+ with 2 pads
torch.cat((x[1:,:,:],y), dim=1) #To concat y- with 2 pads

但它没有给我正确的值。我怎样才能做到这一点


Tags: tofrom功能numpywithnptorch数组
1条回答
网友
1楼 · 发布于 2024-06-17 09:48:44

如果我理解正确,您需要的不是常规数组,因为每个维度都有不同的范围,这取决于观察到的轴(即,它不能几何表示为立方体-您的图片不是(n,n,n)数组)

无论如何,在下面的片段中,我们创建了一个(5,5,5)测试3D数组,可以从中对(3,3,3)数组进行采样。然后我们连续连接以获得原始数组,之后我们屏蔽不需要的单元格,以便输出如图片所示。注意,当使用布尔Numpy数组时,可以用+*替换逻辑操作

import numpy as np

# Define dummy 3D field
n = 5
xx, yy, zz = np.ogrid[0:n, 0:n, 0:n]
field = np.sin(xx) + np.cos(yy) + np.tan(zz)

# Indices of 3 innermost elements - to form (3,3,3) array
i1, i2 = n//2 - 1, n//2 + 1
# Inner 3D array
subfield = np.copy(field)[i1:i2+1, i1:i2+1, i1:i2+1]

# Indices of "inferior" pads
x1 = y1 = z1 = np.arange(i1 - 1, i1)
# Indices of "superior" pads
x2 = y2 = z2 = np.arange(i2 + 1, i2 + 2)

# Padding in axis 0 (x)
padded = np.concatenate((field[x1, i1:i2+1, i1:i2+1], subfield))
padded = np.concatenate((padded, field[x2, i1:i2+1, i1:i2+1]))
# Padding in axis 1 (y)
padded = np.concatenate((field[i1-1:i2+2, y1, i1:i2+1], padded), axis = 1)
padded = np.concatenate((padded, field[i1-1:i2+2, y2, i1:i2+1]), axis = 1)
# Padding in axis 2 (z)
padded = np.concatenate((field[i1-1:i2+2, i1-1:i2+2, z1], padded), axis = 2)
padded = np.concatenate((padded, field[i1-1:i2+2, i1-1:i2+2, z2]), axis = 2)

# Check padded array is equal to original 3D array
print(np.all(padded == field))

## We now mask unwanted cells
indices = np.indices(padded.shape)
idx1, idx2 = indices == 0, indices == n - 1

xi, xf = idx1[0], idx2[0]
yi, yf = idx1[1], idx2[1]
zi, zf = idx1[2], idx2[2]

# Logical operations to mask proper slices
xm = (xi + xf) * (yi + yf + zi + zf)            # masking in axis 0
ym = (yi + yf) * (xi + xf + zi + zf)            # masking in axis 1
zm = (zi + zf) * (yi + yf + xi + xf)            # masking in axis 2

mask = xm + ym + zm
# Masked (5,5,5) array
masked_padded = np.ma.masked_where(mask, padded)

顺便说一句,一定有更优雅的方法来实现同样的结果,但我没有太多地使用Numpy的高级索引:p

相关问题 更多 >