擅长:python、mysql、java
<p>您可以在这里使用<a href="https://numpy.org/doc/stable/reference/arrays.indexing.html" rel="nofollow noreferrer">fancy indexing</a>来选择所需的张量部分</p>
<p>本质上,如果预先生成传递访问模式的索引数组,则可以直接使用它们提取张量的某些片段。每个维度的索引数组的形状应与要提取的输出张量或切片的形状相同</p>
<pre class="lang-py prettyprint-override"><code>i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1, 1]
j = labels.reshape(bs, sample, 1) # shape = [bs, sample, 1]
k = torch.arange(v) # shape = [v, ]
# Get result as
t1 = t0[i, j, k]
</code></pre>
<p>注意上面3个张量的形状<a href="https://numpy.org/doc/stable/user/basics.broadcasting.html" rel="nofollow noreferrer">Broadcasting</a>在张量的前面附加额外的维度,从而从本质上将<code>k</code>重塑为<code>[1, 1, v]</code>形状,这使得所有3个维度都兼容元素操作</p>
<p>广播后<code>(i, j, k)</code>一起将产生3<code>[bs, sample, v]</code>形状的数组,这些数组将(按元素)索引原始张量以产生形状<code>[bs, sample, v]</code>的输出张量<code>t1</code></p>