回答此问题可获得 20 贡献值,回答如果被采纳可获得 50 分。
<p>我需要在PyTorch中插入一些变形网格,并决定使用函数<code>grid_sample</code><a href="https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.grid_sample" rel="nofollow noreferrer">(see doc)</a>。我需要按照以下惯例将网格重塑为图像:</p>
<ul>
<li><code>N</code>批量大小</li>
<li><code>D</code>栅格深度(用于3D图像)</li>
<li><code>H</code>网格高度</li>
<li><code>W</code>网格宽度</li>
<li><code>d</code>网格维度(=2或3)</li>
</ul>
<p>图像格式为2D中的<code>(N,2,H,W)</code>(分别为3D中的<code>(N,3,D,H,W)</code>)
当网格格式为<code>(N,H,W,2)</code>(在3D中分别为<code>(N,D,H,W,3)</code>)时</p>
<p>我不能使用<code>reshape</code>或<code>view</code>,因为它们没有按照我的意愿排列数据。我需要(例如)</p>
<pre><code>grid_in_grid_convention[0,:,:,0] == grid_in_image_convention[0,0,:,:]
</code></pre>
<p>我提出了这些函数,以使重塑工作良好,但我相信有一个更紧凑/快速的方法来做到这一点。你教什么</p>
<pre class="lang-py prettyprint-override"><code>def grid2im(grid):
"""Reshape a grid tensor into an image tensor
2D [T,H,W,2] -> [T,2,H,W]
3D [T,D,H,W,2] -> [T,D,H,W,3]
"""
if grid.shape[0] == 1 and grid.shape[-1] == 2: # 2D case, batch =1
return torch.stack((grid[0,:,:,0],grid[0,:,:,1]),dim = 0).unsqueeze(0)
elif grid.shape[0] == 1 and grid.shape[-1] == 3: # 3D case, batch =1
return torch.stack((grid[0,:,:,:,0],grid[0,:,:,:,1],grid[0,:,:,:,2]),
dim = 0).unsqueeze(0)
elif grid.shape[-1] == 2:
N,H,W,d = grid.shape
temp = torch.zeros((N,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((grid[n,:,:,0],grid[n,:,:,1]),dim = 0).unsqueeze(0)
return temp
elif grid.shape[-1] == 3:
N,D,H,W,d =grid.shape
temp = torch.zeros((N,D,H,W,d))
for n in range(N):
temp[n,:,:,:,:] = torch.stack((grid[n,:,:,:,0],
grid[n,:,:,:,1],
grid[n,:,:,:,2]),
dim = 0).unsqueeze(0)
else:
raise ValueError("input argument expected is [N,H,W,2] or [N,D,H,W,3]",
"got "+str(grid.shape)+" instead.")
</code></pre>
<pre class="lang-py prettyprint-override"><code>def im2grid(image):
"""Reshape an image tensor into a grid tensor
2D case [T,2,H,W] -> [T,H,W,2]
3D case [T,3,D,H,W] -> [T,D,H,W,3]
"""
# No batch
if image.shape[0:2] == (1,2):
return torch.stack((image[0,0,:,:],image[0,1,:,:]),dim= 2).unsqueeze(0)
elif image.shape[0:2] == (1,3):
return torch.stack((image[0,0,:,:],image[0,1,:,:],image[0,2,:,:]),
dim = 2).unsqueeze(0)
# Batch size > 1
elif image.shape[0] > 0 and image.shape[1] == 2 :
N,d,H,W = image.shape
temp = torch.zeros((N,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((image[n,0,:,:],image[n,1,:,:]),dim= 2).unsqueeze(0)
return temp
elif image.shape[0] > 0 and image.shape[1] == 3 :
N,d,D,H,W = image.shape
temp = torch.zeros((N,D,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((image[n,0,:,:],
image[n,1,:,:],
image[n,2,:,:]),
dim = 2).unsqueeze(0)
return temp
else:
raise ValueError("input argument expected is [1,2,H,W] or [1,3,D,H,W]",
"got "+str(image.shape)+" instead.")
</code></pre>