Refactor device handling in scripts, distributed init to be less 'cuda' centric. More device args passed through where needed.

pull/1479/head
Ross Wightman 2 years ago
parent c88947ad3d
commit 87939e6fab

@ -57,6 +57,8 @@ except ImportError as e:
has_functorch = False has_functorch = False
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate') _logger = logging.getLogger('validate')
@ -216,7 +218,7 @@ class BenchmarkRunner:
self.device = device self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False) self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress
if fuser: if fuser:
set_jit_fuser(fuser) set_jit_fuser(fuser)

@ -2,11 +2,11 @@
Hacked together by / Copyright 2019, Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import torch.utils.data as data import io
import os
import torch
import logging import logging
import torch
import torch.utils.data as data
from PIL import Image from PIL import Image
from .parsers import create_parser from .parsers import create_parser
@ -23,23 +23,32 @@ class ImageDataset(data.Dataset):
self, self,
root, root,
parser=None, parser=None,
split='train',
class_map=None, class_map=None,
load_bytes=False, load_bytes=False,
img_mode='RGB',
transform=None, transform=None,
target_transform=None, target_transform=None,
): ):
if parser is None or isinstance(parser, str): if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map) parser = create_parser(
parser or '',
root=root,
split=split,
class_map=class_map
)
self.parser = parser self.parser = parser
self.load_bytes = load_bytes self.load_bytes = load_bytes
self.img_mode = img_mode
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self._consecutive_errors = 0 self._consecutive_errors = 0
def __getitem__(self, index): def __getitem__(self, index):
img, target = self.parser[index] img, target = self.parser[index]
try: try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB') img = img.read() if self.load_bytes else Image.open(img)
except Exception as e: except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}') _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1 self._consecutive_errors += 1
@ -48,12 +57,17 @@ class ImageDataset(data.Dataset):
else: else:
raise e raise e
self._consecutive_errors = 0 self._consecutive_errors = 0
if self.img_mode and not self.load_bytes:
img = img.convert(self.img_mode)
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if target is None: if target is None:
target = -1 target = -1
elif self.target_transform is not None: elif self.target_transform is not None:
target = self.target_transform(target) target = self.target_transform(target)
return img, target return img, target
def __len__(self): def __len__(self):
@ -83,8 +97,14 @@ class IterableImageDataset(data.IterableDataset):
assert parser is not None assert parser is not None
if isinstance(parser, str): if isinstance(parser, str):
self.parser = create_parser( self.parser = create_parser(
parser, root=root, split=split, is_training=is_training, parser,
batch_size=batch_size, repeats=repeats, download=download) root=root,
split=split,
is_training=is_training,
batch_size=batch_size,
repeats=repeats,
download=download,
)
else: else:
self.parser = parser self.parser = parser
self.transform = transform self.transform = transform

@ -134,6 +134,10 @@ def create_dataset(
ds = IterableImageDataset( ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, root, parser=name, split=split, is_training=is_training,
download=download, batch_size=batch_size, repeats=repeats, **kwargs) download=download, batch_size=batch_size, repeats=repeats, **kwargs)
elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, parser=name, split=split, **kwargs)
else: else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
if search_split and os.path.isdir(root): if search_split and os.path.isdir(root):

