PyTorch RNN使用“批处理\u first=False”更有效?

2022-05-21 06:56:30 发布

您现在位置:Python中文网/ 问答频道 /正文

在机器翻译中,我们总是需要切掉注释和预测中的第一个时间步(SOS标记)

使用batch_first=False时,切掉第一个时间步仍然保持张量连续

import torch
batch_size = 128
seq_len = 12
embedding = 50

# Making a dummy output that is `batch_first=False`
batch_not_first = torch.randn((seq_len,batch_size,embedding))
batch_not_first = batch_first[1:].view(-1, embedding) # slicing out the first time step

然而,如果我们使用batch_first=True,在切片之后,张量不再是连续的。我们需要先使其连续,然后才能执行不同的操作,例如view

batch_first = torch.randn((batch_size,seq_len,embedding))
batch_first[:,1:].view(-1, embedding) # slicing out the first time step

output>>>
"""
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-a9bd590a1679> in <module>
----> 1 batch_first[:,1:].view(-1, embedding) # slicing out the first time step

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
"""

这是否意味着batch_first=False至少在机器翻译环境中更好?因为它避免了我们执行contiguous()步骤。是否有batch_first=True效果更好的案例


Tags: theviewfalsesizelentimestepbatchnottorchembeddingoutseqfirstslicing
1条回答
网友
1楼 ·

演出

batch_first=Truebatch_first=False之间似乎没有太大的区别。请参阅下面的脚本:

import time

import torch


def time_measure(batch_first: bool):
    torch.cuda.synchronize()
    layer = torch.nn.RNN(10, 20, batch_first=batch_first).cuda()
    if batch_first:
        inputs = torch.randn(100000, 7, 10).cuda()
    else:
        inputs = torch.randn(7, 100000, 10).cuda()

    start = time.perf_counter()

    for chunk in torch.chunk(inputs, 100000 // 64, dim=0 if batch_first else 1):
        _, last = layer(chunk)

    return time.perf_counter() - start


print(f"Time taken for batch_first=False: {time_measure(False)}")
print(f"Time taken for batch_first=True: {time_measure(True)}")

在我的设备(GTX 1050 Ti)、PyTorch1.6.0和CUDA 11.0上,以下是结果:

Time taken for batch_first=False: 0.3275816479999776
Time taken for batch_first=True: 0.3159054920001836

(这两种情况各不相同,因此没有任何结论)

代码可读性

当您想要使用需要batch作为第0维度的其他PyTorch层时batch_first=True更简单(这是几乎所有torch.nn层(如^{})的情况)

在这种情况下,如果指定了batch_first=False,您将不得不permute返回张量

机器翻译

它应该更好,因为tensor一直是连续的,不需要进行数据拷贝。使用[1:]而不是[:,1:]进行切片看起来也更干净