Pythorch自定义激活功能?

2024-03-29 05:39:13 发布

您现在位置:Python中文网/ 问答频道 /正文

我在Pythorch中实现自定义激活函数时遇到问题,比如Swish。我应该如何在Pytorch中实现和使用自定义激活函数?在


Tags: 函数pytorchpythorchswish
2条回答

根据你所寻找的,有四种可能性。你需要问自己两个问题:

Q1)您的激活函数是否有可学习的参数?在

如果,则您没有选择将激活函数创建为nn.Module类,因为您需要存储这些权重。在

如果,您可以自由地创建一个普通函数或一个类,这取决于您是否方便。在

Q2)您的激活函数可以表示为现有Pythorch函数的组合吗?在

如果,您只需将其编写为现有PyTorch函数的组合,而不需要创建定义渐变的backward函数。在

如果则需要手动编写渐变。在

示例1:Swish函数

swish函数f(x) = x * sigmoid(x)没有任何学习的权重,可以完全用现有的PyTorch函数编写,因此您可以简单地将其定义为函数:

def swish(x):
    return x * torch.sigmoid(x)

然后简单地使用它,就像你有torch.relu或任何其他激活函数一样。在

示例2:使用学习的坡度进行快速滑行

在本例中,您有一个学习的参数,斜率,因此您需要将其生成一个类。在

^{pr2}$

例3:带后退

如果您有一些需要创建自己的渐变函数的东西,可以看看这个例子:Pytorch: define custom function

你可以编写一个自定义的激活函数,如下所示(例如加权Tanh)。在

class weightedTanh(nn.Module):
    def __init__(self, weights = 1):
        super().__init__()
        self.weights = weights

    def forward(self, input):
        ex = torch.exp(2*self.weights*input)
        return (ex-1)/(ex+1)

{cd1>不要使用兼容的反向操作。在

相关问题 更多 >