diff --git a/loss/__init__.py b/loss/__init__.py index 6eaa4c76..43399e76 100644 --- a/loss/__init__.py +++ b/loss/__init__.py @@ -1 +1 @@ -from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy \ No newline at end of file +from loss.cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy \ No newline at end of file diff --git a/loss/cross_entropy.py b/loss/cross_entropy.py index 821b1fe3..60bef646 100644 --- a/loss/cross_entropy.py +++ b/loss/cross_entropy.py @@ -26,10 +26,10 @@ class LabelSmoothingCrossEntropy(nn.Module): return loss.mean() -class SparseLabelCrossEntropy(nn.Module): +class SoftTargetCrossEntropy(nn.Module): def __init__(self): - super(SparseLabelCrossEntropy, self).__init__() + super(SoftTargetCrossEntropy, self).__init__() def forward(self, x, target): loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) diff --git a/train.py b/train.py index 38d86992..9a81eecb 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ except ImportError: from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target from models import create_model, resume_checkpoint from utils import * -from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy +from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from optim import create_optimizer from scheduler import create_scheduler @@ -261,7 +261,7 @@ def main(): if args.mixup > 0.: # smoothing is handled with mixup label transform - train_loss_fn = SparseLabelCrossEntropy().cuda() + train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()