numpy.random.multinomial 输出不正常?

2 投票
1 回答
915 浏览
提问于 2025-04-18 04:01

我有一个函数:

import numpy as np 
def unhot(vec):
    """ takes a one-hot vector and returns the corresponding integer """
    assert np.sum(vec) == 1    # this assertion shouldn't fail, but it did...
    return list(vec).index(1)

我在调用以下内容的输出时使用这个函数:

numpy.random.multinomial(1, coe)

但是在运行的时候,我遇到了一个断言错误。这是怎么回事呢?难道numpy.random.multinomial的输出不一定是一个独热编码向量吗?

然后我去掉了那个断言错误,现在我有:

ValueError: 1 is not in list

我是不是漏掉了什么细节,还是说这个功能就是有问题?

1 个回答

1

好吧,这就是问题所在,我应该早就意识到,因为我之前遇到过这个情况:

np.random.multinomial(1,A([  0.,   0.,  np.nan,   0.]))

返回

array([0,                    0, -9223372036854775807,0])

我使用了一个不稳定的softmax实现,导致出现了Nans(不是一个数字)。现在,我试图确保我传给多项式的参数总和小于等于1,但我这样做:

coe = softmax(coeffs)
while np.sum(coe) > 1-1e-9:
    coe /= (1+1e-5)

而且里面有NaNs的话,while语句可能根本不会被触发,我觉得。

撰写回答