parent
b0158a593e
commit
f2029dfb65
@ -0,0 +1 @@
|
||||
from loss.cross_entropy import LabelSmoothingCrossEntropy
|
@ -0,0 +1,26 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LabelSmoothingCrossEntropy(nn.Module):
|
||||
"""
|
||||
NLL loss with label smoothing.
|
||||
"""
|
||||
def __init__(self, smoothing=0.1):
|
||||
"""
|
||||
Constructor for the LabelSmoothing module.
|
||||
:param smoothing: label smoothing factor
|
||||
"""
|
||||
super(LabelSmoothingCrossEntropy, self).__init__()
|
||||
assert smoothing < 1.0
|
||||
self.smoothing = smoothing
|
||||
self.confidence = 1. - smoothing
|
||||
|
||||
def forward(self, x, target):
|
||||
logprobs = F.log_softmax(x, dim=-1)
|
||||
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
||||
nll_loss = nll_loss.squeeze(1)
|
||||
smooth_loss = -logprobs.mean(dim=-1)
|
||||
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
||||
return loss.mean()
|
||||
|
Loading…
Reference in new issue