PyTorch:错误>>应为标量类型浮点,但找到双精度

2024-04-20 05:00:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我刚刚开始使用Pytork,我正在尝试一种简单的多层感知器。我的ReLU激活功能如下:

def ReLU_activation_func(outputs):
    print(type(outputs))
    result = torch.where(outputs > 0, outputs, 0.)
    result = float(result)
    return result

因此,我试图保持大于0的值,如果该值小于0,则将该值更改为0。 这是我使用ReLU函数的主代码的一部分(我有错误):

def forward_pass(train_loader):
    for batch_idx, (image, label) in enumerate(train_loader):
        print(image.size())
        x = image.view(-1, 28 * 28)
        print(x.size())
    
        input_node_num = 28 * 28
        hidden_node_num = 100
        output_node_num = 10
        W_ih = torch.rand(input_node_num, hidden_node_num)
        W_ho = torch.rand(hidden_node_num, output_node_num)
        final_output_n = ReLU_activation_func(torch.matmul(x, W_ih))

当我运行代码时,我得到以下错误:

RuntimeError:
1 forward_pass(train_loader)

in forward_pass(train_loader)
-----14         W_ih = torch.rand(input_node_num, hidden_node_num)
-----15         W_ho = torch.rand(hidden_node_num, output_node_num)
---->16         final_output_n = ReLU_activation_func(torch.matmul(x, W_ih))

in ReLU_activation_func(outputs)
-----10     print(type(outputs))
---->11     result = torch.where(outputs > 0, outputs, 0.)
-----12     result = float(result)
-----13     return result

RuntimeError: expected scalar type float but found double

有什么帮助吗


Tags: nodeoutputtraintorchloaderresultoutputsactivation
1条回答
网友
1楼 · 发布于 2024-04-20 05:00:49

问题不在result,而是在XW_ihtorch.where(outputs > 0, outputs, 0.)

如果不为^{}dtype设置参数,它将根据pytorch的全局默认值分配数据类型

可以使用^{}更改全局变量

或者走简单的路线:

def ReLU_activation_func(outputs):
    print((outputs).dtype)
    result = torch.where(outputs > 0, outputs, torch.zeros_like(outputs)).float()
    return result

# for the forward pass function, convert the tensor to floats before matmul
def forward_pass(train_loader):
    for batch_idx, (image, label) in enumerate(train_loader):
        ... <your code>
        X, W_ih = X.float(), W_ih.float()
        final_output_n = ReLU_activation_func(torch.matmul(x, W_ih))

相关问题 更多 >