<p>(嗯。。。这里有很多困惑,无论是问题还是答案……)</p>
<p>首先,这两个解(即您的解和建议的解)是<strong>而不是</strong>等价的;它们<strong>发生</strong>仅在一维分数数组的特殊情况下是等价的。如果您也尝试过Udacity测试提供的示例中的二维分数数组,您就会发现它。</p>
<p>从结果上看,这两种解决方案之间唯一的实际区别是<code>axis=0</code>参数。为了证明这是真的,让我们试试您的解决方案(<code>your_softmax</code>),其中唯一的区别是<code>axis</code>参数:</p>
<pre><code>import numpy as np
# your solution:
def your_softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
# correct solution:
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0) # only difference
</code></pre>
<p>正如我所说,对于一维分数数组,结果确实是相同的:</p>
<pre><code>scores = [3.0, 1.0, 0.2]
print(your_softmax(scores))
# [ 0.8360188 0.11314284 0.05083836]
print(softmax(scores))
# [ 0.8360188 0.11314284 0.05083836]
your_softmax(scores) == softmax(scores)
# array([ True, True, True], dtype=bool)
</code></pre>
<p>然而,以下是作为测试示例的Udacity测验中给出的二维分数数组的结果:</p>
<pre><code>scores2D = np.array([[1, 2, 3, 6],
[2, 4, 5, 6],
[3, 8, 7, 6]])
print(your_softmax(scores2D))
# [[ 4.89907947e-04 1.33170787e-03 3.61995731e-03 7.27087861e-02]
# [ 1.33170787e-03 9.84006416e-03 2.67480676e-02 7.27087861e-02]
# [ 3.61995731e-03 5.37249300e-01 1.97642972e-01 7.27087861e-02]]
print(softmax(scores2D))
# [[ 0.09003057 0.00242826 0.01587624 0.33333333]
# [ 0.24472847 0.01794253 0.11731043 0.33333333]
# [ 0.66524096 0.97962921 0.86681333 0.33333333]]
</code></pre>
<p>结果是不同的-第二个结果确实与Udacity测验中预期的结果相同,其中所有列的总和确实为1,而第一个(错误的)结果则不是这样。</p>
<p>所以,所有的麻烦实际上都是为了实现细节-参数<code>axis</code>。根据<a href="http://docs.scipy.org/doc/numpy/reference/generated/numpy.sum.html" rel="noreferrer">numpy.sum documentation</a>:</p>
<blockquote>
<p>The default, axis=None, will sum all of the elements of the input array</p>
</blockquote>
<p>在这里我们要按行求和,因此<code>axis=0</code>。对于一维数组,(仅)行的和和和所有元素的和碰巧是相同的,因此在这种情况下得到相同的结果。。。</p>
<p>撇开<code>axis</code>问题不谈,您的实现(即您选择先减去最大值)实际上比建议的解决方案要好!事实上,这是实现softmax函数的推荐方法-请参见<a href="http://cs231n.github.io/linear-classify/#softmax" rel="noreferrer">here</a>以获得理由(数字稳定性,也由上面的一些答案指出)。</p>