@ -6,10 +6,12 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d
Hacked together by / Copyright 2019, Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import random import random
from contextlib import suppress
from functools import partial from functools import partial
from itertools import repeat from itertools import repeat
from typing import Callable from typing import Callable
import torch
import torch.utils.data import torch.utils.data
import numpy as np import numpy as np
@ -73,6 +75,8 @@ class PrefetchLoader:
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
channels=3, channels=3,
device=torch.device('cuda'),
img_dtype=torch.float32,
fp16=False, fp16=False,
re_prob=0., re_prob=0.,
re_mode='const', re_mode='const',
@ -84,30 +88,42 @@ class PrefetchLoader:
normalization_shape = (1, channels, 1, 1) normalization_shape = (1, channels, 1, 1)
self.loader = loader self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape) self.device = device
self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)
self.fp16 = fp16
if fp16: if fp16:
self.mean = self.mean.half() # fp16 arg is deprecated, but will override dtype arg if set for bwd compat
self.std = self.std.half() img_dtype = torch.float16
self.img_dtype = img_dtype
self.mean = torch.tensor(
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
self.std = torch.tensor(
[x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
if re_prob > 0.: if re_prob > 0.:
self.random_erasing = RandomErasing( self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) probability=re_prob,
mode=re_mode,
max_count=re_count,
num_splits=re_num_splits,
device=device,
)
else: else:
self.random_erasing = None self.random_erasing = None
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
def __iter__(self): def __iter__(self):
stream = torch.cuda.Stream()
first = True first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
else:
stream = None
stream_context = suppress
for next_input, next_target in self.loader: for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True) with stream_context():
next_target = next_target.cuda(non_blocking=True) next_input = next_input.to(device=self.device, non_blocking=True)
if self.fp16: next_target = next_target.to(device=self.device, non_blocking=True)
next_input = next_input.half().sub_(self.mean).div_(self.std) next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
else:
next_input = next_input.float().sub_(self.mean).div_(self.std)
if self.random_erasing is not None: if self.random_erasing is not None:
next_input = self.random_erasing(next_input) next_input = self.random_erasing(next_input)
@ -116,7 +132,9 @@ class PrefetchLoader:
else: else:
first = False first = False
if stream is not None:
torch.cuda.current_stream().wait_stream(stream) torch.cuda.current_stream().wait_stream(stream)
input = next_input input = next_input
target = next_target target = next_target
@ -189,7 +207,9 @@ def create_loader(
crop_pct=None, crop_pct=None,
collate_fn=None, collate_fn=None,
pin_memory=False, pin_memory=False,
fp16=False, fp16=False, # deprecated, use img_dtype
img_dtype=torch.float32,
device=torch.device('cuda'),
tf_preprocessing=False, tf_preprocessing=False,
use_multi_epochs_loader=False, use_multi_epochs_loader=False,
persistent_workers=True, persistent_workers=True,
@ -266,7 +286,9 @@ def create_loader(
mean=mean, mean=mean,
std=std, std=std,
channels=input_size[0], channels=input_size[0],
fp16=fp16, device=device,
fp16=fp16, # deprecated, use img_dtype
img_dtype=img_dtype,
re_prob=prefetch_re_prob, re_prob=prefetch_re_prob,
re_mode=re_mode, re_mode=re_mode,
re_count=re_count, re_count=re_count,

@ -17,6 +17,9 @@ def create_parser(name, root, split='train', **kwargs):
if prefix == 'tfds': if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, **kwargs) parser = ParserTfds(root, name, split=split, **kwargs)
elif prefix == 'hfds':
from .parser_hfds import ParserHfds # defer tensorflow import
parser = ParserHfds(root, name, split=split, **kwargs)
else: else:
assert os.path.exists(root) assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder

@ -86,9 +86,9 @@ class ParserTfds(Parser):
repeats=0, repeats=0,
seed=42, seed=42,
input_name='image', input_name='image',
input_image='RGB', input_img_mode='RGB',
target_name='label', target_name='label',
target_image='', target_img_mode='',
prefetch_size=None, prefetch_size=None,
shuffle_size=None, shuffle_size=None,
max_threadpool_size=None max_threadpool_size=None
@ -105,9 +105,9 @@ class ParserTfds(Parser):
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances seed: common seed for shard shuffle across all distributed/worker instances
input_name: name of Feature to return as data (input) input_name: name of Feature to return as data (input)
input_image: image mode if input is an image (currently PIL mode string) input_img_mode: image mode if input is an image (currently PIL mode string)
target_name: name of Feature to return as target (label) target_name: name of Feature to return as target (label)
target_image: image mode if target is an image (currently PIL mode string) target_img_mode: image mode if target is an image (currently PIL mode string)
prefetch_size: override default tf.data prefetch buffer size prefetch_size: override default tf.data prefetch buffer size
shuffle_size: override default tf.data shuffle buffer size shuffle_size: override default tf.data shuffle buffer size
max_threadpool_size: override default threadpool size for tf.data max_threadpool_size: override default threadpool size for tf.data
@ -130,9 +130,9 @@ class ParserTfds(Parser):
# TFDS builder and split information # TFDS builder and split information
self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
self.input_image = input_image self.input_img_mode = input_img_mode
self.target_name = target_name self.target_name = target_name
self.target_image = target_image self.target_img_mode = target_img_mode
self.builder = tfds.builder(name, data_dir=root) self.builder = tfds.builder(name, data_dir=root)
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download: if download:
@ -249,11 +249,11 @@ class ParserTfds(Parser):
example_count = 0 example_count = 0
for example in self.ds: for example in self.ds:
input_data = example[self.input_name] input_data = example[self.input_name]
if self.input_image: if self.input_img_mode:
input_data = Image.fromarray(input_data, mode=self.input_image) input_data = Image.fromarray(input_data, mode=self.input_img_mode)
target_data = example[self.target_name] target_data = example[self.target_name]
if self.target_image: if self.target_img_mode:
target_data = Image.fromarray(target_data, mode=self.target_image) target_data = Image.fromarray(target_data, mode=self.target_img_mode)
yield input_data, target_data yield input_data, target_data
example_count += 1 example_count += 1
if self.is_training and example_count >= target_example_count: if self.is_training and example_count >= target_example_count:

@ -7,6 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman
""" """
import random import random
import math import math
import torch import torch
@ -44,8 +45,17 @@ class RandomErasing:
def __init__( def __init__(
self, self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, probability=0.5,
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): min_area=0.02,
max_area=1/3,
min_aspect=0.3,
max_aspect=None,
mode='const',
min_count=1,
max_count=None,
num_splits=0,
device='cuda',
):
self.probability = probability self.probability = probability
self.min_area = min_area self.min_area = min_area
self.max_area = max_area self.max_area = max_area
@ -81,8 +91,12 @@ class RandomErasing:
top = random.randint(0, img_h - h) top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w) left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels( img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w), self.per_pixel,
dtype=dtype, device=self.device) self.rand_color,
(chan, h, w),
dtype=dtype,
device=self.device,
)
break break
def __call__(self, input): def __call__(self, input):

