如何使用姿势热图进行GAN调节?

2024-04-18 18:11:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我想问一个关于在PyTorch中建立一个姿势的问题。我在这里的目的是仅在条件姿势下生成人体模型的图像(基于17x64x64姿势热图)。假设生成器调整已经差不多完成,我如何将姿势调节包括到鉴别器中? 我们可以使用来自https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/training/networks.py的鉴别器类作为示例:这里,在鉴别器的forward()方法中应用了一个简单的基于标签的条件作用

def forward(self, x, img, cmap, force_fp32=False):
        misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
        # Here, cmap is just a simple class label mapping. In my case, cmap would include 
        # a 17-channel pose heatmap from a certain source image.
        _ = force_fp32 # unused
        dtype = torch.float32
        memory_format = torch.contiguous_format

        # FromRGB.
        x = x.to(dtype=dtype, memory_format=memory_format)
        if self.architecture == 'skip':
            misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
            img = img.to(dtype=dtype, memory_format=memory_format)
            x = x + self.fromrgb(img)

        # Main layers.
        if self.mbstd is not None:
            x = self.mbstd(x)
        x = self.conv(x)
        x = self.fc(x.flatten(1))
        x = self.out(x)

        # Conditioning.
        if self.cmap_dim > 0:
            misc.assert_shape(cmap, [None, self.cmap_dim])
            x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
        assert x.dtype == dtype
        return x

在热图尺寸为[batch_size, 17, 64, 64]的情况下,如何调整此代码以适应我的问题?我以为阿巴斯会把热图展平,但那样会丢失空间信息。另一个选项是从图像中提取热图xmap,并计算xmapgmap之间的某种形式的距离(某种形式的像素级MAE?)。然而,我很难想象如何将这样一个结果与基本输出x结合起来进行调节


Tags: selfnoneformatimgifassertmisccmap