More appropriate/correct loss name

pull/6/head
Ross Wightman 5 years ago
parent 99ab1b1276
commit e6c14427c0

@ -1 +1 @@
from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
from loss.cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

@ -26,10 +26,10 @@ class LabelSmoothingCrossEntropy(nn.Module):
return loss.mean()
class SparseLabelCrossEntropy(nn.Module):
class SoftTargetCrossEntropy(nn.Module):
def __init__(self):
super(SparseLabelCrossEntropy, self).__init__()
super(SoftTargetCrossEntropy, self).__init__()
def forward(self, x, target):
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)

@ -13,7 +13,7 @@ except ImportError:
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
from models import create_model, resume_checkpoint
from utils import *
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from optim import create_optimizer
from scheduler import create_scheduler
@ -261,7 +261,7 @@ def main():
if args.mixup > 0.:
# smoothing is handled with mixup label transform
train_loss_fn = SparseLabelCrossEntropy().cuda()
train_loss_fn = SoftTargetCrossEntropy().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()

Loading…
Cancel
Save