More appropriate/correct loss name

pull/6/head
Ross Wightman 5 years ago
parent 99ab1b1276
commit e6c14427c0

@ -1 +1 @@
from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy from loss.cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

@ -26,10 +26,10 @@ class LabelSmoothingCrossEntropy(nn.Module):
return loss.mean() return loss.mean()
class SparseLabelCrossEntropy(nn.Module): class SoftTargetCrossEntropy(nn.Module):
def __init__(self): def __init__(self):
super(SparseLabelCrossEntropy, self).__init__() super(SoftTargetCrossEntropy, self).__init__()
def forward(self, x, target): def forward(self, x, target):
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)

@ -13,7 +13,7 @@ except ImportError:
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
from models import create_model, resume_checkpoint from models import create_model, resume_checkpoint
from utils import * from utils import *
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from optim import create_optimizer from optim import create_optimizer
from scheduler import create_scheduler from scheduler import create_scheduler
@ -261,7 +261,7 @@ def main():
if args.mixup > 0.: if args.mixup > 0.:
# smoothing is handled with mixup label transform # smoothing is handled with mixup label transform
train_loss_fn = SparseLabelCrossEntropy().cuda() train_loss_fn = SoftTargetCrossEntropy().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.smoothing: elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()

Loading…
Cancel
Save