我刚刚开始使用Pytork,我正在尝试一种简单的多层感知器。我的ReLU激活功能如下:
def ReLU_activation_func(outputs):
print(type(outputs))
result = torch.where(outputs > 0, outputs, 0.)
result = float(result)
return result
因此,我试图保持大于0的值,如果该值小于0,则将该值更改为0。 这是我使用ReLU函数的主代码的一部分(我有错误):
def forward_pass(train_loader):
for batch_idx, (image, label) in enumerate(train_loader):
print(image.size())
x = image.view(-1, 28 * 28)
print(x.size())
input_node_num = 28 * 28
hidden_node_num = 100
output_node_num = 10
W_ih = torch.rand(input_node_num, hidden_node_num)
W_ho = torch.rand(hidden_node_num, output_node_num)
final_output_n = ReLU_activation_func(torch.matmul(x, W_ih))
当我运行代码时,我得到以下错误:
RuntimeError:
1 forward_pass(train_loader)
in forward_pass(train_loader)
-----14 W_ih = torch.rand(input_node_num, hidden_node_num)
-----15 W_ho = torch.rand(hidden_node_num, output_node_num)
---->16 final_output_n = ReLU_activation_func(torch.matmul(x, W_ih))
in ReLU_activation_func(outputs)
-----10 print(type(outputs))
---->11 result = torch.where(outputs > 0, outputs, 0.)
-----12 result = float(result)
-----13 return result
RuntimeError: expected scalar type float but found double
有什么帮助吗
问题不在
result
,而是在X
、W_ih
或torch.where(outputs > 0, outputs, 0.)
上如果不为^{} 的
dtype
设置参数,它将根据pytorch的全局默认值分配数据类型可以使用^{} 更改全局变量
或者走简单的路线:
相关问题 更多 >
编程相关推荐