我需要在PyTorch中插入一些变形网格,并决定使用函数grid_sample
(see doc)。我需要按照以下惯例将网格重塑为图像:
N
批量大小D
栅格深度(用于3D图像)H
网格高度W
网格宽度d
网格维度(=2或3)图像格式为2D中的(N,2,H,W)
(分别为3D中的(N,3,D,H,W)
)
当网格格式为(N,H,W,2)
(在3D中分别为(N,D,H,W,3)
)时
我不能使用reshape
或view
,因为它们没有按照我的意愿排列数据。我需要(例如)
grid_in_grid_convention[0,:,:,0] == grid_in_image_convention[0,0,:,:]
我提出了这些函数,以使重塑工作良好,但我相信有一个更紧凑/快速的方法来做到这一点。你教什么
def grid2im(grid):
"""Reshape a grid tensor into an image tensor
2D [T,H,W,2] -> [T,2,H,W]
3D [T,D,H,W,2] -> [T,D,H,W,3]
"""
if grid.shape[0] == 1 and grid.shape[-1] == 2: # 2D case, batch =1
return torch.stack((grid[0,:,:,0],grid[0,:,:,1]),dim = 0).unsqueeze(0)
elif grid.shape[0] == 1 and grid.shape[-1] == 3: # 3D case, batch =1
return torch.stack((grid[0,:,:,:,0],grid[0,:,:,:,1],grid[0,:,:,:,2]),
dim = 0).unsqueeze(0)
elif grid.shape[-1] == 2:
N,H,W,d = grid.shape
temp = torch.zeros((N,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((grid[n,:,:,0],grid[n,:,:,1]),dim = 0).unsqueeze(0)
return temp
elif grid.shape[-1] == 3:
N,D,H,W,d =grid.shape
temp = torch.zeros((N,D,H,W,d))
for n in range(N):
temp[n,:,:,:,:] = torch.stack((grid[n,:,:,:,0],
grid[n,:,:,:,1],
grid[n,:,:,:,2]),
dim = 0).unsqueeze(0)
else:
raise ValueError("input argument expected is [N,H,W,2] or [N,D,H,W,3]",
"got "+str(grid.shape)+" instead.")
def im2grid(image):
"""Reshape an image tensor into a grid tensor
2D case [T,2,H,W] -> [T,H,W,2]
3D case [T,3,D,H,W] -> [T,D,H,W,3]
"""
# No batch
if image.shape[0:2] == (1,2):
return torch.stack((image[0,0,:,:],image[0,1,:,:]),dim= 2).unsqueeze(0)
elif image.shape[0:2] == (1,3):
return torch.stack((image[0,0,:,:],image[0,1,:,:],image[0,2,:,:]),
dim = 2).unsqueeze(0)
# Batch size > 1
elif image.shape[0] > 0 and image.shape[1] == 2 :
N,d,H,W = image.shape
temp = torch.zeros((N,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((image[n,0,:,:],image[n,1,:,:]),dim= 2).unsqueeze(0)
return temp
elif image.shape[0] > 0 and image.shape[1] == 3 :
N,d,D,H,W = image.shape
temp = torch.zeros((N,D,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((image[n,0,:,:],
image[n,1,:,:],
image[n,2,:,:]),
dim = 2).unsqueeze(0)
return temp
else:
raise ValueError("input argument expected is [1,2,H,W] or [1,3,D,H,W]",
"got "+str(image.shape)+" instead.")
ihdv的答案是正确的。谢谢大家!
我在这里完成了以上两个函数的答案,并以一种更加高效和优雅的方式进行了更正
您可以使用^{} 。例如,如果你想用
[T,H,W,2] -> [T,2,H,W]
表示张量,比如说grid
,你可以这样做相关问题 更多 >
编程相关推荐