diff --git a/data/transforms.py b/data/transforms.py index f1222e2b..90419ae0 100644 --- a/data/transforms.py +++ b/data/transforms.py @@ -20,7 +20,7 @@ def get_model_meanstd(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD - elif 'ception' in model_name: + elif 'ception' in model_name or 'nasnet' in model_name: return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD else: return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -30,7 +30,7 @@ def get_model_mean(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DPN_STD - elif 'ception' in model_name: + elif 'ception' in model_name or 'nasnet' in model_name: return IMAGENET_INCEPTION_MEAN else: return IMAGENET_DEFAULT_MEAN @@ -40,7 +40,7 @@ def get_model_std(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DEFAULT_STD - elif 'ception' in model_name: + elif 'ception' in model_name or 'nasnet' in model_name: return IMAGENET_INCEPTION_STD else: return IMAGENET_DEFAULT_STD diff --git a/models/model_factory.py b/models/model_factory.py index e3d0e2f3..06940496 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -12,6 +12,7 @@ from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d #from .resnext import resnext50, resnext101, resnext152 from .xception import xception +from .pnasnet import pnasnet5large model_config_dict = { 'resnet18': { @@ -48,6 +49,8 @@ model_config_dict = { 'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'}, 'xception': { 'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'}, + 'pnasnet5large': { + 'model_name': 'pnasnet5large', 'num_classes': 1000, 'input_size': 331, 'normalizer': 'le'} } @@ -118,6 +121,8 @@ def create_model( model = resnext152_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'xception': model = xception(num_classes=num_classes, pretrained=pretrained) + elif model_name == 'pnasnet5large': + model = pnasnet5large(num_classes=num_classes, pretrained=pretrained) else: assert False and "Invalid model" diff --git a/models/pnasnet.py b/models/pnasnet.py index c169c695..6aebb772 100644 --- a/models/pnasnet.py +++ b/models/pnasnet.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo - pretrained_settings = { 'pnasnet5large': { 'imagenet': { @@ -292,6 +291,8 @@ class PNASNet5Large(nn.Module): def __init__(self, num_classes=1001): super(PNASNet5Large, self).__init__() self.num_classes = num_classes + self.num_features = 4320 + self.conv_0 = nn.Sequential(OrderedDict([ ('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)), ('bn', nn.BatchNorm2d(96, eps=0.001)) @@ -335,9 +336,20 @@ class PNASNet5Large(nn.Module): self.relu = nn.ReLU() self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0) self.dropout = nn.Dropout(0.5) - self.last_linear = nn.Linear(4320, num_classes) + self.last_linear = nn.Linear(self.num_features, num_classes) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + del self.last_linear + if num_classes: + self.last_linear = nn.Linear(self.num_features, num_classes) + else: + self.last_linear = None - def features(self, x): + def forward_features(self, x, pool=True): x_conv_0 = self.conv_0(x) x_stem_0 = self.cell_stem_0(x_conv_0) x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) @@ -353,19 +365,16 @@ class PNASNet5Large(nn.Module): x_cell_9 = self.cell_9(x_cell_7, x_cell_8) x_cell_10 = self.cell_10(x_cell_8, x_cell_9) x_cell_11 = self.cell_11(x_cell_9, x_cell_10) - return x_cell_11 - - def logits(self, features): - x = self.relu(features) - x = self.avg_pool(x) - x = x.view(x.size(0), -1) - x = self.dropout(x) - x = self.last_linear(x) + x = self.relu(x_cell_11) + if pool: + x = self.avg_pool(x) + x = x.view(x.size(0), -1) return x def forward(self, input): - x = self.features(input) - x = self.logits(x) + x = self.forward_features(input) + x = self.dropout(x) + x = self.last_linear(x) return x @@ -375,7 +384,7 @@ def pnasnet5large(num_classes=1001, pretrained='imagenet'): `_ paper. """ if pretrained: - settings = pretrained_settings['pnasnet5large'][pretrained] + settings = pretrained_settings['pnasnet5large']['imagenet'] assert num_classes == settings[ 'num_classes'], 'num_classes should be {}, but is {}'.format( settings['num_classes'], num_classes) @@ -384,18 +393,12 @@ def pnasnet5large(num_classes=1001, pretrained='imagenet'): model = PNASNet5Large(num_classes=1001) model.load_state_dict(model_zoo.load_url(settings['url'])) - if pretrained == 'imagenet': - new_last_linear = nn.Linear(model.last_linear.in_features, 1000) - new_last_linear.weight.data = model.last_linear.weight.data[1:] - new_last_linear.bias.data = model.last_linear.bias.data[1:] - model.last_linear = new_last_linear - - model.input_space = settings['input_space'] - model.input_size = settings['input_size'] - model.input_range = settings['input_range'] + #if pretrained == 'imagenet': + new_last_linear = nn.Linear(model.last_linear.in_features, 1000) + new_last_linear.weight.data = model.last_linear.weight.data[1:] + new_last_linear.bias.data = model.last_linear.bias.data[1:] + model.last_linear = new_last_linear - model.mean = settings['mean'] - model.std = settings['std'] else: model = PNASNet5Large(num_classes=num_classes) return model diff --git a/models/xception.py b/models/xception.py index c4ae09fa..97b3947d 100644 --- a/models/xception.py +++ b/models/xception.py @@ -127,6 +127,7 @@ class Xception(nn.Module): """ super(Xception, self).__init__() self.num_classes = num_classes + self.num_features = 2048 self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) @@ -156,10 +157,10 @@ class Xception(nn.Module): self.bn3 = nn.BatchNorm2d(1536) # do relu here - self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) - self.bn4 = nn.BatchNorm2d(2048) + self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(self.num_features) - self.fc = nn.Linear(2048, num_classes) + self.fc = nn.Linear(self.num_features, num_classes) # #------- init weights -------- for m in self.modules(): @@ -169,7 +170,18 @@ class Xception(nn.Module): m.weight.data.fill_(1) m.bias.data.zero_() - def forward_features(self, input): + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + del self.fc + if num_classes: + self.fc = nn.Linear(self.num_features, num_classes) + else: + self.fc = None + + def forward_features(self, input, pool=True): x = self.conv1(input) x = self.bn1(x) x = self.relu(x) @@ -197,19 +209,16 @@ class Xception(nn.Module): x = self.conv4(x) x = self.bn4(x) - return x - - def logits(self, features): - x = self.relu(features) + x = self.relu(x) - x = F.adaptive_avg_pool2d(x, (1, 1)) - x = x.view(x.size(0), -1) - x = self.last_linear(x) + if pool: + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) return x def forward(self, input): x = self.forward_features(input) - x = self.logits(x) + x = self.fc(x) return x @@ -223,13 +232,4 @@ def xception(num_classes=1000, pretrained=False): model = Xception(num_classes=num_classes) model.load_state_dict(model_zoo.load_url(config['url'])) - model.input_space = config['input_space'] - model.input_size = config['input_size'] - model.input_range = config['input_range'] - model.mean = config['mean'] - model.std = config['std'] - - # TODO: ugly - model.last_linear = model.fc - del model.fc return model diff --git a/train.py b/train.py index 89632d39..5494b3ff 100644 --- a/train.py +++ b/train.py @@ -93,6 +93,8 @@ parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA amp for mixed precision training') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') +parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC', + help='Best metric (default: "prec1"') parser.add_argument("--local_rank", default=0, type=int) @@ -238,10 +240,13 @@ def main(): if args.local_rank == 0: print('Scheduled epochs: ', num_epochs) + eval_metric = args.eval_metric saver = None if output_dir: - saver = CheckpointSaver(checkpoint_dir=output_dir) - best_loss = None + decreasing = True if eval_metric == 'loss' else False + saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) + best_metric = None + best_epoch = None try: for epoch in range(start_epoch, num_epochs): if args.distributed: @@ -255,15 +260,15 @@ def main(): model, loader_eval, validate_loss_fn, args) if lr_scheduler is not None: - lr_scheduler.step(epoch, eval_metrics['eval_loss']) + lr_scheduler.step(epoch, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_loss is None) + write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric - best_loss = saver.save_checkpoint({ + best_metric, best_epoch = saver.save_checkpoint({ 'epoch': epoch + 1, 'arch': args.model, 'state_dict': model.state_dict(), @@ -271,11 +276,12 @@ def main(): 'args': args, }, epoch=epoch + 1, - metric=eval_metrics['eval_loss']) + metric=eval_metrics[eval_metric]) except KeyboardInterrupt: pass - print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0])) + if best_metric is not None: + print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) def train_epoch( @@ -363,7 +369,7 @@ def train_epoch( end = time.time() - return OrderedDict([('train_loss', losses_m.avg)]) + return OrderedDict([('loss', losses_m.avg)]) def validate(model, loader, loss_fn, args): @@ -418,7 +424,7 @@ def validate(model, loader, loss_fn, args): batch_time=batch_time_m, loss=losses_m, top1=prec1_m, top5=prec5_m)) - metrics = OrderedDict([('eval_loss', losses_m.avg), ('eval_prec1', prec1_m.avg)]) + metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)]) return metrics diff --git a/utils.py b/utils.py index 4604a258..f206f945 100644 --- a/utils.py +++ b/utils.py @@ -6,6 +6,7 @@ import os import shutil import glob import csv +import operator from collections import OrderedDict @@ -16,24 +17,32 @@ class CheckpointSaver: recovery_prefix='recovery', checkpoint_dir='', recovery_dir='', + decreasing=False, + verbose=True, max_history=10): - self.checkpoint_files = [] + # state + self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness + self.best_epoch = None self.best_metric = None - self.worst_metric = None - self.max_history = max_history - assert self.max_history >= 1 self.curr_recovery_file = '' self.last_recovery_file = '' + + # config self.checkpoint_dir = checkpoint_dir self.recovery_dir = recovery_dir self.save_prefix = checkpoint_prefix self.recovery_prefix = recovery_prefix self.extension = '.pth.tar' + self.decreasing = decreasing # a lower metric is better if True + self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs + self.verbose = verbose + self.max_history = max_history + assert self.max_history >= 1 def save_checkpoint(self, state, epoch, metric=None): - worst_metric = self.checkpoint_files[-1] if self.checkpoint_files else None - if len(self.checkpoint_files) < self.max_history or metric < worst_metric[1]: + worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None + if len(self.checkpoint_files) < self.max_history or self.cmp(metric, worst_file[1]): if len(self.checkpoint_files) >= self.max_history: self._cleanup_checkpoints(1) @@ -43,16 +52,21 @@ class CheckpointSaver: state['metric'] = metric torch.save(state, save_path) self.checkpoint_files.append((save_path, metric)) - self.checkpoint_files = sorted(self.checkpoint_files, key=lambda x: x[1]) - - print("Current checkpoints:") - for c in self.checkpoint_files: - print(c) - - if metric is not None and (self.best_metric is None or metric < self.best_metric[1]): - self.best_metric = (epoch, metric) + self.checkpoint_files = sorted( + self.checkpoint_files, key=lambda x: x[1], + reverse=not self.decreasing) # sort in descending order if a lower metric is not better + + if self.verbose: + print("Current checkpoints:") + for c in self.checkpoint_files: + print(c) + + if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): + self.best_epoch = epoch + self.best_metric = metric shutil.copyfile(save_path, os.path.join(self.checkpoint_dir, 'model_best' + self.extension)) - return None, None if self.best_metric is None else self.best_metric + + return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) def _cleanup_checkpoints(self, trim=0): trim = min(len(self.checkpoint_files), trim) @@ -62,7 +76,8 @@ class CheckpointSaver: to_delete = self.checkpoint_files[delete_index:] for d in to_delete: try: - print('Cleaning checkpoint: ', d) + if self.verbose: + print('Cleaning checkpoint: ', d) os.remove(d[0]) except Exception as e: print('Exception (%s) while deleting checkpoint' % str(e)) @@ -74,7 +89,8 @@ class CheckpointSaver: torch.save(state, save_path) if os.path.exists(self.last_recovery_file): try: - print('Cleaning recovery', self.last_recovery_file) + if self.verbose: + print('Cleaning recovery', self.last_recovery_file) os.remove(self.last_recovery_file) except Exception as e: print("Exception (%s) while removing %s" % (str(e), self.last_recovery_file)) @@ -143,8 +159,8 @@ def get_outdir(path, *paths, inc=False): def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False): rowd = OrderedDict(epoch=epoch) - rowd.update(train_metrics) - rowd.update(eval_metrics) + rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) + rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) with open(filename, mode='a') as cf: dw = csv.DictWriter(cf, fieldnames=rowd.keys()) if write_header: # first iteration (epoch == 1 can't be used)