“视图”方法在PyTorch中是如何工作的?

2024-04-18 02:51:54 发布

您现在位置:Python中文网/ 问答频道 /正文

我对下面代码片段中的方法view()感到困惑。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool  = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

我的困惑是关于下面这一行。

x = x.view(-1, 16*5*5)

tensor.view()函数做什么?我在很多地方见过它的用法,但我不明白它是如何解释它的参数的。

如果我将负值作为view()函数的参数,会发生什么情况?例如,如果我调用tensor_variable.view(1, 1, -1),会发生什么?

有人能用一些例子解释一下view()函数的主要原理吗?


Tags: 函数selfviewnetinitdefnnlinear
3条回答

我发现x.view(-1, 16 * 5 * 5)相当于x.flatten(1),其中参数1表示展平过程从第一个维度开始(而不是展平“sample”维度) 如您所见,后一种用法在语义上更清晰、更易于使用,因此我更喜欢flatten()

view函数是用来重塑张量的。

假设你有张量

import torch
a = torch.range(1, 16)

a是一个有16个元素从1到16(包括)的张量。如果你想重塑这个张量使其成为一个4 x 4张量,那么你可以使用

a = a.view(4, 4)

现在a将是一个4 x 4张量。请注意,重塑后,元素总数需要保持不变。将张量a重塑为3 x 5张量是不合适的。

参数-1是什么意思?

如果在某些情况下,您不知道需要多少行,但确定列的数量,则可以使用-1指定。(请注意,可以将其扩展到具有更多维度的张量。只有一个轴值可以是-1)。这是告诉库的一种方式:“给我一个有这么多列的张量,然后计算出实现这一点所需的适当行数”。

这可以在上面给出的神经网络代码中看到。在forward函数中的第x = self.pool(F.relu(self.conv2(x)))行之后,您将有一个16深度的特征映射。你必须把它展平,把它交给完全连接的层。所以你告诉pytorch重塑你得到的张量,使其具有特定的列数,并告诉它自己决定行数。

绘制numpy和pytorch之间的相似性,view类似于numpy的reshape函数。

让我们举几个例子,从简单到困难。

  1. view方法返回一个张量,其数据与self张量相同(这意味着返回的张量具有相同数量的元素),但形状不同。例如:

    a = torch.arange(1, 17)  # a's shape is (16,)
    
    a.view(4, 4) # output below
      1   2   3   4
      5   6   7   8
      9  10  11  12
     13  14  15  16
    [torch.FloatTensor of size 4x4]
    
    a.view(2, 2, 4) # output below
    (0 ,.,.) = 
    1   2   3   4
    5   6   7   8
    
    (1 ,.,.) = 
     9  10  11  12
    13  14  15  16
    [torch.FloatTensor of size 2x2x4]
    
  2. 假设-1不是参数之一,当您将它们相乘时,结果必须等于张量中的元素数。如果这样做:a.view(3, 3),它将引发一个RuntimeError,因为形状(3x 3)对于16个元素的输入无效。换句话说:3×3不等于16,而等于9。

  3. 可以使用-1作为传递给函数的参数之一,但只能使用一次。所发生的一切是,该方法将为您计算如何填充该维度。例如,a.view(2, -1, 4)等同于a.view(2, 2, 4)。[16/(2 x 4)=2]

  4. 注意,返回的张量共享相同的数据。如果在“视图”中进行更改,则更改的是原始张量的数据:

    b = a.view(4, 4)
    b[0, 2] = 2
    a[2] == 3.0
    False
    
  5. 现在,对于更复杂的用例。文档说明,每个新视图维度必须是原始维度的子空间,或者只能是spand,d+1,…,d+k满足以下类似于连续性的条件,即对于所有i=0,…,k-1,步长[i]=步长[i+1]x大小[i+1]。否则,需要调用contiguous()才能查看张量。例如:

    a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2)
    a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4)
    
    # The commented line below will raise a RuntimeError, because one dimension
    # spans across two contiguous subspaces
    # a_t.view(-1, 4)
    
    # instead do:
    a_t.contiguous().view(-1, 4)
    
    # To see why the first one does not work and the second does,
    # compare a.stride() and a_t.stride()
    a.stride() # (24, 6, 2, 1)
    a_t.stride() # (24, 2, 1, 6)
    

    注意,对于a_t跨步[0]!=步幅[1]x大小[1]24起!=2 x 3

相关问题 更多 >