@ -3,7 +3,8 @@ from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler from .cuda import ApexScaler, NativeScaler
from .decay_batch import decay_batch_step, check_batch_size_retry from .decay_batch import decay_batch_step, check_batch_size_retry
from .distributed import distribute_bn, reduce_tensor from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
world_info_from_env, is_distributed_env, is_primary
from .jit import set_jit_legacy, set_jit_fuser from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy from .metrics import AverageMeter, accuracy

@ -2,9 +2,16 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import os
import torch import torch
from torch import distributed as dist from torch import distributed as dist
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from .model import unwrap_model from .model import unwrap_model
@ -26,3 +33,105 @@ def distribute_bn(model, world_size, reduce=False):
else: else:
# broadcast bn stats from rank 0 to whole group # broadcast bn stats from rank 0 to whole group
torch.distributed.broadcast(bn_buf, 0) torch.distributed.broadcast(bn_buf, 0)
def is_global_primary(args):
return args.rank == 0
def is_local_primary(args):
return args.local_rank == 0
def is_primary(args, local=False):
return is_local_primary(args) if local else is_global_primary(args)
def is_distributed_env():
if 'WORLD_SIZE' in os.environ:
return int(os.environ['WORLD_SIZE']) > 1
if 'SLURM_NTASKS' in os.environ:
return int(os.environ['SLURM_NTASKS']) > 1
return False
def world_info_from_env():
local_rank = 0
for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
if v in os.environ:
world_size = int(os.environ[v])
break
return local_rank, global_rank, world_size
def init_distributed_device(args):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args.distributed = False
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
# TBD, support horovod?
# if args.horovod:
# assert hvd is not None, "Horovod is not installed"
# hvd.init()
# args.local_rank = int(hvd.local_rank())
# args.rank = hvd.rank()
# args.world_size = hvd.size()
# args.distributed = True
# os.environ['LOCAL_RANK'] = str(args.local_rank)
# os.environ['RANK'] = str(args.rank)
# os.environ['WORLD_SIZE'] = str(args.world_size)
dist_backend = getattr(args, 'dist_backend', 'nccl')
dist_url = getattr(args, 'dist_url', 'env://')
if is_distributed_env():
if 'SLURM_PROCID' in os.environ:
# DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ['LOCAL_RANK'] = str(args.local_rank)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
torch.distributed.init_process_group(
backend=dist_backend,
init_method=dist_url,
world_size=args.world_size,
rank=args.rank,
)
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=dist_backend,
init_method=dist_url,
)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
if torch.cuda.is_available():
if args.distributed:
device = 'cuda:%d' % args.local_rank
else:
device = 'cuda:0'
torch.cuda.set_device(device)
else:
device = 'cpu'
args.device = device
device = torch.device(device)
return device

