@ -10,7 +10,7 @@ try:
except ImportError :
except ImportError :
has_apex = False
has_apex = False
from data import Dataset , create_loader , resolve_data_config
from data import Dataset , create_loader , resolve_data_config , FastCollateMixup , mixup_target
from models import create_model , resume_checkpoint
from models import create_model , resume_checkpoint
from utils import *
from utils import *
from loss import LabelSmoothingCrossEntropy , SparseLabelCrossEntropy
from loss import LabelSmoothingCrossEntropy , SparseLabelCrossEntropy
@ -66,9 +66,9 @@ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RA
parser . add_argument ( ' --sched ' , default = ' step ' , type = str , metavar = ' SCHEDULER ' ,
parser . add_argument ( ' --sched ' , default = ' step ' , type = str , metavar = ' SCHEDULER ' ,
help = ' LR scheduler (default: " step " ' )
help = ' LR scheduler (default: " step " ' )
parser . add_argument ( ' --drop ' , type = float , default = 0.0 , metavar = ' DROP ' ,
parser . add_argument ( ' --drop ' , type = float , default = 0.0 , metavar = ' DROP ' ,
help = ' Dropout rate (default: 0. 1 )' )
help = ' Dropout rate (default: 0. )' )
parser . add_argument ( ' --reprob ' , type = float , default = 0. 4 , metavar = ' PCT ' ,
parser . add_argument ( ' --reprob ' , type = float , default = 0. , metavar = ' PCT ' ,
help = ' Random erase prob (default: 0. 4 )' )
help = ' Random erase prob (default: 0. )' )
parser . add_argument ( ' --remode ' , type = str , default = ' const ' ,
parser . add_argument ( ' --remode ' , type = str , default = ' const ' ,
help = ' Random erase mode (default: " const " ) ' )
help = ' Random erase mode (default: " const " ) ' )
parser . add_argument ( ' --lr ' , type = float , default = 0.01 , metavar = ' LR ' ,
parser . add_argument ( ' --lr ' , type = float , default = 0.01 , metavar = ' LR ' ,
@ -109,6 +109,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
help = ' save images of input bathes every log interval for debugging ' )
help = ' save images of input bathes every log interval for debugging ' )
parser . add_argument ( ' --amp ' , action = ' store_true ' , default = False ,
parser . add_argument ( ' --amp ' , action = ' store_true ' , default = False ,
help = ' use NVIDIA amp for mixed precision training ' )
help = ' use NVIDIA amp for mixed precision training ' )
parser . add_argument ( ' --no-prefetcher ' , action = ' store_true ' , default = False ,
help = ' disable fast prefetcher ' )
parser . add_argument ( ' --output ' , default = ' ' , type = str , metavar = ' PATH ' ,
parser . add_argument ( ' --output ' , default = ' ' , type = str , metavar = ' PATH ' ,
help = ' path to output folder (default: none, current dir) ' )
help = ' path to output folder (default: none, current dir) ' )
parser . add_argument ( ' --eval-metric ' , default = ' prec1 ' , type = str , metavar = ' EVAL_METRIC ' ,
parser . add_argument ( ' --eval-metric ' , default = ' prec1 ' , type = str , metavar = ' EVAL_METRIC ' ,
@ -119,6 +121,7 @@ parser.add_argument("--local_rank", default=0, type=int)
def main ( ) :
def main ( ) :
args = parser . parse_args ( )
args = parser . parse_args ( )
args . prefetcher = not args . no_prefetcher
args . distributed = False
args . distributed = False
if ' WORLD_SIZE ' in os . environ :
if ' WORLD_SIZE ' in os . environ :
args . distributed = int ( os . environ [ ' WORLD_SIZE ' ] ) > 1
args . distributed = int ( os . environ [ ' WORLD_SIZE ' ] ) > 1
@ -130,6 +133,7 @@ def main():
args . world_size = 1
args . world_size = 1
r = - 1
r = - 1
if args . distributed :
if args . distributed :
args . num_gpu = 1
args . device = ' cuda: %d ' % args . local_rank
args . device = ' cuda: %d ' % args . local_rank
torch . cuda . set_device ( args . local_rank )
torch . cuda . set_device ( args . local_rank )
torch . distributed . init_process_group ( backend = ' nccl ' ,
torch . distributed . init_process_group ( backend = ' nccl ' ,
@ -216,12 +220,16 @@ def main():
exit ( 1 )
exit ( 1 )
dataset_train = Dataset ( train_dir )
dataset_train = Dataset ( train_dir )
collate_fn = None
if args . prefetcher and args . mixup > 0 :
collate_fn = FastCollateMixup ( args . mixup , args . smoothing , args . num_classes )
loader_train = create_loader (
loader_train = create_loader (
dataset_train ,
dataset_train ,
input_size = data_config [ ' input_size ' ] ,
input_size = data_config [ ' input_size ' ] ,
batch_size = args . batch_size ,
batch_size = args . batch_size ,
is_training = True ,
is_training = True ,
use_prefetcher = True ,
use_prefetcher = args . prefetcher ,
rand_erase_prob = args . reprob ,
rand_erase_prob = args . reprob ,
rand_erase_mode = args . remode ,
rand_erase_mode = args . remode ,
interpolation = ' random ' , # FIXME cleanly resolve this? data_config['interpolation'],
interpolation = ' random ' , # FIXME cleanly resolve this? data_config['interpolation'],
@ -229,6 +237,7 @@ def main():
std = data_config [ ' std ' ] ,
std = data_config [ ' std ' ] ,
num_workers = args . workers ,
num_workers = args . workers ,
distributed = args . distributed ,
distributed = args . distributed ,
collate_fn = collate_fn ,
)
)
eval_dir = os . path . join ( args . data , ' validation ' )
eval_dir = os . path . join ( args . data , ' validation ' )
@ -242,7 +251,7 @@ def main():
input_size = data_config [ ' input_size ' ] ,
input_size = data_config [ ' input_size ' ] ,
batch_size = 4 * args . batch_size ,
batch_size = 4 * args . batch_size ,
is_training = False ,
is_training = False ,
use_prefetcher = True ,
use_prefetcher = args . prefetcher ,
interpolation = data_config [ ' interpolation ' ] ,
interpolation = data_config [ ' interpolation ' ] ,
mean = data_config [ ' mean ' ] ,
mean = data_config [ ' mean ' ] ,
std = data_config [ ' std ' ] ,
std = data_config [ ' std ' ] ,
@ -309,6 +318,10 @@ def train_epoch(
epoch , model , loader , optimizer , loss_fn , args ,
epoch , model , loader , optimizer , loss_fn , args ,
lr_scheduler = None , saver = None , output_dir = ' ' , use_amp = False ) :
lr_scheduler = None , saver = None , output_dir = ' ' , use_amp = False ) :
if args . prefetcher and args . mixup > 0 and loader . mixup_enabled :
if args . mixup_off_epoch and epoch > = args . mixup_off_epoch :
loader . mixup_enabled = False
batch_time_m = AverageMeter ( )
batch_time_m = AverageMeter ( )
data_time_m = AverageMeter ( )
data_time_m = AverageMeter ( )
losses_m = AverageMeter ( )
losses_m = AverageMeter ( )
@ -321,13 +334,15 @@ def train_epoch(
for batch_idx , ( input , target ) in enumerate ( loader ) :
for batch_idx , ( input , target ) in enumerate ( loader ) :
last_batch = batch_idx == last_idx
last_batch = batch_idx == last_idx
data_time_m . update ( time . time ( ) - end )
data_time_m . update ( time . time ( ) - end )
if not args . prefetcher :
if args . mixup > 0. :
input = input . cuda ( )
lam = 1.
target = target . cuda ( )
if not args . mixup_off_epoch or epoch < args . mixup_off_epoch :
if args . mixup > 0. :
lam = np . random . beta ( args . mixup , args . mixup )
lam = 1.
input . mul_ ( lam ) . add_ ( 1 - lam , input . flip ( 0 ) )
if not args . mixup_off_epoch or epoch < args . mixup_off_epoch :
target = mixup_target ( target , args . num_classes , lam , args . smoothing )
lam = np . random . beta ( args . mixup , args . mixup )
input . mul_ ( lam ) . add_ ( 1 - lam , input . flip ( 0 ) )
target = mixup_target ( target , args . num_classes , lam , args . smoothing )
output = model ( input )
output = model ( input )