diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py index b781472f..28a686ce 100644 --- a/timm/loss/__init__.py +++ b/timm/loss/__init__.py @@ -1,2 +1,3 @@ from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy -from .jsd import JsdCrossEntropy \ No newline at end of file +from .jsd import JsdCrossEntropy +from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel \ No newline at end of file diff --git a/timm/loss/asymmetric_loss.py b/timm/loss/asymmetric_loss.py new file mode 100644 index 00000000..96a97788 --- /dev/null +++ b/timm/loss/asymmetric_loss.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn + + +class AsymmetricLossMultiLabel(nn.Module): + def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): + super(AsymmetricLossMultiLabel, self).__init__() + + self.gamma_neg = gamma_neg + self.gamma_pos = gamma_pos + self.clip = clip + self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss + self.eps = eps + + def forward(self, x, y): + """" + Parameters + ---------- + x: input logits + y: targets (multi-label binarized vector) + """ + + # Calculating Probabilities + x_sigmoid = torch.sigmoid(x) + xs_pos = x_sigmoid + xs_neg = 1 - x_sigmoid + + # Asymmetric Clipping + if self.clip is not None and self.clip > 0: + xs_neg = (xs_neg + self.clip).clamp(max=1) + + # Basic CE calculation + los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) + los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) + loss = los_pos + los_neg + + # Asymmetric Focusing + if self.gamma_neg > 0 or self.gamma_pos > 0: + if self.disable_torch_grad_focal_loss: + torch._C.set_grad_enabled(False) + pt0 = xs_pos * y + pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p + pt = pt0 + pt1 + one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) + one_sided_w = torch.pow(1 - pt, one_sided_gamma) + if self.disable_torch_grad_focal_loss: + torch._C.set_grad_enabled(True) + loss *= one_sided_w + + return -loss.sum() + + +class AsymmetricLossSingleLabel(nn.Module): + def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): + super(AsymmetricLossSingleLabel, self).__init__() + + self.eps = eps + self.logsoftmax = nn.LogSoftmax(dim=-1) + self.targets_classes = [] # prevent gpu repeated memory allocation + self.gamma_pos = gamma_pos + self.gamma_neg = gamma_neg + self.reduction = reduction + + def forward(self, inputs, target, reduction=None): + """" + Parameters + ---------- + x: input logits + y: targets (1-hot vector) + """ + + num_classes = inputs.size()[-1] + log_preds = self.logsoftmax(inputs) + self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) + + # ASL weights + targets = self.targets_classes + anti_targets = 1 - targets + xs_pos = torch.exp(log_preds) + xs_neg = 1 - xs_pos + xs_pos = xs_pos * targets + xs_neg = xs_neg * anti_targets + asymmetric_w = torch.pow(1 - xs_pos - xs_neg, + self.gamma_pos * targets + self.gamma_neg * anti_targets) + log_preds = log_preds * asymmetric_w + + if self.eps > 0: # label smoothing + self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) + + # loss calculation + loss = - self.targets_classes.mul(log_preds) + + loss = loss.sum(dim=-1) + if self.reduction == 'mean': + loss = loss.mean() + + return loss