@ -21,6 +21,7 @@ import time
from collections import OrderedDict from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -66,7 +67,6 @@ except ImportError as e:
has_functorch = False has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train') _logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to # The first arg parser parses out only the --config argument, this argument is used to
@ -349,32 +349,26 @@ def main():
utils.setup_default_logging() utils.setup_default_logging()
args, args_text = _parse_args() args, args_text = _parse_args()
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
args.distributed = False device = utils.init_distributed_device(args)
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed: if args.distributed:
if 'LOCAL_RANK' in os.environ: _logger.info(
args.local_rank = int(os.getenv('LOCAL_RANK')) 'Training in distributed mode with multiple processes, 1 device per process.'
args.device = 'cuda:%d' % args.local_rank f'Process {args.rank}, total {args.world_size}, device {args.device}.')
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else: else:
_logger.info('Training with a single process on 1 GPUs.') _logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0 assert args.rank >= 0
if args.rank == 0 and args.log_wandb: if utils.is_primary(args) and args.log_wandb:
if has_wandb: if has_wandb:
wandb.init(project=args.experiment, config=args) wandb.init(project=args.experiment, config=args)
else: else:
_logger.warning("You've requested to log metrics to wandb but package not found. " _logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`") "Metrics not being logged to wandb, try `pip install wandb`")
# resolve AMP arguments based on PyTorch / Apex availability # resolve AMP arguments based on PyTorch / Apex availability
@ -405,14 +399,14 @@ def main():
pretrained=args.pretrained, pretrained=args.pretrained,
num_classes=args.num_classes, num_classes=args.num_classes,
drop_rate=args.drop, drop_rate=args.drop,
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
drop_path_rate=args.drop_path, drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block, drop_block_rate=args.drop_block,
global_pool=args.gp, global_pool=args.gp,
bn_momentum=args.bn_momentum, bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps, bn_eps=args.bn_eps,
scriptable=args.torchscript, scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint,
)
if args.num_classes is None: if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
@ -420,11 +414,11 @@ def main():
if args.grad_checkpointing: if args.grad_checkpointing:
model.set_grad_checkpointing(enable=True) model.set_grad_checkpointing(enable=True)
if args.local_rank == 0: if utils.is_primary(args):
_logger.info( _logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
# setup augmentation batch splits for contrastive loss or split bn # setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0 num_aug_splits = 0
@ -438,9 +432,9 @@ def main():
model = convert_splitbn_model(model, max(num_aug_splits, 2)) model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set # move model to GPU, enable channels last layout if set
model.cuda() model.to(device=device)
if args.channels_last: if args.channels_last:
model = model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training # setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
@ -452,7 +446,7 @@ def main():
model = convert_syncbn_model(model) model = convert_syncbn_model(model)
else: else:
model = convert_sync_batchnorm(model) model = convert_sync_batchnorm(model)
if args.local_rank == 0: if utils.is_primary(args):
_logger.info( _logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
@ -461,6 +455,7 @@ def main():
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model) model = torch.jit.script(model)
if args.aot_autograd: if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)
@ -471,28 +466,31 @@ def main():
amp_autocast = suppress # do nothing amp_autocast = suppress # do nothing
loss_scaler = None loss_scaler = None
if use_amp == 'apex': if use_amp == 'apex':
assert device.type == 'cuda'
model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler() loss_scaler = ApexScaler()
if args.local_rank == 0: if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native': elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast amp_autocast = partial(torch.autocast, device_type=device.type)
if device.type == 'cuda':
loss_scaler = NativeScaler() loss_scaler = NativeScaler()
if args.local_rank == 0: if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.') _logger.info('Using native Torch AMP. Training in mixed precision.')
else: else:
if args.local_rank == 0: if utils.is_primary(args):
_logger.info('AMP not enabled. Training in float32.') _logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint # optionally resume from a checkpoint
resume_epoch = None resume_epoch = None
if args.resume: if args.resume:
resume_epoch = resume_checkpoint( resume_epoch = resume_checkpoint(
model, args.resume, model,
args.resume,
optimizer=None if args.no_resume_opt else optimizer, optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler, loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0) log_info=utils.is_primary(args),
)
# setup exponential moving average of model weights, SWA could be used here too # setup exponential moving average of model weights, SWA could be used here too
model_ema = None model_ema = None
@ -507,13 +505,13 @@ def main():
if args.distributed: if args.distributed:
if has_apex and use_amp == 'apex': if has_apex and use_amp == 'apex':
# Apex DDP preferred unless native amp is activated # Apex DDP preferred unless native amp is activated
if args.local_rank == 0: if utils.is_primary(args):
_logger.info("Using NVIDIA APEX DistributedDataParallel.") _logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True) model = ApexDDP(model, delay_allreduce=True)
else: else:
if args.local_rank == 0: if utils.is_primary(args):
_logger.info("Using native Torch DistributedDataParallel.") _logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP # NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch # setup learning rate schedule and starting epoch
@ -527,21 +525,30 @@ def main():
if lr_scheduler is not None and start_epoch > 0: if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch) lr_scheduler.step(start_epoch)
if args.local_rank == 0: if utils.is_primary(args):
_logger.info('Scheduled epochs: {}'.format(num_epochs)) _logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets # create the train and eval datasets
dataset_train = create_dataset( dataset_train = create_dataset(
args.dataset, root=args.data_dir, split=args.train_split, is_training=True, args.dataset,
root=args.data_dir,
split=args.train_split,
is_training=True,
class_map=args.class_map, class_map=args.class_map,
download=args.dataset_download, download=args.dataset_download,
batch_size=args.batch_size, batch_size=args.batch_size,
repeats=args.epoch_repeats) repeats=args.epoch_repeats
)
dataset_eval = create_dataset( dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, args.dataset,
root=args.data_dir,
split=args.val_split,
is_training=False,
class_map=args.class_map, class_map=args.class_map,
download=args.dataset_download, download=args.dataset_download,
batch_size=args.batch_size) batch_size=args.batch_size
)
# setup mixup / cutmix # setup mixup / cutmix
collate_fn = None collate_fn = None
@ -549,9 +556,15 @@ def main():
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active: if mixup_active:
mixup_args = dict( mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, mixup_alpha=args.mixup,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, cutmix_alpha=args.cutmix,
label_smoothing=args.smoothing, num_classes=args.num_classes) cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob,
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.smoothing,
num_classes=args.num_classes
)
if args.prefetcher: if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args) collate_fn = FastCollateMixup(**mixup_args)
@ -592,6 +605,7 @@ def main():
distributed=args.distributed, distributed=args.distributed,
collate_fn=collate_fn, collate_fn=collate_fn,
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
device=device,
use_multi_epochs_loader=args.use_multi_epochs_loader, use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding, worker_seeding=args.worker_seeding,
) )
@ -609,6 +623,7 @@ def main():
distributed=args.distributed, distributed=args.distributed,
crop_pct=data_config['crop_pct'], crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
device=device,
) )
# setup loss function # setup loss function
@ -628,8 +643,8 @@ def main():
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else: else:
train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.cuda() train_loss_fn = train_loss_fn.to(device=device)
validate_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
# setup checkpoint saver and eval metric tracking # setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric eval_metric = args.eval_metric
@ -637,7 +652,7 @@ def main():
best_epoch = None best_epoch = None
saver = None saver = None
output_dir = None output_dir = None
if args.rank == 0: if utils.is_primary(args):
if args.experiment: if args.experiment:
exp_name = args.experiment exp_name = args.experiment
else: else:
@ -649,8 +664,16 @@ def main():
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False decreasing = True if eval_metric == 'loss' else False
saver = utils.CheckpointSaver( saver = utils.CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, model=model,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) optimizer=optimizer,
args=args,
model_ema=model_ema,
amp_scaler=loss_scaler,
checkpoint_dir=output_dir,
recovery_dir=output_dir,
decreasing=decreasing,
max_history=args.checkpoint_hist
)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text) f.write(args_text)
@ -660,22 +683,46 @@ def main():
loader_train.sampler.set_epoch(epoch) loader_train.sampler.set_epoch(epoch)
train_metrics = train_one_epoch( train_metrics = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args, epoch,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, model,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) loader_train,
optimizer,
train_loss_fn,
args,
lr_scheduler=lr_scheduler,
saver=saver,
output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_ema=model_ema,
mixup_fn=mixup_fn,
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0: if utils.is_primary(args):
_logger.info("Distributing BatchNorm running means and vars") _logger.info("Distributing BatchNorm running means and vars")
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) eval_metrics = validate(
model,
loader_eval,
validate_loss_fn,
args,
amp_autocast=amp_autocast,
)
if model_ema is not None and not args.model_ema_force_cpu: if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate( ema_eval_metrics = validate(
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') model_ema.module,
loader_eval,
validate_loss_fn,
args,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
)
eval_metrics = ema_eval_metrics eval_metrics = ema_eval_metrics
if lr_scheduler is not None: if lr_scheduler is not None:
@ -684,8 +731,13 @@ def main():
if output_dir is not None: if output_dir is not None:
utils.update_summary( utils.update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), epoch,
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) train_metrics,
eval_metrics,
os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb,
)
if saver is not None: if saver is not None:
# save proper checkpoint with eval metric # save proper checkpoint with eval metric
@ -699,10 +751,21 @@ def main():
def train_one_epoch( def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args, epoch,
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, model,
loss_scaler=None, model_ema=None, mixup_fn=None): loader,
optimizer,
loss_fn,
args,
device=torch.device('cuda'),
lr_scheduler=None,
saver=None,
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled: if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False loader.mixup_enabled = False
@ -723,7 +786,7 @@ def train_one_epoch(
last_batch = batch_idx == last_idx last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end) data_time_m.update(time.time() - end)
if not args.prefetcher: if not args.prefetcher:
input, target = input.cuda(), target.cuda() input, target = input.to(device), target.to(device)
if mixup_fn is not None: if mixup_fn is not None:
input, target = mixup_fn(input, target) input, target = mixup_fn(input, target)
if args.channels_last: if args.channels_last:
@ -740,21 +803,26 @@ def train_one_epoch(
if loss_scaler is not None: if loss_scaler is not None:
loss_scaler( loss_scaler(
loss, optimizer, loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode, clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order) create_graph=second_order
)
else: else:
loss.backward(create_graph=second_order) loss.backward(create_graph=second_order)
if args.clip_grad is not None: if args.clip_grad is not None:
utils.dispatch_clip_grad( utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode), model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode) value=args.clip_grad,
mode=args.clip_mode
)
optimizer.step() optimizer.step()
if model_ema is not None: if model_ema is not None:
model_ema.update(model) model_ema.update(model)
torch.cuda.synchronize() torch.cuda.synchronize()
num_updates += 1 num_updates += 1
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0: if last_batch or batch_idx % args.log_interval == 0:
@ -765,7 +833,7 @@ def train_one_epoch(
reduced_loss = utils.reduce_tensor(loss.data, args.world_size) reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0)) losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0: if utils.is_primary(args):
_logger.info( _logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
@ -781,14 +849,16 @@ def train_one_epoch(
rate=input.size(0) * args.world_size / batch_time_m.val, rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg, rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr, lr=lr,
data_time=data_time_m)) data_time=data_time_m)
)
if args.save_images and output_dir: if args.save_images and output_dir:
torchvision.utils.save_image( torchvision.utils.save_image(
input, input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0, padding=0,
normalize=True) normalize=True
)
if saver is not None and args.recovery_interval and ( if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0): last_batch or (batch_idx + 1) % args.recovery_interval == 0):
@ -806,7 +876,15 @@ def train_one_epoch(
return OrderedDict([('loss', losses_m.avg)]) return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): def validate(
model,
loader,
loss_fn,
args,
device=torch.device('cuda'),
amp_autocast=suppress,
log_suffix=''
):
batch_time_m = utils.AverageMeter() batch_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter() losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter() top1_m = utils.AverageMeter()
@ -820,8 +898,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
for batch_idx, (input, target) in enumerate(loader): for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx last_batch = batch_idx == last_idx
if not args.prefetcher: if not args.prefetcher:
input = input.cuda() input = input.to(device)
target = target.cuda() target = target.to(device)
if args.channels_last: if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last) input = input.contiguous(memory_format=torch.channels_last)
@ -846,6 +924,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
else: else:
reduced_loss = loss.data reduced_loss = loss.data
if device.type == 'cuda':
torch.cuda.synchronize() torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0)) losses_m.update(reduced_loss.item(), input.size(0))
@ -854,7 +933,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
end = time.time() end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix log_name = 'Test' + log_suffix
_logger.info( _logger.info(
'{0}: [{1:>4d}/{2}] ' '{0}: [{1:>4d}/{2}] '
@ -862,8 +941,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx, batch_time=batch_time_m, log_name, batch_idx, last_idx,
loss=losses_m, top1=top1_m, top5=top5_m)) batch_time=batch_time_m,
loss=losses_m,
top1=top1_m,
top5=top5_m)
)
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])

@ -19,6 +19,7 @@ import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from collections import OrderedDict from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from functools import partial
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
@ -45,7 +46,6 @@ try:
except ImportError as e: except ImportError as e:
has_functorch = False has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate') _logger = logging.getLogger('validate')
@ -100,6 +100,8 @@ parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False, parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout') help='Use channels_last memory layout')
parser.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False, parser.add_argument('--apex-amp', action='store_true', default=False,
@ -133,6 +135,13 @@ def validate(args):
# might as well try to validate something # might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
device = torch.device(args.device)
amp_autocast = suppress # do nothing amp_autocast = suppress # do nothing
if args.amp: if args.amp:
if has_native_amp: if has_native_amp:
@ -143,15 +152,17 @@ def validate(args):
_logger.warning("Neither APEX or Native Torch AMP is available.") _logger.warning("Neither APEX or Native Torch AMP is available.")
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
if args.native_amp: if args.native_amp:
amp_autocast = torch.cuda.amp.autocast amp_autocast = partial(torch.autocast, device_type=device.type)
_logger.info('Validating in mixed precision with native PyTorch AMP.') _logger.info('Validating in mixed precision with native PyTorch AMP.')
elif args.apex_amp: elif args.apex_amp:
assert device.type == 'cuda'
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.') _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else: else:
_logger.info('Validating in float32. AMP not enabled.') _logger.info('Validating in float32. AMP not enabled.')
if args.fuser: if args.fuser:
set_jit_fuser(args.fuser) set_jit_fuser(args.fuser)
if args.fast_norm: if args.fast_norm:
set_fast_norm() set_fast_norm()
@ -162,7 +173,8 @@ def validate(args):
num_classes=args.num_classes, num_classes=args.num_classes,
in_chans=3, in_chans=3,
global_pool=args.gp, global_pool=args.gp,
scriptable=args.torchscript) scriptable=args.torchscript,
)
if args.num_classes is None: if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes args.num_classes = model.num_classes
@ -177,7 +189,7 @@ def validate(args):
vars(args), vars(args),
model=model, model=model,
use_test_size=not args.use_train_size, use_test_size=not args.use_train_size,
verbose=True verbose=True,
) )
test_time_pool = False test_time_pool = False
if args.test_pool: if args.test_pool:
@ -186,11 +198,12 @@ def validate(args):
if args.torchscript: if args.torchscript:
torch.jit.optimized_execution(True) torch.jit.optimized_execution(True)
model = torch.jit.script(model) model = torch.jit.script(model)
if args.aot_autograd: if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)
model = model.cuda() model = model.to(device)
if args.apex_amp: if args.apex_amp:
model = amp.initialize(model, opt_level='O1') model = amp.initialize(model, opt_level='O1')
@ -200,11 +213,16 @@ def validate(args):
if args.num_gpu > 1: if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().to(device)
dataset = create_dataset( dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split, root=args.data,
download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map) name=args.dataset,
split=args.split,
download=args.dataset_download,
load_bytes=args.tf_preprocessing,
class_map=args.class_map,
)
if args.valid_labels: if args.valid_labels:
with open(args.valid_labels, 'r') as f: with open(args.valid_labels, 'r') as f:
@ -230,7 +248,9 @@ def validate(args):
num_workers=args.workers, num_workers=args.workers,
crop_pct=crop_pct, crop_pct=crop_pct,
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing) device=device,
tf_preprocessing=args.tf_preprocessing,
)
batch_time = AverageMeter() batch_time = AverageMeter()
losses = AverageMeter() losses = AverageMeter()
@ -240,7 +260,7 @@ def validate(args):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
if args.channels_last: if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last) input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast(): with amp_autocast():
@ -249,8 +269,8 @@ def validate(args):
end = time.time() end = time.time()
for batch_idx, (input, target) in enumerate(loader): for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher: if args.no_prefetcher:
target = target.cuda() target = target.to(device)
input = input.cuda() input = input.to(device)
if args.channels_last: if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last) input = input.contiguous(memory_format=torch.channels_last)
@ -282,9 +302,15 @@ def validate(args):
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
batch_idx, len(loader), batch_time=batch_time, batch_idx,
len(loader),
batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg, rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5)) loss=losses,
top1=top1,
top5=top5
)
)
if real_labels is not None: if real_labels is not None:
# real labels mode replaces topk values at the end # real labels mode replaces topk values at the end
@ -298,7 +324,8 @@ def validate(args):
param_count=round(param_count / 1e6, 2), param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1], img_size=data_config['input_size'][-1],
crop_pct=crop_pct, crop_pct=crop_pct,
interpolation=data_config['interpolation']) interpolation=data_config['interpolation'],
)
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err'])) results['top1'], results['top1_err'], results['top5'], results['top5_err']))
@ -313,6 +340,7 @@ def _try_run(args, initial_batch_size):
while batch_size: while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try: try:
if torch.cuda.is_available() and 'cuda' in args.device:
torch.cuda.empty_cache() torch.cuda.empty_cache()
results = validate(args) results = validate(args)
return results return results

Loading…
Cancel
Save