擅长:python、mysql、java
<p>我猜<code>torch.where</code>会更快,我在CPU中进行了测量,结果如下</p>
<pre><code>import torch
a = torch.rand(3**10)
b = torch.rand(3**10)
</code></pre>
<pre><code>%timeit a[b > 0.5] = 0.
852 µs ± 30.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
</code></pre>
<pre><code>%timeit temp = torch.where(b > 0.5, torch.tensor(0.), a)
294 µs ± 4.51 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
</code></pre>