擅长:python、mysql、java
<p>如果你的模型是“正确的”,它只是预测一只狗,你可以得到带有<code>torch.argmax(output, dim=1)</code>的标签,不管<code>batch</code>的大小</p>
<p>无论如何,你不应该使用<code>LogSoftmax</code>作为激活,请使用<a href="https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss" rel="nofollow noreferrer">^{<cd4>}</a>作为你的损失函数,<strong>从你的最后一层移除激活</strong>,<strong>只输出一个神经元</strong>(图像仅为狗的概率)。在您的案例中,它看起来是这样的:</p>
<pre><code>classifier = nn.Sequential(
OrderedDict(
[
("fc1", nn.Linear(1024, 500)),
("relu", nn.ReLU()),
("fc2", nn.Linear(500, 1)),
# See? No activation needed
]
)
)
</code></pre>
<p>只需运行<code>output > 0</code>即可使用上述网络选择正确的标签,并“免费”获得数值稳定性</p>