Refactor device handling in scripts, distributed init to be less 'cuda' centric. More device args passed through where needed.

pull/1479/head
Ross Wightman 2 years ago
parent c88947ad3d
commit 87939e6fab

@ -57,7 +57,9 @@ except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
@ -216,7 +218,7 @@ class BenchmarkRunner:
self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress
if fuser:
set_jit_fuser(fuser)

@ -2,11 +2,11 @@
Hacked together by / Copyright 2019, Ross Wightman
"""
import torch.utils.data as data
import os
import torch
import io
import logging
import torch
import torch.utils.data as data
from PIL import Image
from .parsers import create_parser
@ -23,23 +23,32 @@ class ImageDataset(data.Dataset):
self,
root,
parser=None,
split='train',
class_map=None,
load_bytes=False,
img_mode='RGB',
transform=None,
target_transform=None,
):
if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map)
parser = create_parser(
parser or '',
root=root,
split=split,
class_map=class_map
)
self.parser = parser
self.load_bytes = load_bytes
self.img_mode = img_mode
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
img = img.read() if self.load_bytes else Image.open(img)
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
@ -48,12 +57,17 @@ class ImageDataset(data.Dataset):
else:
raise e
self._consecutive_errors = 0
if self.img_mode and not self.load_bytes:
img = img.convert(self.img_mode)
if self.transform is not None:
img = self.transform(img)
if target is None:
target = -1
elif self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
@ -83,8 +97,14 @@ class IterableImageDataset(data.IterableDataset):
assert parser is not None
if isinstance(parser, str):
self.parser = create_parser(
parser, root=root, split=split, is_training=is_training,
batch_size=batch_size, repeats=repeats, download=download)
parser,
root=root,
split=split,
is_training=is_training,
batch_size=batch_size,
repeats=repeats,
download=download,
)
else:
self.parser = parser
self.transform = transform

@ -134,6 +134,10 @@ def create_dataset(
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training,
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, parser=name, split=split, **kwargs)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
if search_split and os.path.isdir(root):

@ -6,10 +6,12 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d
Hacked together by / Copyright 2019, Ross Wightman
"""
import random
from contextlib import suppress
from functools import partial
from itertools import repeat
from typing import Callable
import torch
import torch.utils.data
import numpy as np
@ -73,6 +75,8 @@ class PrefetchLoader:
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
channels=3,
device=torch.device('cuda'),
img_dtype=torch.float32,
fp16=False,
re_prob=0.,
re_mode='const',
@ -84,30 +88,42 @@ class PrefetchLoader:
normalization_shape = (1, channels, 1, 1)
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)
self.fp16 = fp16
self.device = device
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
img_dtype = torch.float16
self.img_dtype = img_dtype
self.mean = torch.tensor(
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
self.std = torch.tensor(
[x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
probability=re_prob,
mode=re_mode,
max_count=re_count,
num_splits=re_num_splits,
device=device,
)
else:
self.random_erasing = None
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
def __iter__(self):
stream = torch.cuda.Stream()
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
else:
stream = None
stream_context = suppress
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
if self.fp16:
next_input = next_input.half().sub_(self.mean).div_(self.std)
else:
next_input = next_input.float().sub_(self.mean).div_(self.std)
with stream_context():
next_input = next_input.to(device=self.device, non_blocking=True)
next_target = next_target.to(device=self.device, non_blocking=True)
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
@ -116,7 +132,9 @@ class PrefetchLoader:
else:
first = False
torch.cuda.current_stream().wait_stream(stream)
if stream is not None:
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
@ -189,7 +207,9 @@ def create_loader(
crop_pct=None,
collate_fn=None,
pin_memory=False,
fp16=False,
fp16=False, # deprecated, use img_dtype
img_dtype=torch.float32,
device=torch.device('cuda'),
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
@ -266,7 +286,9 @@ def create_loader(
mean=mean,
std=std,
channels=input_size[0],
fp16=fp16,
device=device,
fp16=fp16, # deprecated, use img_dtype
img_dtype=img_dtype,
re_prob=prefetch_re_prob,
re_mode=re_mode,
re_count=re_count,

@ -17,6 +17,9 @@ def create_parser(name, root, split='train', **kwargs):
if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, **kwargs)
elif prefix == 'hfds':
from .parser_hfds import ParserHfds # defer tensorflow import
parser = ParserHfds(root, name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder

@ -86,9 +86,9 @@ class ParserTfds(Parser):
repeats=0,
seed=42,
input_name='image',
input_image='RGB',
input_img_mode='RGB',
target_name='label',
target_image='',
target_img_mode='',
prefetch_size=None,
shuffle_size=None,
max_threadpool_size=None
@ -105,9 +105,9 @@ class ParserTfds(Parser):
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances
input_name: name of Feature to return as data (input)
input_image: image mode if input is an image (currently PIL mode string)
input_img_mode: image mode if input is an image (currently PIL mode string)
target_name: name of Feature to return as target (label)
target_image: image mode if target is an image (currently PIL mode string)
target_img_mode: image mode if target is an image (currently PIL mode string)
prefetch_size: override default tf.data prefetch buffer size
shuffle_size: override default tf.data shuffle buffer size
max_threadpool_size: override default threadpool size for tf.data
@ -130,9 +130,9 @@ class ParserTfds(Parser):
# TFDS builder and split information
self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
self.input_image = input_image
self.input_img_mode = input_img_mode
self.target_name = target_name
self.target_image = target_image
self.target_img_mode = target_img_mode
self.builder = tfds.builder(name, data_dir=root)
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download:
@ -249,11 +249,11 @@ class ParserTfds(Parser):
example_count = 0
for example in self.ds:
input_data = example[self.input_name]
if self.input_image:
input_data = Image.fromarray(input_data, mode=self.input_image)
if self.input_img_mode:
input_data = Image.fromarray(input_data, mode=self.input_img_mode)
target_data = example[self.target_name]
if self.target_image:
target_data = Image.fromarray(target_data, mode=self.target_image)
if self.target_img_mode:
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
yield input_data, target_data
example_count += 1
if self.is_training and example_count >= target_example_count:

@ -7,6 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman
"""
import random
import math
import torch
@ -44,8 +45,17 @@ class RandomErasing:
def __init__(
self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
probability=0.5,
min_area=0.02,
max_area=1/3,
min_aspect=0.3,
max_aspect=None,
mode='const',
min_count=1,
max_count=None,
num_splits=0,
device='cuda',
):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
@ -81,8 +91,12 @@ class RandomErasing:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w),
dtype=dtype, device=self.device)
self.per_pixel,
self.rand_color,
(chan, h, w),
dtype=dtype,
device=self.device,
)
break
def __call__(self, input):

@ -3,7 +3,8 @@ from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler
from .decay_batch import decay_batch_step, check_batch_size_retry
from .distributed import distribute_bn, reduce_tensor
from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
world_info_from_env, is_distributed_env, is_primary
from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy

@ -2,9 +2,16 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import torch
from torch import distributed as dist
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from .model import unwrap_model
@ -26,3 +33,105 @@ def distribute_bn(model, world_size, reduce=False):
else:
# broadcast bn stats from rank 0 to whole group
torch.distributed.broadcast(bn_buf, 0)
def is_global_primary(args):
return args.rank == 0
def is_local_primary(args):
return args.local_rank == 0
def is_primary(args, local=False):
return is_local_primary(args) if local else is_global_primary(args)
def is_distributed_env():
if 'WORLD_SIZE' in os.environ:
return int(os.environ['WORLD_SIZE']) > 1
if 'SLURM_NTASKS' in os.environ:
return int(os.environ['SLURM_NTASKS']) > 1
return False
def world_info_from_env():
local_rank = 0
for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
if v in os.environ:
world_size = int(os.environ[v])
break
return local_rank, global_rank, world_size
def init_distributed_device(args):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args.distributed = False
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
# TBD, support horovod?
# if args.horovod:
# assert hvd is not None, "Horovod is not installed"
# hvd.init()
# args.local_rank = int(hvd.local_rank())
# args.rank = hvd.rank()
# args.world_size = hvd.size()
# args.distributed = True
# os.environ['LOCAL_RANK'] = str(args.local_rank)
# os.environ['RANK'] = str(args.rank)
# os.environ['WORLD_SIZE'] = str(args.world_size)
dist_backend = getattr(args, 'dist_backend', 'nccl')
dist_url = getattr(args, 'dist_url', 'env://')
if is_distributed_env():
if 'SLURM_PROCID' in os.environ:
# DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ['LOCAL_RANK'] = str(args.local_rank)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
torch.distributed.init_process_group(
backend=dist_backend,
init_method=dist_url,
world_size=args.world_size,
rank=args.rank,
)
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=dist_backend,
init_method=dist_url,
)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
if torch.cuda.is_available():
if args.distributed:
device = 'cuda:%d' % args.local_rank
else:
device = 'cuda:0'
torch.cuda.set_device(device)
else:
device = 'cpu'
args.device = device
device = torch.device(device)
return device

@ -21,6 +21,7 @@ import time
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from functools import partial
import torch
import torch.nn as nn
@ -66,7 +67,6 @@ except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
@ -349,33 +349,27 @@ def main():
utils.setup_default_logging()
args, args_text = _parse_args()
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
device = utils.init_distributed_device(args)
if args.distributed:
if 'LOCAL_RANK' in os.environ:
args.local_rank = int(os.getenv('LOCAL_RANK'))
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
_logger.info(
'Training in distributed mode with multiple processes, 1 device per process.'
f'Process {args.rank}, total {args.world_size}, device {args.device}.')
else:
_logger.info('Training with a single process on 1 GPUs.')
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0
if args.rank == 0 and args.log_wandb:
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
@ -405,14 +399,14 @@ def main():
pretrained=args.pretrained,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint)
checkpoint_path=args.initial_checkpoint,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
@ -420,11 +414,11 @@ def main():
if args.grad_checkpointing:
model.set_grad_checkpointing(enable=True)
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
@ -438,9 +432,9 @@ def main():
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set
model.cuda()
model.to(device=device)
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
@ -452,7 +446,7 @@ def main():
model = convert_syncbn_model(model)
else:
model = convert_sync_batchnorm(model)
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
@ -461,6 +455,7 @@ def main():
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
@ -471,28 +466,31 @@ def main():
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
assert device.type == 'cuda'
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
amp_autocast = partial(torch.autocast, device_type=device.type)
if device.type == 'cuda':
loss_scaler = NativeScaler()
if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
resume_epoch = resume_checkpoint(
model, args.resume,
model,
args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
log_info=utils.is_primary(args),
)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
@ -507,13 +505,13 @@ def main():
if args.distributed:
if has_apex and use_amp == 'apex':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
@ -527,21 +525,30 @@ def main():
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets
dataset_train = create_dataset(
args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
args.dataset,
root=args.data_dir,
split=args.train_split,
is_training=True,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
repeats=args.epoch_repeats)
repeats=args.epoch_repeats
)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
args.dataset,
root=args.data_dir,
split=args.val_split,
is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size)
batch_size=args.batch_size
)
# setup mixup / cutmix
collate_fn = None
@ -549,9 +556,15 @@ def main():
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes)
mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix,
cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob,
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.smoothing,
num_classes=args.num_classes
)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
@ -592,6 +605,7 @@ def main():
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
device=device,
use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding,
)
@ -609,6 +623,7 @@ def main():
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
device=device,
)
# setup loss function
@ -628,8 +643,8 @@ def main():
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
train_loss_fn = train_loss_fn.to(device=device)
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
@ -637,7 +652,7 @@ def main():
best_epoch = None
saver = None
output_dir = None
if args.rank == 0:
if utils.is_primary(args):
if args.experiment:
exp_name = args.experiment
else:
@ -649,8 +664,16 @@ def main():
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = utils.CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
model=model,
optimizer=optimizer,
args=args,
model_ema=model_ema,
amp_scaler=loss_scaler,
checkpoint_dir=output_dir,
recovery_dir=output_dir,
decreasing=decreasing,
max_history=args.checkpoint_hist
)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
@ -660,22 +683,46 @@ def main():
loader_train.sampler.set_epoch(epoch)
train_metrics = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
epoch,
model,
loader_train,
optimizer,
train_loss_fn,
args,
lr_scheduler=lr_scheduler,
saver=saver,
output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_ema=model_ema,
mixup_fn=mixup_fn,
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info("Distributing BatchNorm running means and vars")
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
eval_metrics = validate(
model,
loader_eval,
validate_loss_fn,
args,
amp_autocast=amp_autocast,
)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
model_ema.module,
loader_eval,
validate_loss_fn,
args,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
@ -684,8 +731,13 @@ def main():
if output_dir is not None:
utils.update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
epoch,
train_metrics,
eval_metrics,
os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb,
)
if saver is not None:
# save proper checkpoint with eval metric
@ -699,10 +751,21 @@ def main():
def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
epoch,
model,
loader,
optimizer,
loss_fn,
args,
device=torch.device('cuda'),
lr_scheduler=None,
saver=None,
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
@ -723,7 +786,7 @@ def train_one_epoch(
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
input, target = input.to(device), target.to(device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
@ -740,21 +803,26 @@ def train_one_epoch(
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
create_graph=second_order
)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
value=args.clip_grad,
mode=args.clip_mode
)
optimizer.step()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
@ -765,7 +833,7 @@ def train_one_epoch(
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
@ -781,14 +849,16 @@ def train_one_epoch(
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
data_time=data_time_m)
)
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
normalize=True
)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
@ -806,7 +876,15 @@ def train_one_epoch(
return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
def validate(
model,
loader,
loss_fn,
args,
device=torch.device('cuda'),
amp_autocast=suppress,
log_suffix=''
):
batch_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter()
@ -820,8 +898,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
input = input.to(device)
target = target.to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
@ -846,7 +924,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
else:
reduced_loss = loss.data
torch.cuda.synchronize()
if device.type == 'cuda':
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
@ -854,7 +933,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
'{0}: [{1:>4d}/{2}] '
@ -862,8 +941,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx, batch_time=batch_time_m,
loss=losses_m, top1=top1_m, top5=top5_m))
log_name, batch_idx, last_idx,
batch_time=batch_time_m,
loss=losses_m,
top1=top1_m,
top5=top5_m)
)
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])

@ -19,6 +19,7 @@ import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
from functools import partial
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
@ -45,7 +46,6 @@ try:
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
@ -100,6 +100,8 @@ parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
@ -133,6 +135,13 @@ def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
device = torch.device(args.device)
amp_autocast = suppress # do nothing
if args.amp:
if has_native_amp:
@ -143,15 +152,17 @@ def validate(args):
_logger.warning("Neither APEX or Native Torch AMP is available.")
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
if args.native_amp:
amp_autocast = torch.cuda.amp.autocast
amp_autocast = partial(torch.autocast, device_type=device.type)
_logger.info('Validating in mixed precision with native PyTorch AMP.')
elif args.apex_amp:
assert device.type == 'cuda'
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
if args.fuser:
set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
@ -162,7 +173,8 @@ def validate(args):
num_classes=args.num_classes,
in_chans=3,
global_pool=args.gp,
scriptable=args.torchscript)
scriptable=args.torchscript,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
@ -177,7 +189,7 @@ def validate(args):
vars(args),
model=model,
use_test_size=not args.use_train_size,
verbose=True
verbose=True,
)
test_time_pool = False
if args.test_pool:
@ -186,11 +198,12 @@ def validate(args):
if args.torchscript:
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
model = model.cuda()
model = model.to(device)
if args.apex_amp:
model = amp.initialize(model, opt_level='O1')
@ -200,11 +213,16 @@ def validate(args):
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
criterion = nn.CrossEntropyLoss().cuda()
criterion = nn.CrossEntropyLoss().to(device)
dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split,
download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
root=args.data,
name=args.dataset,
split=args.split,
download=args.dataset_download,
load_bytes=args.tf_preprocessing,
class_map=args.class_map,
)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:
@ -230,7 +248,9 @@ def validate(args):
num_workers=args.workers,
crop_pct=crop_pct,
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
device=device,
tf_preprocessing=args.tf_preprocessing,
)
batch_time = AverageMeter()
losses = AverageMeter()
@ -240,7 +260,7 @@ def validate(args):
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
@ -249,8 +269,8 @@ def validate(args):
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.cuda()
input = input.cuda()
target = target.to(device)
input = input.to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
@ -282,9 +302,15 @@ def validate(args):
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
batch_idx, len(loader), batch_time=batch_time,
batch_idx,
len(loader),
batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
loss=losses,
top1=top1,
top5=top5
)
)
if real_labels is not None:
# real labels mode replaces topk values at the end
@ -298,7 +324,8 @@ def validate(args):
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
crop_pct=crop_pct,
interpolation=data_config['interpolation'])
interpolation=data_config['interpolation'],
)
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
@ -313,7 +340,8 @@ def _try_run(args, initial_batch_size):
while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try:
torch.cuda.empty_cache()
if torch.cuda.is_available() and 'cuda' in args.device:
torch.cuda.empty_cache()
results = validate(args)
return results
except RuntimeError as e:

Loading…
Cancel
Save