将嵌入序列传递给LSTM并获取TypeError:“int”对象不可下标

2024-04-19 03:33:29 发布

您现在位置:Python中文网/ 问答频道 /正文

我这里有一些非常基本的pytorch代码,我试图通过最终将成为我的forward函数的东西来测试运行输入张量

目标:在嵌入每个单词编号后,将句子视为单个输入序列

  1. 嵌入张量
  2. 将嵌入转换回float32张量
  3. 重塑嵌入形状(批量大小、序列长度、输入大小)
  4. 通过lstm

我在嵌入后转换回了float32张量,所以idk为什么会出现这个错误

hidden_size=10
embedding = nn.Embedding(VOC.n_words, hidden_size)
lstm = nn.LSTM(hidden_size, hidden_size, # Will output 2x hidden size
               num_layers=1, dropout=0.5,
               bidirectional=True, batch_first=True)

print("Input tensor",idx_sentence)
# Forward test
embedded = embedding(idx_sentence.long())
embedded = torch.tensor(embedded, dtype=torch.float32)
print(f"embedding: {embedded.size()}")

# reshape to (batch_size, seq_len, input_size)
sequence = embedded.view(1,-1,hidden_size)
print(f"sequence shape: {sequence.size()}")

output, hidden = lstm(sequence, hidden_size)
print(f"output shape: {output.size()}")
Input tensor tensor([ 3., 20., 21., 90.,  9.])
embedding: torch.Size([5, 10])
sequence shape: torch.Size([1, 5, 10])
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:10: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  # Remove the CWD from sys.path while we load stuff.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-116-ab3d6ed0e51c> in <module>()
     16 
     17 # Input have shape (seq_len, batch, input_size)
---> 18 output, hidden = lstm(sequence, hidden_size)
     19 print(f"output shape: {output.size()}")

2 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py in check_forward_args(self, input, hidden, batch_sizes)
    520         expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
    521 
--> 522         self.check_hidden_size(hidden[0], expected_hidden_size,
    523                                'Expected hidden[0] size {}, got {}')
    524         self.check_hidden_size(hidden[1], expected_hidden_size,

TypeError: 'int' object is not subscriptable

Tags: selfinputoutputsizebatchtorchembeddinghidden
1条回答
网友
1楼 · 发布于 2024-04-19 03:33:29

LSTM接受两个输入,如^{} - Inputs中所述:

  • input:输入序列
  • (h_0, c_0):具有初始隐藏状态h_0和初始单元格状态c_0的元组

但是您正在传递hidden_size作为第二个参数,它是int而不是tuple。解包元组时会失败,因为hidden_size[0]不起作用,因为整数不能被索引

第二个参数是可选的,如果不提供它,则隐藏状态和单元格状态将默认为零。这通常是您想要的,因此您可以不使用它:

output, hidden = lstm(sequence)

相关问题 更多 >