You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
37 lines
1.1 KiB
37 lines
1.1 KiB
""" Cross Entropy w/ smoothing or soft targets
|
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class LabelSmoothingCrossEntropy(nn.Module):
|
|
""" NLL loss with label smoothing.
|
|
"""
|
|
def __init__(self, smoothing=0.1):
|
|
super(LabelSmoothingCrossEntropy, self).__init__()
|
|
assert smoothing < 1.0
|
|
self.smoothing = smoothing
|
|
self.confidence = 1. - smoothing
|
|
|
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
logprobs = F.log_softmax(x, dim=-1)
|
|
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
|
nll_loss = nll_loss.squeeze(1)
|
|
smooth_loss = -logprobs.mean(dim=-1)
|
|
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
|
return loss.mean()
|
|
|
|
|
|
class SoftTargetCrossEntropy(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(SoftTargetCrossEntropy, self).__init__()
|
|
|
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
|
|
return loss.mean()
|