PyTorch:使用numpy数组为GRU/LSTM手动设置权重参数

2024-04-20 05:42:58 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图用pytorch中手动定义的参数填充GRU/LSTM。

我有numpy数组,用于参数的形状如文档中所定义的(https://pytorch.org/docs/stable/nn.html#torch.nn.GRU)。

它似乎可以工作,但我不确定返回的值是否正确。

这是用numpy参数填充GRU/LSTM的正确方法吗?

gru = nn.GRU(input_size, hidden_size, num_layers,
              bias=True, batch_first=False, dropout=dropout, bidirectional=bidirectional)

def set_nn_wih(layer, parameter_name, w, l0=True):
    param = getattr(layer, parameter_name)
    if l0:
        for i in range(3*hidden_size):
            param.data[i] = w[i*input_size:(i+1)*input_size]
    else:
        for i in range(3*hidden_size):
            param.data[i] = w[i*num_directions*hidden_size:(i+1)*num_directions*hidden_size]

def set_nn_whh(layer, parameter_name, w):
    param = getattr(layer, parameter_name)
    for i in range(3*hidden_size):
        param.data[i] = w[i*hidden_size:(i+1)*hidden_size]

l0=True

for i in range(num_directions):
    for j in range(num_layers):
        if j == 0:
            wih = w0[i, :, :3*input_size]
            whh = w0[i, :, 3*input_size:]  # check
            l0=True
        else:
            wih = w[j-1, i, :, :num_directions*3*hidden_size]
            whh = w[j-1, i, :, num_directions*3*hidden_size:]
            l0=False

        if i == 0:
            set_nn_wih(
                gru, "weight_ih_l{}".format(j), torch.from_numpy(wih.flatten()),l0)
            set_nn_whh(
                gru, "weight_hh_l{}".format(j), torch.from_numpy(whh.flatten()))
        else:
            set_nn_wih(
                gru, "weight_ih_l{}_reverse".format(j), torch.from_numpy(wih.flatten()),l0)
            set_nn_whh(
                gru, "weight_hh_l{}_reverse".format(j), torch.from_numpy(whh.flatten()))

y, hn = gru(x_t, h_t)

numpy数组定义如下:

rng = np.random.RandomState(313)
w0 = rng.randn(num_directions, hidden_size, 3*(input_size +
               hidden_size)).astype(np.float32)
w = rng.randn(max(1, num_layers-1), num_directions, hidden_size,
              3*(num_directions*hidden_size + hidden_size)).astype(np.float32)

Tags: numpyforinputsizeparamnntorchnum