我有一个张量probs
和probs.shape = (max_time, num_batches, num_labels)
。在
我有一个张量targets
和targets.shape = (max_seq_len, num_batches)
,其中的值是标签索引,也就是说,probs
中的第三维度。在
现在我想得到一个张量probs_y
,其中第三维是targets
中的索引。基本上
probs_y[:,i,:] = probs[:,i,targets[:,i]]
所有人0 <= i < num_batches
。在
我怎样才能做到这一点?在
解决方案中的一个类似问题已发布here。在
如果我理解正确的话,解决方案是:
^{pr2}$但这似乎行不通。我得到:
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices
。在
另外,创建临时的T.arange
不是有点贵吗?特别是当我试图解决这个问题时,把它变成一个完全密集的整数数组。应该有更好的办法。在
也许theano.map
?但据我所知,这并不能并行化代码,所以这也不是一个解决方案。在
这对我有用:
上面使用了你的索引的转置版本。你的确切主张也有效
^{pr2}$也许你的问题在别的地方。在
至于速度,我不知道还有什么比这更快。}的一个特化,几乎肯定不是。我不知道在多大程度上,
map
,它是{arange
实际上是构建而不是简单地迭代。在相关问题 更多 >
编程相关推荐