擅长:python、mysql、java
<p>执行for循环会在一定程度上降低性能。假设贡献向量中唯一类型的数量远小于B(或C)的长度,则可以使用假设<code>O(num_types)</code><;<;O(len_B)`,对for循环执行如下操作:</p>
<pre class="lang-py prettyprint-override"><code>
num_types = 3
B_len = 5
C_len = B_len
B = torch.randint(0, num_types, size=[B_len,])
"""
>>> B
tensor([2, 1, 1, 0, 0])
"""
C = torch.randint(0, 10, size=[C_len,])
C = C.float()
"""
>>> C
tensor([1., 5., 7., 6., 2.])
"""
# For loop here
A = [torch.sum(C * (torch.eq(B, type).float()) for type in range(num_types)]
# A = [tensor(8.), tensor(12.), tensor(1.)]
# Convert it to torch.tensor
A = torch.stack(A)
# tensor([ 8., 12., 1.])
</code></pre>