|
|
@ -13,7 +13,7 @@ except ImportError:
|
|
|
|
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
|
|
|
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
|
|
|
from models import create_model, resume_checkpoint
|
|
|
|
from models import create_model, resume_checkpoint
|
|
|
|
from utils import *
|
|
|
|
from utils import *
|
|
|
|
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
|
|
|
|
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
|
|
|
from optim import create_optimizer
|
|
|
|
from optim import create_optimizer
|
|
|
|
from scheduler import create_scheduler
|
|
|
|
from scheduler import create_scheduler
|
|
|
|
|
|
|
|
|
|
|
@ -261,7 +261,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
if args.mixup > 0.:
|
|
|
|
if args.mixup > 0.:
|
|
|
|
# smoothing is handled with mixup label transform
|
|
|
|
# smoothing is handled with mixup label transform
|
|
|
|
train_loss_fn = SparseLabelCrossEntropy().cuda()
|
|
|
|
train_loss_fn = SoftTargetCrossEntropy().cuda()
|
|
|
|
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
|
|
|
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
|
|
|
elif args.smoothing:
|
|
|
|
elif args.smoothing:
|
|
|
|
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
|
|
|
|
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
|
|
|
|