擅长:python、mysql、java
<p>根据你所寻找的,有四种可能性。你需要问自己两个问题:</p>
<p><strong>Q1)</strong>您的激活函数是否有可学习的参数?在</p>
<p>如果<strong>是</strong>,则您没有选择将激活函数创建为<code>nn.Module</code>类,因为您需要存储这些权重。在</p>
<p>如果<strong>否</strong>,您可以自由地创建一个普通函数或一个类,这取决于您是否方便。在</p>
<p><strong>Q2)</strong>您的激活函数可以表示为现有Pythorch函数的组合吗?在</p>
<p>如果<strong>是</strong>,您只需将其编写为现有PyTorch函数的组合,而不需要创建定义渐变的<code>backward</code>函数。在</p>
<p>如果<strong>否</strong>则需要手动编写渐变。在</p>
<p><strong>示例1:Swish函数</strong></p>
<p>swish函数<code>f(x) = x * sigmoid(x)</code>没有任何学习的权重,可以完全用现有的PyTorch函数编写,因此您可以简单地将其定义为函数:</p>
<pre><code>def swish(x):
return x * torch.sigmoid(x)
</code></pre>
<p>然后简单地使用它,就像你有<code>torch.relu</code>或任何其他激活函数一样。在</p>
<p><strong>示例2:使用学习的坡度进行快速滑行</strong></p>
<p>在本例中,您有一个学习的参数,斜率,因此您需要将其生成一个类。在</p>
^{pr2}$
<p><strong>例3:带后退</strong></p>
<p>如果您有一些需要创建自己的渐变函数的东西,可以看看这个例子:<a href="https://stackoverflow.com/questions/46509039/pytorch-define-custom-function">Pytorch: define custom function</a></p>