Theano 张量高阶索引,共享索引

2024-06-02 07:22:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个张量probsprobs.shape = (max_time, num_batches, num_labels)。在

我有一个张量targetstargets.shape = (max_seq_len, num_batches),其中的值是标签索引,也就是说,probs中的第三维度。在

现在我想得到一个张量probs_y,其中第三维是targets中的索引。基本上

probs_y[:,i,:] = probs[:,i,targets[:,i]]

所有人0 <= i < num_batches。在

我怎样才能做到这一点?在

解决方案中的一个类似问题已发布here。在

如果我理解正确的话,解决方案是:

^{pr2}$

但这似乎行不通。我得到: IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices。在

另外,创建临时的T.arange不是有点贵吗?特别是当我试图解决这个问题时,把它变成一个完全密集的整数数组。应该有更好的办法。在

也许theano.map?但据我所知,这并不能并行化代码,所以这也不是一个解决方案。在


Tags: labelslenheretimebatches标签解决方案num
1条回答
网友
1楼 · 发布于 2024-06-02 07:22:37

这对我有用:

import theano
import theano.tensor as T

max_time, num_batches, num_labels = 3, 4, 6
max_seq_len = 5

probs_ = np.arange(max_time * num_batches * num_labels).reshape(
    max_time, num_batches, num_labels)

targets_ = np.arange(num_batches * max_seq_len).reshape(max_seq_len, 
    num_batches) % (num_batches - 1)  # mix stuff up

probs, targets = map(theano.shared, (probs_, targets_))

print probs_
print targets_

probs_y = probs[:, T.arange(targets.shape[1])[:, np.newaxis], targets.T]

print probs_y.eval()

上面使用了你的索引的转置版本。你的确切主张也有效

^{pr2}$

也许你的问题在别的地方。在

至于速度,我不知道还有什么比这更快。map,它是{}的一个特化,几乎肯定不是。我不知道在多大程度上,arange实际上是构建而不是简单地迭代。在

相关问题 更多 >