如何在PyTorch中使用具有焦点损失的类权重进行多类分类的不平衡数据集

2024-06-01 01:09:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在为语言任务进行多类分类(4类),并使用伯特模型进行分类任务。我正在跟踪this blog as referenceMy BERT微调模型返回nn.LogSoftmax(dim=1)

我的数据非常不平衡,所以我使用sklearn.utils.class_weight.compute_class_weight来计算类的权重,并使用损失中的权重

class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy  = nn.NLLLoss(weight=weights) 

我的结果不太好,所以我想用Focal Loss进行实验,并有一个焦点丢失代码

class FocalLoss(nn.Module):
  def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.logits = logits
    self.reduce = reduce

  def forward(self, inputs, targets):
    BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

    if self.reduce:
      return torch.mean(F_loss)
    else:
      return F_loss

我现在有三个问题。首先也是最重要的是

  1. 我是否应该使用重心减退的班级体重?
  2. 如果必须在此Focal Loss内实现权重,是否可以在 nn.CrossEntropyLoss()内使用weights参数
  3. 如果此机具不正确,则此机具的正确代码应该是什么,包括重量(如果可能)

Tags: 模型selfalphareduce分类nntorchclass
2条回答

我想OP现在应该已经得到答案了。我写这篇文章是为了其他可能会思考这个问题的人

OPs实施焦点丢失时存在一个问题:

  1. F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

在这一行中,相同的alpha值乘以每个类的输出概率,即(pt)。此外,代码没有显示如何获得pt。焦点丢失的一个很好的实现是here。但是这个实现只适用于二进制分类,因为它在self.alpha张量中有alpha1-alpha两个类

在多类分类或多标签分类的情况下,self.alpha张量应包含等于标签总数的元素数。这些值可以是标签的反向标签频率或反向标签规范化频率(只需注意频率为0的标签)

您可以找到以下问题的答案:

  1. 焦点损失自动处理类别不平衡,因此焦点损失不需要权重。α和γ因子处理焦损方程中的类别不平衡
  2. 不需要额外的权重,因为焦损使用阿尔法和伽马调制因子处理它们
  3. 根据焦损公式,您提到的实现是正确的,但是我在使我的模型与这个版本收敛时遇到了麻烦,因此,我使用了the following implementation from mmdetection framework
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

您还可以使用another focal loss version available进行实验

相关问题 更多 >