diff --git a/timm/loss/binary_cross_entropy.py b/timm/loss/binary_cross_entropy.py new file mode 100644 index 00000000..6da04dba --- /dev/null +++ b/timm/loss/binary_cross_entropy.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DenseBinaryCrossEntropy(nn.Module): + """ BCE using one-hot from dense targets w/ label smoothing + NOTE for experiments comparing CE to BCE /w label smoothing, may remove + """ + def __init__(self, smoothing=0.1): + super(DenseBinaryCrossEntropy, self).__init__() + assert 0. <= smoothing < 1.0 + self.smoothing = smoothing + self.bce = nn.BCEWithLogitsLoss() + + def forward(self, x, target): + num_classes = x.shape[-1] + off_value = self.smoothing / num_classes + on_value = 1. - self.smoothing + off_value + target = target.long().view(-1, 1) + target = torch.full( + (target.size()[0], num_classes), off_value, device=x.device, dtype=x.dtype).scatter_(1, target, on_value) + return self.bce(x, target)