擅长:python、mysql、java
<pre class="lang-py prettyprint-override"><code>>>> import torch
>>> import numpy as np
>>> s = np.arange(12).reshape(4,3)
>>> s = torch.tensor(s)
>>> s
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
>>> idx = torch.tensor([0, 2, 1, 2])
>>> torch.gather(s,-1 ,idx.unsqueeze(-1))
tensor([[ 0],
[ 5],
[ 7],
[11]])
</code></pre>
<p><code>torch.gather(s,-1 ,idx.unsqueeze(-1))</code></p>