From f2029dfb65858ec3fc1ce948bb3d4f963db39318 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 5 Apr 2019 20:50:26 -0700 Subject: [PATCH] Add smooth loss --- loss/__init__.py | 1 + loss/cross_entropy.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 loss/__init__.py create mode 100644 loss/cross_entropy.py diff --git a/loss/__init__.py b/loss/__init__.py new file mode 100644 index 00000000..9ad83fb1 --- /dev/null +++ b/loss/__init__.py @@ -0,0 +1 @@ +from loss.cross_entropy import LabelSmoothingCrossEntropy \ No newline at end of file diff --git a/loss/cross_entropy.py b/loss/cross_entropy.py new file mode 100644 index 00000000..db4aaed9 --- /dev/null +++ b/loss/cross_entropy.py @@ -0,0 +1,26 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class LabelSmoothingCrossEntropy(nn.Module): + """ + NLL loss with label smoothing. + """ + def __init__(self, smoothing=0.1): + """ + Constructor for the LabelSmoothing module. + :param smoothing: label smoothing factor + """ + super(LabelSmoothingCrossEntropy, self).__init__() + assert smoothing < 1.0 + self.smoothing = smoothing + self.confidence = 1. - smoothing + + def forward(self, x, target): + logprobs = F.log_softmax(x, dim=-1) + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.mean() +