Pytorch变换张量到一个热

2024-04-26 18:38:47 发布

您现在位置:Python中文网/ 问答频道 /正文

将填充有n值的形状张量(批次大小、高度、宽度)转换为形状张量(批次大小、n、高度、宽度)的最简单方法是什么? 我在下面创建了一个解决方案,但看起来有更简单、更快的方法


def batch_tensor_to_onehot(tnsr, classes):
    tnsr = tnsr.unsqueeze(1)
    res = []
    for cls in range(classes):
        res.append((tnsr == cls).long())
    return torch.cat(res, dim=1)

Tags: to方法宽度高度defbatchres解决方案
2条回答

您可以使用torch.nn.functional.one_hot

对于您的情况:

a = torch.nn.functional.one_hot(tnsr, num_classes=classes)
out = a.permute(0, 3, 1, 2)

您也可以使用^{},它避免了.permute,但可以说比@Alpha提出的简单方法更难理解

def batch_tensor_to_onehot(tnsr, classes):
    result = torch.zeros(tnsr.shape[0], classes, *tnsr.shape[1:], dtype=torch.long, device=tnsr.device)
    result.scatter_(1, tnsr.unsqueeze(1), 1) 
    return result

基准测试结果

我很好奇,决定对这三种方法进行基准测试。我发现提议的方法在批次大小、宽度或高度方面似乎没有明显的相对差异。主要是班级数量是区别因素。当然,与任何基准里程一样,里程可能会有所不同

使用随机指数和批次大小、高度、宽度=100收集基准。每个实验重复20次,报告平均值。num_classes=100实验在热身分析之前运行一次

CPU结果表明,对于小于30的num_类,原始方法可能是最好的,而对于GPU,scatter_方法似乎是最快的

在Ubuntu 18.04、NVIDIA 2060 Super、i7-9700K上执行的测试

enter image description here

enter image description here

用于基准测试的代码如下所示:

import torch
from tqdm import tqdm
import time
import matplotlib.pyplot as plt


def batch_tensor_to_onehot_slavka(tnsr, classes):
    tnsr = tnsr.unsqueeze(1)
    res = []
    for cls in range(classes):
        res.append((tnsr == cls).long())
    return torch.cat(res, dim=1)


def batch_tensor_to_onehot_alpha(tnsr, classes):
    result = torch.nn.functional.one_hot(tnsr, num_classes=classes)
    return result.permute(0, 3, 1, 2)


def batch_tensor_to_onehot_jodag(tnsr, classes):
    result = torch.zeros(tnsr.shape[0], classes, *tnsr.shape[1:], dtype=torch.long, device=tnsr.device)
    result.scatter_(1, tnsr.unsqueeze(1), 1)
    return result


def main():
    num_classes = [2, 10, 25, 50, 100]
    height = 100
    width = 100
    bs = [100] * 20

    for d in ['cpu', 'cuda']:
        times_slavka = []
        times_alpha = []
        times_jodag = []
        warmup = True
        for c in tqdm([num_classes[-1]] + num_classes, ncols=0):
            tslavka = 0
            talpha = 0
            tjodag = 0

            for b in bs:
                tnsr = torch.randint(c, (b, height, width)).to(device=d)

                t0 = time.time()
                y = batch_tensor_to_onehot_slavka(tnsr, c)
                torch.cuda.synchronize()
                tslavka += time.time() - t0
            if not warmup:
                times_slavka.append(tslavka / len(bs))

            for b in bs:
                tnsr = torch.randint(c, (b, height, width)).to(device=d)

                t0 = time.time()
                y = batch_tensor_to_onehot_alpha(tnsr, c)
                torch.cuda.synchronize()
                talpha += time.time() - t0
            if not warmup:
                times_alpha.append(talpha / len(bs))

            for b in bs:
                tnsr = torch.randint(c, (b, height, width)).to(device=d)

                t0 = time.time()
                y = batch_tensor_to_onehot_jodag(tnsr, c)
                torch.cuda.synchronize()
                tjodag += time.time() - t0
            if not warmup:
                times_jodag.append(tjodag / len(bs))


            warmup = False

        fig = plt.figure()
        ax = fig.subplots()
        ax.plot(num_classes, times_slavka, label='Slavka-cat')
        ax.plot(num_classes, times_alpha, label='Alpha-one_hot')
        ax.plot(num_classes, times_jodag, label='jodag-scatter_')
        ax.set_xlabel('num_classes')
        ax.set_ylabel('time (s)')
        ax.set_title(f'{d} benchmark')
        ax.legend()
        plt.savefig(f'{d}.png')
        plt.show()


if __name__ == "__main__":
    main()

相关问题 更多 >