From 87939e6fab4bf40bec7eeac7c6fbcd0a82294f17 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 23 Sep 2022 16:08:59 -0700 Subject: [PATCH] Refactor device handling in scripts, distributed init to be less 'cuda' centric. More device args passed through where needed. --- benchmark.py | 6 +- timm/data/dataset.py | 34 +++- timm/data/dataset_factory.py | 4 + timm/data/loader.py | 56 +++++-- timm/data/parsers/parser_factory.py | 3 + timm/data/parsers/parser_tfds.py | 20 +-- timm/data/random_erasing.py | 22 ++- timm/utils/__init__.py | 3 +- timm/utils/distributed.py | 109 +++++++++++++ train.py | 237 +++++++++++++++++++--------- validate.py | 60 +++++-- 11 files changed, 420 insertions(+), 134 deletions(-) diff --git a/benchmark.py b/benchmark.py index 4a89441b..a03c1982 100755 --- a/benchmark.py +++ b/benchmark.py @@ -57,7 +57,9 @@ except ImportError as e: has_functorch = False -torch.backends.cudnn.benchmark = True +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -216,7 +218,7 @@ class BenchmarkRunner: self.device = device self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) self.channels_last = kwargs.pop('channels_last', False) - self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress + self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress if fuser: set_jit_fuser(fuser) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 20b663ce..0599c78a 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -2,11 +2,11 @@ Hacked together by / Copyright 2019, Ross Wightman """ -import torch.utils.data as data -import os -import torch +import io import logging +import torch +import torch.utils.data as data from PIL import Image from .parsers import create_parser @@ -23,23 +23,32 @@ class ImageDataset(data.Dataset): self, root, parser=None, + split='train', class_map=None, load_bytes=False, + img_mode='RGB', transform=None, target_transform=None, ): if parser is None or isinstance(parser, str): - parser = create_parser(parser or '', root=root, class_map=class_map) + parser = create_parser( + parser or '', + root=root, + split=split, + class_map=class_map + ) self.parser = parser self.load_bytes = load_bytes + self.img_mode = img_mode self.transform = transform self.target_transform = target_transform self._consecutive_errors = 0 def __getitem__(self, index): img, target = self.parser[index] + try: - img = img.read() if self.load_bytes else Image.open(img).convert('RGB') + img = img.read() if self.load_bytes else Image.open(img) except Exception as e: _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}') self._consecutive_errors += 1 @@ -48,12 +57,17 @@ class ImageDataset(data.Dataset): else: raise e self._consecutive_errors = 0 + + if self.img_mode and not self.load_bytes: + img = img.convert(self.img_mode) if self.transform is not None: img = self.transform(img) + if target is None: target = -1 elif self.target_transform is not None: target = self.target_transform(target) + return img, target def __len__(self): @@ -83,8 +97,14 @@ class IterableImageDataset(data.IterableDataset): assert parser is not None if isinstance(parser, str): self.parser = create_parser( - parser, root=root, split=split, is_training=is_training, - batch_size=batch_size, repeats=repeats, download=download) + parser, + root=root, + split=split, + is_training=is_training, + batch_size=batch_size, + repeats=repeats, + download=download, + ) else: self.parser = parser self.transform = transform diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index d0ac30b1..c2be63ad 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -134,6 +134,10 @@ def create_dataset( ds = IterableImageDataset( root, parser=name, split=split, is_training=is_training, download=download, batch_size=batch_size, repeats=repeats, **kwargs) + elif name.startswith('hfds/'): + # NOTE right now, HF datasets default arrow format is a random-access Dataset, + # There will be a IterableDataset variant too, TBD + ds = ImageDataset(root, parser=name, split=split, **kwargs) else: # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future if search_split and os.path.isdir(root): diff --git a/timm/data/loader.py b/timm/data/loader.py index ecc075c0..a77e0a4c 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -6,10 +6,12 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d Hacked together by / Copyright 2019, Ross Wightman """ import random +from contextlib import suppress from functools import partial from itertools import repeat from typing import Callable +import torch import torch.utils.data import numpy as np @@ -73,6 +75,8 @@ class PrefetchLoader: mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, channels=3, + device=torch.device('cuda'), + img_dtype=torch.float32, fp16=False, re_prob=0., re_mode='const', @@ -84,30 +88,42 @@ class PrefetchLoader: normalization_shape = (1, channels, 1, 1) self.loader = loader - self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape) - self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape) - self.fp16 = fp16 + self.device = device if fp16: - self.mean = self.mean.half() - self.std = self.std.half() + # fp16 arg is deprecated, but will override dtype arg if set for bwd compat + img_dtype = torch.float16 + self.img_dtype = img_dtype + self.mean = torch.tensor( + [x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape) + self.std = torch.tensor( + [x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape) if re_prob > 0.: self.random_erasing = RandomErasing( - probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) + probability=re_prob, + mode=re_mode, + max_count=re_count, + num_splits=re_num_splits, + device=device, + ) else: self.random_erasing = None + self.is_cuda = torch.cuda.is_available() and device.type == 'cuda' def __iter__(self): - stream = torch.cuda.Stream() first = True + if self.is_cuda: + stream = torch.cuda.Stream() + stream_context = partial(torch.cuda.stream, stream=stream) + else: + stream = None + stream_context = suppress for next_input, next_target in self.loader: - with torch.cuda.stream(stream): - next_input = next_input.cuda(non_blocking=True) - next_target = next_target.cuda(non_blocking=True) - if self.fp16: - next_input = next_input.half().sub_(self.mean).div_(self.std) - else: - next_input = next_input.float().sub_(self.mean).div_(self.std) + + with stream_context(): + next_input = next_input.to(device=self.device, non_blocking=True) + next_target = next_target.to(device=self.device, non_blocking=True) + next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std) if self.random_erasing is not None: next_input = self.random_erasing(next_input) @@ -116,7 +132,9 @@ class PrefetchLoader: else: first = False - torch.cuda.current_stream().wait_stream(stream) + if stream is not None: + torch.cuda.current_stream().wait_stream(stream) + input = next_input target = next_target @@ -189,7 +207,9 @@ def create_loader( crop_pct=None, collate_fn=None, pin_memory=False, - fp16=False, + fp16=False, # deprecated, use img_dtype + img_dtype=torch.float32, + device=torch.device('cuda'), tf_preprocessing=False, use_multi_epochs_loader=False, persistent_workers=True, @@ -266,7 +286,9 @@ def create_loader( mean=mean, std=std, channels=input_size[0], - fp16=fp16, + device=device, + fp16=fp16, # deprecated, use img_dtype + img_dtype=img_dtype, re_prob=prefetch_re_prob, re_mode=re_mode, re_count=re_count, diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py index 0665c02a..a204bf6a 100644 --- a/timm/data/parsers/parser_factory.py +++ b/timm/data/parsers/parser_factory.py @@ -17,6 +17,9 @@ def create_parser(name, root, split='train', **kwargs): if prefix == 'tfds': from .parser_tfds import ParserTfds # defer tensorflow import parser = ParserTfds(root, name, split=split, **kwargs) + elif prefix == 'hfds': + from .parser_hfds import ParserHfds # defer tensorflow import + parser = ParserHfds(root, name, split=split, **kwargs) else: assert os.path.exists(root) # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 739f3813..c0128a5b 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -86,9 +86,9 @@ class ParserTfds(Parser): repeats=0, seed=42, input_name='image', - input_image='RGB', + input_img_mode='RGB', target_name='label', - target_image='', + target_img_mode='', prefetch_size=None, shuffle_size=None, max_threadpool_size=None @@ -105,9 +105,9 @@ class ParserTfds(Parser): repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) seed: common seed for shard shuffle across all distributed/worker instances input_name: name of Feature to return as data (input) - input_image: image mode if input is an image (currently PIL mode string) + input_img_mode: image mode if input is an image (currently PIL mode string) target_name: name of Feature to return as target (label) - target_image: image mode if target is an image (currently PIL mode string) + target_img_mode: image mode if target is an image (currently PIL mode string) prefetch_size: override default tf.data prefetch buffer size shuffle_size: override default tf.data shuffle buffer size max_threadpool_size: override default threadpool size for tf.data @@ -130,9 +130,9 @@ class ParserTfds(Parser): # TFDS builder and split information self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature - self.input_image = input_image + self.input_img_mode = input_img_mode self.target_name = target_name - self.target_image = target_image + self.target_img_mode = target_img_mode self.builder = tfds.builder(name, data_dir=root) # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag if download: @@ -249,11 +249,11 @@ class ParserTfds(Parser): example_count = 0 for example in self.ds: input_data = example[self.input_name] - if self.input_image: - input_data = Image.fromarray(input_data, mode=self.input_image) + if self.input_img_mode: + input_data = Image.fromarray(input_data, mode=self.input_img_mode) target_data = example[self.target_name] - if self.target_image: - target_data = Image.fromarray(target_data, mode=self.target_image) + if self.target_img_mode: + target_data = Image.fromarray(target_data, mode=self.target_img_mode) yield input_data, target_data example_count += 1 if self.is_training and example_count >= target_example_count: diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index 98108488..1dee5f86 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -7,6 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman """ import random import math + import torch @@ -44,8 +45,17 @@ class RandomErasing: def __init__( self, - probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, - mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): + probability=0.5, + min_area=0.02, + max_area=1/3, + min_aspect=0.3, + max_aspect=None, + mode='const', + min_count=1, + max_count=None, + num_splits=0, + device='cuda', + ): self.probability = probability self.min_area = min_area self.max_area = max_area @@ -81,8 +91,12 @@ class RandomErasing: top = random.randint(0, img_h - h) left = random.randint(0, img_w - w) img[:, top:top + h, left:left + w] = _get_pixels( - self.per_pixel, self.rand_color, (chan, h, w), - dtype=dtype, device=self.device) + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) break def __call__(self, input): diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 7b139852..a9ff0c78 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -3,7 +3,8 @@ from .checkpoint_saver import CheckpointSaver from .clip_grad import dispatch_clip_grad from .cuda import ApexScaler, NativeScaler from .decay_batch import decay_batch_step, check_batch_size_retry -from .distributed import distribute_bn, reduce_tensor +from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\ + world_info_from_env, is_distributed_env, is_primary from .jit import set_jit_legacy, set_jit_fuser from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 3c5dba8c..ee9a358c 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -2,9 +2,16 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import os + import torch from torch import distributed as dist +try: + import horovod.torch as hvd +except ImportError: + hvd = None + from .model import unwrap_model @@ -26,3 +33,105 @@ def distribute_bn(model, world_size, reduce=False): else: # broadcast bn stats from rank 0 to whole group torch.distributed.broadcast(bn_buf, 0) + + +def is_global_primary(args): + return args.rank == 0 + + +def is_local_primary(args): + return args.local_rank == 0 + + +def is_primary(args, local=False): + return is_local_primary(args) if local else is_global_primary(args) + + +def is_distributed_env(): + if 'WORLD_SIZE' in os.environ: + return int(os.environ['WORLD_SIZE']) > 1 + if 'SLURM_NTASKS' in os.environ: + return int(os.environ['SLURM_NTASKS']) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): + if v in os.environ: + local_rank = int(os.environ[v]) + break + + global_rank = 0 + for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): + if v in os.environ: + global_rank = int(os.environ[v]) + break + + world_size = 1 + for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + + # TBD, support horovod? + # if args.horovod: + # assert hvd is not None, "Horovod is not installed" + # hvd.init() + # args.local_rank = int(hvd.local_rank()) + # args.rank = hvd.rank() + # args.world_size = hvd.size() + # args.distributed = True + # os.environ['LOCAL_RANK'] = str(args.local_rank) + # os.environ['RANK'] = str(args.rank) + # os.environ['WORLD_SIZE'] = str(args.world_size) + dist_backend = getattr(args, 'dist_backend', 'nccl') + dist_url = getattr(args, 'dist_url', 'env://') + if is_distributed_env(): + if 'SLURM_PROCID' in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ['LOCAL_RANK'] = str(args.local_rank) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + torch.distributed.init_process_group( + backend=dist_backend, + init_method=dist_url, + world_size=args.world_size, + rank=args.rank, + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend=dist_backend, + init_method=dist_url, + ) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + + if torch.cuda.is_available(): + if args.distributed: + device = 'cuda:%d' % args.local_rank + else: + device = 'cuda:0' + torch.cuda.set_device(device) + else: + device = 'cpu' + + args.device = device + device = torch.device(device) + return device diff --git a/train.py b/train.py index ee137217..91980cb6 100755 --- a/train.py +++ b/train.py @@ -21,6 +21,7 @@ import time from collections import OrderedDict from contextlib import suppress from datetime import datetime +from functools import partial import torch import torch.nn as nn @@ -66,7 +67,6 @@ except ImportError as e: has_functorch = False -torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') # The first arg parser parses out only the --config argument, this argument is used to @@ -349,33 +349,27 @@ def main(): utils.setup_default_logging() args, args_text = _parse_args() + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + args.prefetcher = not args.no_prefetcher - args.distributed = False - if 'WORLD_SIZE' in os.environ: - args.distributed = int(os.environ['WORLD_SIZE']) > 1 - args.device = 'cuda:0' - args.world_size = 1 - args.rank = 0 # global rank + device = utils.init_distributed_device(args) if args.distributed: - if 'LOCAL_RANK' in os.environ: - args.local_rank = int(os.getenv('LOCAL_RANK')) - args.device = 'cuda:%d' % args.local_rank - torch.cuda.set_device(args.local_rank) - torch.distributed.init_process_group(backend='nccl', init_method='env://') - args.world_size = torch.distributed.get_world_size() - args.rank = torch.distributed.get_rank() - _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' - % (args.rank, args.world_size)) + _logger.info( + 'Training in distributed mode with multiple processes, 1 device per process.' + f'Process {args.rank}, total {args.world_size}, device {args.device}.') else: - _logger.info('Training with a single process on 1 GPUs.') + _logger.info(f'Training with a single process on 1 device ({args.device}).') assert args.rank >= 0 - if args.rank == 0 and args.log_wandb: + if utils.is_primary(args) and args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: - _logger.warning("You've requested to log metrics to wandb but package not found. " - "Metrics not being logged to wandb, try `pip install wandb`") + _logger.warning( + "You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") # resolve AMP arguments based on PyTorch / Apex availability use_amp = None @@ -405,14 +399,14 @@ def main(): pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, - drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, - checkpoint_path=args.initial_checkpoint) + checkpoint_path=args.initial_checkpoint, + ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly @@ -420,11 +414,11 @@ def main(): if args.grad_checkpointing: model.set_grad_checkpointing(enable=True) - if args.local_rank == 0: + if utils.is_primary(args): _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') - data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args)) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 @@ -438,9 +432,9 @@ def main(): model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set - model.cuda() + model.to(device=device) if args.channels_last: - model = model.to(memory_format=torch.channels_last) + model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: @@ -452,7 +446,7 @@ def main(): model = convert_syncbn_model(model) else: model = convert_sync_batchnorm(model) - if args.local_rank == 0: + if utils.is_primary(args): _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') @@ -461,6 +455,7 @@ def main(): assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) + if args.aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model) @@ -471,28 +466,31 @@ def main(): amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': + assert device.type == 'cuda' model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() - if args.local_rank == 0: + if utils.is_primary(args): _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': - amp_autocast = torch.cuda.amp.autocast - loss_scaler = NativeScaler() - if args.local_rank == 0: + amp_autocast = partial(torch.autocast, device_type=device.type) + if device.type == 'cuda': + loss_scaler = NativeScaler() + if utils.is_primary(args): _logger.info('Using native Torch AMP. Training in mixed precision.') else: - if args.local_rank == 0: + if utils.is_primary(args): _logger.info('AMP not enabled. Training in float32.') - # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( - model, args.resume, + model, + args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, - log_info=args.local_rank == 0) + log_info=utils.is_primary(args), + ) # setup exponential moving average of model weights, SWA could be used here too model_ema = None @@ -507,13 +505,13 @@ def main(): if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated - if args.local_rank == 0: + if utils.is_primary(args): _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: - if args.local_rank == 0: + if utils.is_primary(args): _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) + model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch @@ -527,21 +525,30 @@ def main(): if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) - if args.local_rank == 0: + if utils.is_primary(args): _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset( - args.dataset, root=args.data_dir, split=args.train_split, is_training=True, + args.dataset, + root=args.data_dir, + split=args.train_split, + is_training=True, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, - repeats=args.epoch_repeats) + repeats=args.epoch_repeats + ) + dataset_eval = create_dataset( - args.dataset, root=args.data_dir, split=args.val_split, is_training=False, + args.dataset, + root=args.data_dir, + split=args.val_split, + is_training=False, class_map=args.class_map, download=args.dataset_download, - batch_size=args.batch_size) + batch_size=args.batch_size + ) # setup mixup / cutmix collate_fn = None @@ -549,9 +556,15 @@ def main(): mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( - mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, - prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, - label_smoothing=args.smoothing, num_classes=args.num_classes) + mixup_alpha=args.mixup, + cutmix_alpha=args.cutmix, + cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, + switch_prob=args.mixup_switch_prob, + mode=args.mixup_mode, + label_smoothing=args.smoothing, + num_classes=args.num_classes + ) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) @@ -592,6 +605,7 @@ def main(): distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, + device=device, use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) @@ -609,6 +623,7 @@ def main(): distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, + device=device, ) # setup loss function @@ -628,8 +643,8 @@ def main(): train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() - train_loss_fn = train_loss_fn.cuda() - validate_loss_fn = nn.CrossEntropyLoss().cuda() + train_loss_fn = train_loss_fn.to(device=device) + validate_loss_fn = nn.CrossEntropyLoss().to(device=device) # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric @@ -637,7 +652,7 @@ def main(): best_epoch = None saver = None output_dir = None - if args.rank == 0: + if utils.is_primary(args): if args.experiment: exp_name = args.experiment else: @@ -649,8 +664,16 @@ def main(): output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = utils.CheckpointSaver( - model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, - checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) + model=model, + optimizer=optimizer, + args=args, + model_ema=model_ema, + amp_scaler=loss_scaler, + checkpoint_dir=output_dir, + recovery_dir=output_dir, + decreasing=decreasing, + max_history=args.checkpoint_hist + ) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) @@ -660,22 +683,46 @@ def main(): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch( - epoch, model, loader_train, optimizer, train_loss_fn, args, - lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, - amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) + epoch, + model, + loader_train, + optimizer, + train_loss_fn, + args, + lr_scheduler=lr_scheduler, + saver=saver, + output_dir=output_dir, + amp_autocast=amp_autocast, + loss_scaler=loss_scaler, + model_ema=model_ema, + mixup_fn=mixup_fn, + ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): - if args.local_rank == 0: + if utils.is_primary(args): _logger.info("Distributing BatchNorm running means and vars") utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') - eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) + eval_metrics = validate( + model, + loader_eval, + validate_loss_fn, + args, + amp_autocast=amp_autocast, + ) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') + ema_eval_metrics = validate( - model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') + model_ema.module, + loader_eval, + validate_loss_fn, + args, + amp_autocast=amp_autocast, + log_suffix=' (EMA)', + ) eval_metrics = ema_eval_metrics if lr_scheduler is not None: @@ -684,8 +731,13 @@ def main(): if output_dir is not None: utils.update_summary( - epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) + epoch, + train_metrics, + eval_metrics, + os.path.join(output_dir, 'summary.csv'), + write_header=best_metric is None, + log_wandb=args.log_wandb and has_wandb, + ) if saver is not None: # save proper checkpoint with eval metric @@ -699,10 +751,21 @@ def main(): def train_one_epoch( - epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, - loss_scaler=None, model_ema=None, mixup_fn=None): - + epoch, + model, + loader, + optimizer, + loss_fn, + args, + device=torch.device('cuda'), + lr_scheduler=None, + saver=None, + output_dir=None, + amp_autocast=suppress, + loss_scaler=None, + model_ema=None, + mixup_fn=None +): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: loader.mixup_enabled = False @@ -723,7 +786,7 @@ def train_one_epoch( last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: - input, target = input.cuda(), target.cuda() + input, target = input.to(device), target.to(device) if mixup_fn is not None: input, target = mixup_fn(input, target) if args.channels_last: @@ -740,21 +803,26 @@ def train_one_epoch( if loss_scaler is not None: loss_scaler( loss, optimizer, - clip_grad=args.clip_grad, clip_mode=args.clip_mode, + clip_grad=args.clip_grad, + clip_mode=args.clip_mode, parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), - create_graph=second_order) + create_graph=second_order + ) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: utils.dispatch_clip_grad( model_parameters(model, exclude_head='agc' in args.clip_mode), - value=args.clip_grad, mode=args.clip_mode) + value=args.clip_grad, + mode=args.clip_mode + ) optimizer.step() if model_ema is not None: model_ema.update(model) torch.cuda.synchronize() + num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: @@ -765,7 +833,7 @@ def train_one_epoch( reduced_loss = utils.reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) - if args.local_rank == 0: + if utils.is_primary(args): _logger.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' @@ -781,14 +849,16 @@ def train_one_epoch( rate=input.size(0) * args.world_size / batch_time_m.val, rate_avg=input.size(0) * args.world_size / batch_time_m.avg, lr=lr, - data_time=data_time_m)) + data_time=data_time_m) + ) if args.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, - normalize=True) + normalize=True + ) if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): @@ -806,7 +876,15 @@ def train_one_epoch( return OrderedDict([('loss', losses_m.avg)]) -def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): +def validate( + model, + loader, + loss_fn, + args, + device=torch.device('cuda'), + amp_autocast=suppress, + log_suffix='' +): batch_time_m = utils.AverageMeter() losses_m = utils.AverageMeter() top1_m = utils.AverageMeter() @@ -820,8 +898,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if not args.prefetcher: - input = input.cuda() - target = target.cuda() + input = input.to(device) + target = target.to(device) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) @@ -846,7 +924,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') else: reduced_loss = loss.data - torch.cuda.synchronize() + if device.type == 'cuda': + torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) @@ -854,7 +933,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') batch_time_m.update(time.time() - end) end = time.time() - if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): + if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info( '{0}: [{1:>4d}/{2}] ' @@ -862,8 +941,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( - log_name, batch_idx, last_idx, batch_time=batch_time_m, - loss=losses_m, top1=top1_m, top5=top5_m)) + log_name, batch_idx, last_idx, + batch_time=batch_time_m, + loss=losses_m, + top1=top1_m, + top5=top5_m) + ) metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) diff --git a/validate.py b/validate.py index 6244f052..cdce82bf 100755 --- a/validate.py +++ b/validate.py @@ -19,6 +19,7 @@ import torch.nn as nn import torch.nn.parallel from collections import OrderedDict from contextlib import suppress +from functools import partial from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet @@ -45,7 +46,6 @@ try: except ImportError as e: has_functorch = False -torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -100,6 +100,8 @@ parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') +parser.add_argument('--device', default='cuda', type=str, + help="Device (accelerator) to use.") parser.add_argument('--amp', action='store_true', default=False, help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') parser.add_argument('--apex-amp', action='store_true', default=False, @@ -133,6 +135,13 @@ def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher + + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + + device = torch.device(args.device) + amp_autocast = suppress # do nothing if args.amp: if has_native_amp: @@ -143,15 +152,17 @@ def validate(args): _logger.warning("Neither APEX or Native Torch AMP is available.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: - amp_autocast = torch.cuda.amp.autocast + amp_autocast = partial(torch.autocast, device_type=device.type) _logger.info('Validating in mixed precision with native PyTorch AMP.') elif args.apex_amp: + assert device.type == 'cuda' _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: _logger.info('Validating in float32. AMP not enabled.') if args.fuser: set_jit_fuser(args.fuser) + if args.fast_norm: set_fast_norm() @@ -162,7 +173,8 @@ def validate(args): num_classes=args.num_classes, in_chans=3, global_pool=args.gp, - scriptable=args.torchscript) + scriptable=args.torchscript, + ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes @@ -177,7 +189,7 @@ def validate(args): vars(args), model=model, use_test_size=not args.use_train_size, - verbose=True + verbose=True, ) test_time_pool = False if args.test_pool: @@ -186,11 +198,12 @@ def validate(args): if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) + if args.aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model) - model = model.cuda() + model = model.to(device) if args.apex_amp: model = amp.initialize(model, opt_level='O1') @@ -200,11 +213,16 @@ def validate(args): if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) - criterion = nn.CrossEntropyLoss().cuda() + criterion = nn.CrossEntropyLoss().to(device) dataset = create_dataset( - root=args.data, name=args.dataset, split=args.split, - download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map) + root=args.data, + name=args.dataset, + split=args.split, + download=args.dataset_download, + load_bytes=args.tf_preprocessing, + class_map=args.class_map, + ) if args.valid_labels: with open(args.valid_labels, 'r') as f: @@ -230,7 +248,9 @@ def validate(args): num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, - tf_preprocessing=args.tf_preprocessing) + device=device, + tf_preprocessing=args.tf_preprocessing, + ) batch_time = AverageMeter() losses = AverageMeter() @@ -240,7 +260,7 @@ def validate(args): model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() + input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): @@ -249,8 +269,8 @@ def validate(args): end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: - target = target.cuda() - input = input.cuda() + target = target.to(device) + input = input.to(device) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) @@ -282,9 +302,15 @@ def validate(args): 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( - batch_idx, len(loader), batch_time=batch_time, + batch_idx, + len(loader), + batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, - loss=losses, top1=top1, top5=top5)) + loss=losses, + top1=top1, + top5=top5 + ) + ) if real_labels is not None: # real labels mode replaces topk values at the end @@ -298,7 +324,8 @@ def validate(args): param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], crop_pct=crop_pct, - interpolation=data_config['interpolation']) + interpolation=data_config['interpolation'], + ) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) @@ -313,7 +340,8 @@ def _try_run(args, initial_batch_size): while batch_size: args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case try: - torch.cuda.empty_cache() + if torch.cuda.is_available() and 'cuda' in args.device: + torch.cuda.empty_cache() results = validate(args) return results except RuntimeError as e: