LSTM PyTorch中LongTensor的运行时错误

2024-06-16 08:32:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试制作seq2seq自动编码器,我想使用预训练嵌入,所以我不在自动编码器架构中使用它。我发现LSTM出现了奇怪的错误:

--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-91-fd376d45585e> in <module> 18 for ei in range(input_length): 19 encoder_output, encoder_hidden, encoder_cell = encoder( ---> 20 input_tensor[ei], encoder_hidden, encoder_cell) 21 encoder_outputs[ei] = encoder_output[0, 0] 22 ~/miniconda3/envs/simple_code/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs) 545 result = self._slow_forward(*input, **kwargs) 546 else: --> 547 result = self.forward(*input, **kwargs) 548 for hook in self._forward_hooks.values(): 549 hook_result = hook(self, input, result) <ipython-input-89-012b4c3b2071> in forward(self, _input, hidden, cell) 8 _input = _input.view(1,1,-1) 9 print(type(_input), type(hidden), type(cell)) ---> 10 output, (hidden, cell) = self.lstm(_input.view(1,1,-1).long(), (hidden.long(), cell.long())) 11 return output, hidden, cell 12 ~/miniconda3/envs/simple_code/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs) 545 result = self._slow_forward(*input, **kwargs) 546 else: --> 547 result = self.forward(*input, **kwargs) 548 for hook in self._forward_hooks.values(): 549 hook_result = hook(self, input, result) ~/miniconda3/envs/simple_code/lib/python3.7/site-packages/torch/nn/modules/rnn.py in forward(self, input, hx) 562 return self.forward_packed(input, hx) 563 else: --> 564 return self.forward_tensor(input, hx) 565 566 class GRU(RNNBase): ~/miniconda3/envs/simple_code/lib/python3.7/site-packages/torch/nn/modules/rnn.py in forward_tensor(self, input, hx) 541 unsorted_indices = None 542 --> 543 output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) 544 545 return output, self.permute_hidden(hidden, unsorted_indices) ~/miniconda3/envs/simple_code/lib/python3.7/site-packages/torch/nn/modules/rnn.py in forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices) 524 if batch_sizes is None: 525 result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers, --> 526 self.dropout, self.training, self.bidirectional, self.batch_first) 527 else: 528 result = _VF.lstm(input, batch_sizes, hx, self._get_flat_weights(), self.bias, RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'mat2'

还有一个字符串,出现错误时:

output, (hidden, cell) = self.lstm(_input.view(1,1,-1).long(), (hidden.long(), cell.long()))

我不明白为什么会发生错误,因为我强制将所有张量转换为长类型


Tags: inselfencoderinputoutputbatchcellresult