我正在为语言任务进行多类分类(4类),并使用伯特模型进行分类任务。我正在跟踪this blog as referenceMy BERT微调模型返回nn.LogSoftmax(dim=1)
。
我的数据非常不平衡,所以我使用sklearn.utils.class_weight.compute_class_weight
来计算类的权重,并使用损失中的权重
class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy = nn.NLLLoss(weight=weights)
我的结果不太好,所以我想用Focal Loss
进行实验,并有一个焦点丢失代码
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
BCE_loss = nn.CrossEntropyLoss()(inputs, targets)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
我现在有三个问题。首先也是最重要的是
Focal Loss
内实现权重,是否可以在 nn.CrossEntropyLoss()
内使用weights
参数
我想OP现在应该已经得到答案了。我写这篇文章是为了其他可能会思考这个问题的人
OPs实施焦点丢失时存在一个问题:
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
在这一行中,相同的
alpha
值乘以每个类的输出概率,即(pt
)。此外,代码没有显示如何获得pt
。焦点丢失的一个很好的实现是here。但是这个实现只适用于二进制分类,因为它在self.alpha
张量中有alpha
和1-alpha
两个类在多类分类或多标签分类的情况下,
self.alpha
张量应包含等于标签总数的元素数。这些值可以是标签的反向标签频率或反向标签规范化频率(只需注意频率为0的标签)您可以找到以下问题的答案:
您还可以使用another focal loss version available进行实验
相关问题 更多 >
编程相关推荐