Pytorch:从网格传递到图像坐标约定,用于网格\u示例

2024-03-28 17:43:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要在PyTorch中插入一些变形网格,并决定使用函数grid_sample(see doc)。我需要按照以下惯例将网格重塑为图像:

  • N批量大小
  • D栅格深度(用于3D图像)
  • H网格高度
  • W网格宽度
  • d网格维度(=2或3)

图像格式为2D中的(N,2,H,W)(分别为3D中的(N,3,D,H,W)) 当网格格式为(N,H,W,2)(在3D中分别为(N,D,H,W,3))时

我不能使用reshapeview,因为它们没有按照我的意愿排列数据。我需要(例如)

grid_in_grid_convention[0,:,:,0] == grid_in_image_convention[0,0,:,:]

我提出了这些函数,以使重塑工作良好,但我相信有一个更紧凑/快速的方法来做到这一点。你教什么

def grid2im(grid):
    """Reshape a grid tensor into an image tensor
        2D  [T,H,W,2] -> [T,2,H,W]
        3D  [T,D,H,W,2] -> [T,D,H,W,3]
    """
    if grid.shape[0] == 1 and grid.shape[-1] == 2: # 2D case, batch =1
        return torch.stack((grid[0,:,:,0],grid[0,:,:,1]),dim = 0).unsqueeze(0)

    elif grid.shape[0] == 1 and grid.shape[-1] == 3: # 3D case, batch =1 
        return torch.stack((grid[0,:,:,:,0],grid[0,:,:,:,1],grid[0,:,:,:,2]),
                            dim = 0).unsqueeze(0)
    
    elif grid.shape[-1] == 2:
        N,H,W,d = grid.shape
        temp = torch.zeros((N,H,W,d))
        for n in range(N):
            temp[n,:,:,:] = torch.stack((grid[n,:,:,0],grid[n,:,:,1]),dim = 0).unsqueeze(0)
        return temp
    
    elif grid.shape[-1] == 3:
        N,D,H,W,d =grid.shape
        temp = torch.zeros((N,D,H,W,d))
        for n in range(N):
            temp[n,:,:,:,:] = torch.stack((grid[n,:,:,:,0],
                                           grid[n,:,:,:,1],
                                           grid[n,:,:,:,2]),
                            dim = 0).unsqueeze(0)
    else:
        raise ValueError("input argument expected is [N,H,W,2] or [N,D,H,W,3]",
                         "got "+str(grid.shape)+" instead.")
def im2grid(image):
    """Reshape an image tensor into a grid tensor
        2D case [T,2,H,W]   ->  [T,H,W,2]
        3D case [T,3,D,H,W] ->  [T,D,H,W,3]
    """
    # No batch 
    if image.shape[0:2] == (1,2):
        return torch.stack((image[0,0,:,:],image[0,1,:,:]),dim= 2).unsqueeze(0)
    elif image.shape[0:2] == (1,3):
        return torch.stack((image[0,0,:,:],image[0,1,:,:],image[0,2,:,:]),
                           dim = 2).unsqueeze(0)
    # Batch size > 1
    elif image.shape[0] > 0 and image.shape[1] == 2 :
        N,d,H,W = image.shape
        temp = torch.zeros((N,H,W,d))
        for n in range(N):
            temp[n,:,:,:] = torch.stack((image[n,0,:,:],image[n,1,:,:]),dim= 2).unsqueeze(0)
        return temp
    elif image.shape[0] > 0 and image.shape[1] == 3 :
        N,d,D,H,W = image.shape
        temp = torch.zeros((N,D,H,W,d))
        for n in range(N):
            temp[n,:,:,:] = torch.stack((image[n,0,:,:],
                                         image[n,1,:,:],
                                         image[n,2,:,:]),
                           dim = 2).unsqueeze(0)
        return temp
    else:
        raise ValueError("input argument expected is [1,2,H,W] or [1,3,D,H,W]",
                         "got "+str(image.shape)+" instead.")

Tags: andinimage网格returnstacktorchtemp
2条回答

ihdv的答案是正确的。谢谢大家!

我在这里完成了以上两个函数的答案,并以一种更加高效和优雅的方式进行了更正

def grid2im(grid):
    """Reshape a grid tensor into an image tensor
        2D  [T,H,W,2] -> [T,2,H,W]
        3D  [T,D,H,W,2] -> [T,D,H,W,3]
    """
    if grid.shape[-1] == 2: # 2D case
        return grid.transpose(2,3).transpose(1,2)

    elif grid.shape[-1] == 3: # 3D case
        return grid.transpose(3,4).transpose(2,3).transpose(1,2)
    else:
        raise ValueError("input argument expected is [N,H,W,2] or [N,D,H,W,3]",
                         "got "+str(grid.shape)+" instead.")



def im2grid(image):
    """Reshape an image tensor into a grid tensor
        2D case [T,2,H,W]   ->  [T,H,W,2]
        3D case [T,3,D,H,W] ->  [T,D,H,W,3]
    """
    # No batch
    if image.shape[1] == 2:
        return image.transpose(1,2).transpose(2,3)
    elif image.shape[1] == 3:
        return image.transpose(1,2).transpose(2,3).transpose(3,4)
    else:
        raise ValueError("input argument expected is [1,2,H,W] or [1,3,D,H,W]",
                         "got "+str(image.shape)+" instead.")

您可以使用^{}。例如,如果你想用[T,H,W,2] -> [T,2,H,W]表示张量,比如说grid,你可以这样做

grid = grid.transpose(2,3).transpose(1,2)

相关问题 更多 >