Pythorch LSTM源代码问题

2024-04-24 14:31:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用带有batch_first=True的双向LSTM。但是,它给了我一个关于尺寸的错误。 **Error: Expected hidden[0] size (6, 5, 40), got (5, 6, 40)** 当我检查源代码时,错误是由下面的函数引起的

if is_input_packed:
            mini_batch = int(batch_sizes[0])
        else:
            mini_batch = input.size(0) if self.batch_first else input.size(1)

        num_directions = 2 if self.bidirectional else 1
        expected_hidden_size = (self.num_layers * num_directions,
                                mini_batch, self.hidden_size)

        def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
            if tuple(hx.size()) != expected_hidden_size:
                raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))

默认情况下,expected_hidden_size是根据序列先写入的。我相信这是问题的根源。有人能告诉我是不是对的,这个问题需要解决吗?在


Tags: selfinputsizeif错误batchelsenum