在chainer中,如何准确地使用trainer编写BPTT更新程序?

2024-05-16 12:28:43 发布

您现在位置:Python中文网/ 问答频道 /正文

链接器文档RNN教程在此页中包含错误代码: https://docs.chainer.org/en/stable/tutorial/recurrentnet.html

def update_bptt(updater):
    loss = 0
    for i in range(35):
        batch = train_iter.__next__()
        x, t = chainer.dataset.concat_examples(batch)
        loss += model(chainer.Variable(x), chainer.Variable(t))

    model.cleargrads()
    loss.backward()
    loss.unchain_backward()  # truncate
    optimizer.update()

updater = training.StandardUpdater(train_iter, optimizer, **update_bptt**)

那个培训.StandardUpdater第三个参数是converter=concat\u示例,而不是update函数。 如何准确地使用trainer编写BPTT?你知道吗


Tags: model链接batchupdatetrainvariableoptimizerchainer