Merge pull request #1479 from rwightman/script_cleanup

Train / val script enhancements, non-GPU (ie CPU) device support, HF datasets support, TFDS/WDS dataloading improvements
pull/1498/head
Ross Wightman 2 years ago committed by GitHub
commit 6635bc3f7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -128,7 +128,7 @@ More models, more fixes
* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs.
* Hugging Face Hub support fixes verified, demo notebook TBA
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`

@ -57,7 +57,9 @@ except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
@ -216,7 +218,7 @@ class BenchmarkRunner:
self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress
if fuser:
set_jit_fuser(fuser)

@ -6,8 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser,\
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .readers import create_reader
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform

@ -2,14 +2,15 @@
Hacked together by / Copyright 2019, Ross Wightman
"""
import torch.utils.data as data
import os
import torch
import io
import logging
from typing import Optional
import torch
import torch.utils.data as data
from PIL import Image
from .parsers import create_parser
from .readers import create_reader
_logger = logging.getLogger(__name__)
@ -22,48 +23,62 @@ class ImageDataset(data.Dataset):
def __init__(
self,
root,
parser=None,
reader=None,
split='train',
class_map=None,
load_bytes=False,
img_mode='RGB',
transform=None,
target_transform=None,
):
if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map)
self.parser = parser
if reader is None or isinstance(reader, str):
reader = create_reader(
reader or '',
root=root,
split=split,
class_map=class_map
)
self.reader = reader
self.load_bytes = load_bytes
self.img_mode = img_mode
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __getitem__(self, index):
img, target = self.parser[index]
img, target = self.reader[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
img = img.read() if self.load_bytes else Image.open(img)
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
_logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
return self.__getitem__((index + 1) % len(self.reader))
else:
raise e
self._consecutive_errors = 0
if self.img_mode and not self.load_bytes:
img = img.convert(self.img_mode)
if self.transform is not None:
img = self.transform(img)
if target is None:
target = -1
elif self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.parser)
return len(self.reader)
def filename(self, index, basename=False, absolute=False):
return self.parser.filename(index, basename, absolute)
return self.reader.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
return self.reader.filenames(basename, absolute)
class IterableImageDataset(data.IterableDataset):
@ -71,28 +86,36 @@ class IterableImageDataset(data.IterableDataset):
def __init__(
self,
root,
parser=None,
reader=None,
split='train',
is_training=False,
batch_size=None,
seed=42,
repeats=0,
download=False,
transform=None,
target_transform=None,
):
assert parser is not None
if isinstance(parser, str):
self.parser = create_parser(
parser, root=root, split=split, is_training=is_training,
batch_size=batch_size, repeats=repeats, download=download)
assert reader is not None
if isinstance(reader, str):
self.reader = create_reader(
reader,
root=root,
split=split,
is_training=is_training,
batch_size=batch_size,
seed=seed,
repeats=repeats,
download=download,
)
else:
self.parser = parser
self.reader = reader
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __iter__(self):
for img, target in self.parser:
for img, target in self.reader:
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
@ -100,16 +123,29 @@ class IterableImageDataset(data.IterableDataset):
yield img, target
def __len__(self):
if hasattr(self.parser, '__len__'):
return len(self.parser)
if hasattr(self.reader, '__len__'):
return len(self.reader)
else:
return 0
def set_epoch(self, count):
# TFDS and WDS need external epoch count for deterministic cross process shuffle
if hasattr(self.reader, 'set_epoch'):
self.reader.set_epoch(count)
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
# TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
if hasattr(self.reader, 'set_loader_cfg'):
self.reader.set_loader_cfg(num_workers=num_workers)
def filename(self, index, basename=False, absolute=False):
assert False, 'Filename lookup by index not supported, use filenames().'
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
return self.reader.filenames(basename, absolute)
class AugMixDataset(torch.utils.data.Dataset):

@ -60,6 +60,7 @@ def create_dataset(
is_training=False,
download=False,
batch_size=None,
seed=42,
repeats=0,
**kwargs
):
@ -68,7 +69,9 @@ def create_dataset(
In parenthesis after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* HFDS - Hugging Face Datasets
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* WDS - Webdataset
* all - any of the above
Args:
@ -79,11 +82,12 @@ def create_dataset(
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
class_map: specify class -> index mapping via text file or dict (folder)
load_bytes: load data, return images as undecoded bytes (folder)
download: download dataset if not present and supported (TFDS, torch)
download: download dataset if not present and supported (HFDS, TFDS, torch)
is_training: create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
batch_size: batch size hint for (TFDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS)
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS)
batch_size: batch size hint for (TFDS, WDS)
seed: seed for iterable datasets (TFDS, WDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
**kwargs: other args to pass to dataset
Returns:
@ -130,14 +134,37 @@ def create_dataset(
ds = ImageFolder(root, **kwargs)
else:
assert False, f"Unknown torchvision dataset {name}"
elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, reader=name, split=split, **kwargs)
elif name.startswith('tfds/'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training,
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
root,
reader=name,
split=split,
is_training=is_training,
download=download,
batch_size=batch_size,
repeats=repeats,
seed=seed,
**kwargs
)
elif name.startswith('wds/'):
ds = IterableImageDataset(
root,
reader=name,
split=split,
is_training=is_training,
batch_size=batch_size,
repeats=repeats,
seed=seed,
**kwargs
)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
return ds

@ -5,19 +5,25 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d
Hacked together by / Copyright 2019, Ross Wightman
"""
import logging
import random
from contextlib import suppress
from functools import partial
from itertools import repeat
from typing import Callable
import torch
import torch.utils.data
import numpy as np
from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .dataset import IterableImageDataset
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
from .random_erasing import RandomErasing
from .mixup import FastCollateMixup
from .transforms_factory import create_transform
_logger = logging.getLogger(__name__)
def fast_collate(batch):
@ -55,11 +61,13 @@ def fast_collate(batch):
assert False
def expand_to_chs(x, n):
def adapt_to_chs(x, n):
if not isinstance(x, (tuple, list)):
x = tuple(repeat(x, n))
elif len(x) == 1:
x = x * n
elif len(x) != n:
x_mean = np.mean(x).item()
x = (x_mean,) * n
_logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.')
else:
assert len(x) == n, 'normalization stats must match image channels'
return x
@ -73,41 +81,55 @@ class PrefetchLoader:
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
channels=3,
device=torch.device('cuda'),
img_dtype=torch.float32,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
mean = expand_to_chs(mean, channels)
std = expand_to_chs(std, channels)
mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
normalization_shape = (1, channels, 1, 1)
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)
self.fp16 = fp16
self.device = device
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
img_dtype = torch.float16
self.img_dtype = img_dtype
self.mean = torch.tensor(
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
self.std = torch.tensor(
[x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
probability=re_prob,
mode=re_mode,
max_count=re_count,
num_splits=re_num_splits,
device=device,
)
else:
self.random_erasing = None
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
def __iter__(self):
stream = torch.cuda.Stream()
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
else:
stream = None
stream_context = suppress
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
if self.fp16:
next_input = next_input.half().sub_(self.mean).div_(self.std)
else:
next_input = next_input.float().sub_(self.mean).div_(self.std)
with stream_context():
next_input = next_input.to(device=self.device, non_blocking=True)
next_target = next_target.to(device=self.device, non_blocking=True)
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
@ -116,7 +138,9 @@ class PrefetchLoader:
else:
first = False
if stream is not None:
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
@ -189,7 +213,9 @@ def create_loader(
crop_pct=None,
collate_fn=None,
pin_memory=False,
fp16=False,
fp16=False, # deprecated, use img_dtype
img_dtype=torch.float32,
device=torch.device('cuda'),
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
@ -222,6 +248,11 @@ def create_loader(
separate=num_aug_splits > 0,
)
if isinstance(dataset, IterableImageDataset):
# give Iterable datasets early knowledge of num_workers so that sample estimates
# are correct before worker processes are launched
dataset.set_loader_cfg(num_workers=num_workers)
sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training:
@ -266,7 +297,9 @@ def create_loader(
mean=mean,
std=std,
channels=input_size[0],
fp16=fp16,
device=device,
fp16=fp16, # deprecated, use img_dtype
img_dtype=img_dtype,
re_prob=prefetch_re_prob,
re_mode=re_mode,
re_count=re_count,

@ -1,2 +0,0 @@
from .parser_factory import create_parser
from .img_extensions import *

@ -1,28 +0,0 @@
import os
from .parser_image_folder import ParserImageFolder
from .parser_image_in_tar import ParserImageInTar
def create_parser(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here, in parser?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageInTar(root, **kwargs)
else:
parser = ParserImageFolder(root, **kwargs)
return parser

@ -7,6 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman
"""
import random
import math
import torch
@ -44,8 +45,17 @@ class RandomErasing:
def __init__(
self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
probability=0.5,
min_area=0.02,
max_area=1/3,
min_aspect=0.3,
max_aspect=None,
mode='const',
min_count=1,
max_count=None,
num_splits=0,
device='cuda',
):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
@ -81,8 +91,12 @@ class RandomErasing:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w),
dtype=dtype, device=self.device)
self.per_pixel,
self.rand_color,
(chan, h, w),
dtype=dtype,
device=self.device,
)
break
def __call__(self, input):

@ -0,0 +1,2 @@
from .reader_factory import create_reader
from .img_extensions import *

@ -1,7 +1,7 @@
from abc import abstractmethod
class Parser:
class Reader:
def __init__(self):
pass

@ -0,0 +1,35 @@
import os
from .reader_image_folder import ReaderImageFolder
from .reader_image_in_tar import ReaderImageInTar
def create_reader(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'hfds':
from .reader_hfds import ReaderHfds # defer tensorflow import
reader = ReaderHfds(root, name, split=split, **kwargs)
elif prefix == 'tfds':
from .reader_tfds import ReaderTfds # defer tensorflow import
reader = ReaderTfds(root, name, split=split, **kwargs)
elif prefix == 'wds':
from .reader_wds import ReaderWds
kwargs.pop('download', False)
reader = ReaderWds(root, name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here or in reader?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
reader = ReaderImageInTar(root, **kwargs)
else:
reader = ReaderImageFolder(root, **kwargs)
return reader

@ -0,0 +1,70 @@
""" Dataset reader that wraps Hugging Face datasets
Hacked together by / Copyright 2022 Ross Wightman
"""
import io
import math
import torch
import torch.distributed as dist
from PIL import Image
try:
import datasets
except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.")
exit(1)
from .reader import Reader
def get_class_labels(info):
if 'label' not in info.features:
return {}
class_label = info.features['label']
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
return class_to_idx
class ReaderHfds(Reader):
def __init__(
self,
root,
name,
split='train',
class_map=None,
download=False,
):
"""
"""
super().__init__()
self.root = root
self.split = split
self.dataset = datasets.load_dataset(
name, # 'name' maps to path arg in hf datasets
split=split,
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
#use_auth_token=True,
)
# leave decode for caller, plus we want easy access to original path names...
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
self.class_to_idx = get_class_labels(self.dataset.info)
self.split_info = self.dataset.info.splits[split]
self.num_samples = self.split_info.num_examples
def __getitem__(self, index):
item = self.dataset[index]
image = item['image']
if 'bytes' in image and image['bytes']:
image = io.BytesIO(image['bytes'])
else:
assert 'path' in image and image['path']
image = open(image['path'], 'rb')
return image, item['label']
def __len__(self):
return len(self.dataset)
def _filename(self, index, basename=False, absolute=False):
item = self.dataset[index]
return item['image']['path']

@ -1,6 +1,6 @@
""" A dataset parser that reads images from folders
""" A dataset reader that extracts images from folders
Folders are scannerd recursively to find image files. Labels are based
Folders are scanned recursively to find image files. Labels are based
on the folder hierarchy, just leaf folders by default.
Hacked together by / Copyright 2020 Ross Wightman
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
from .class_map import load_class_map
from .img_extensions import get_img_extensions
from .parser import Parser
from .reader import Reader
def find_images_and_targets(
@ -56,7 +56,7 @@ def find_images_and_targets(
return images_and_targets, class_to_idx
class ParserImageFolder(Parser):
class ReaderImageFolder(Reader):
def __init__(
self,

@ -1,6 +1,6 @@
""" A dataset parser that reads tarfile based datasets
""" A dataset reader that reads tarfile based datasets
This parser can read and extract image samples from:
This reader can extract image samples from:
* a single tar of image files
* a folder of multiple tarfiles containing imagefiles
* a tar of tars containing image files
@ -22,7 +22,7 @@ from timm.utils.misc import natural_key
from .class_map import load_class_map
from .img_extensions import get_img_extensions
from .parser import Parser
from .reader import Reader
_logger = logging.getLogger(__name__)
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
@ -169,8 +169,8 @@ def extract_tarinfos(
return samples, targets, class_name_to_idx, tarfiles
class ParserImageInTar(Parser):
""" Multi-tarfile dataset parser where there is one .tar file per class
class ReaderImageInTar(Reader):
""" Multi-tarfile dataset reader where there is one .tar file per class
"""
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):

@ -1,6 +1,6 @@
""" A dataset parser that reads single tarfile based datasets
""" A dataset reader that reads single tarfile based datasets
This parser can read datasets consisting if a single tarfile containing images.
This reader can read datasets consisting if a single tarfile containing images.
I am planning to deprecated it in favour of ParerImageInTar.
Hacked together by / Copyright 2020 Ross Wightman
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
from .class_map import load_class_map
from .img_extensions import get_img_extensions
from .parser import Parser
from .reader import Reader
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
@ -38,9 +38,9 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
return tarinfo_and_targets, class_to_idx
class ParserImageTar(Parser):
class ReaderImageTar(Reader):
""" Single tarfile dataset where classes are mapped to folders within tar
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can
operate on folders of tars or tars in tars.
"""
def __init__(self, root, class_map=''):

@ -1,4 +1,4 @@
""" Dataset parser interface that wraps TFDS datasets
""" Dataset reader that wraps TFDS datasets
Wraps many (most?) TFDS image-classification datasets
from https://github.com/tensorflow/datasets
@ -7,6 +7,9 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
import os
from typing import Optional
import torch
import torch.distributed as dist
from PIL import Image
@ -30,16 +33,18 @@ except ImportError as e:
print(e)
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
exit(1)
from .parser import Parser
from .reader import Reader
from .shared_count import SharedCount
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue
PREFETCH_SIZE = 2048 # examples to prefetch
MAX_TP_SIZE = os.environ.get('TFDS_TP_SIZE', 8) # maximum TF threadpool size, for jpeg decodes and queuing activities
SHUFFLE_SIZE = os.environ.get('TFDS_SHUFFLE_SIZE', 8192) # samples to shuffle in DS queue
PREFETCH_SIZE = os.environ.get('TFDS_PREFETCH_SIZE', 2048) # samples to prefetch
def even_split_indices(split, n, num_examples):
partitions = [round(i * num_examples / n) for i in range(n + 1)]
def even_split_indices(split, n, num_samples):
partitions = [round(i * num_samples / n) for i in range(n + 1)]
return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
@ -51,24 +56,24 @@ def get_class_labels(info):
return class_to_idx
class ParserTfds(Parser):
class ReaderTfds(Reader):
""" Wrap Tensorflow Datasets for use in PyTorch
There several things to be aware of:
* To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of
* To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
https://github.com/pytorch/pytorch/issues/33413
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
from each worker could be a different size. For training this is worked around by option above, for
validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced
validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
across replicas are of same size. This will slightly alter the results, distributed validation will not be
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
since there are up to N * J extra examples with IterableDatasets.
since there are up to N * J extra samples with IterableDatasets.
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
benefit of distributed training or fast dataloading should be much less for small datasets.
* This wrapper is currently configured to return individual, decompressed image examples from the TFDS
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
components.
@ -86,9 +91,9 @@ class ParserTfds(Parser):
repeats=0,
seed=42,
input_name='image',
input_image='RGB',
input_img_mode='RGB',
target_name='label',
target_image='',
target_img_mode='',
prefetch_size=None,
shuffle_size=None,
max_threadpool_size=None
@ -100,14 +105,14 @@ class ParserTfds(Parser):
name: tfds dataset name (eg `imagenet2012`)
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
is_training: training mode, shuffle enabled, dataset len rounded by batch_size
batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes
batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
download: download and build TFDS dataset if set, otherwise must use tfds CLI
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances
input_name: name of Feature to return as data (input)
input_image: image mode if input is an image (currently PIL mode string)
input_img_mode: image mode if input is an image (currently PIL mode string)
target_name: name of Feature to return as target (label)
target_image: image mode if target is an image (currently PIL mode string)
target_img_mode: image mode if target is an image (currently PIL mode string)
prefetch_size: override default tf.data prefetch buffer size
shuffle_size: override default tf.data shuffle buffer size
max_threadpool_size: override default threadpool size for tf.data
@ -130,16 +135,16 @@ class ParserTfds(Parser):
# TFDS builder and split information
self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
self.input_image = input_image
self.input_img_mode = input_img_mode
self.target_name = target_name
self.target_image = target_image
self.target_img_mode = target_img_mode
self.builder = tfds.builder(name, data_dir=root)
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download:
self.builder.download_and_prepare()
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
self.split_info = self.builder.info.splits[split]
self.num_examples = self.split_info.num_examples
self.num_samples = self.split_info.num_examples
# Distributed world state
self.dist_rank = 0
@ -150,10 +155,29 @@ class ParserTfds(Parser):
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself
self.global_num_workers = 1
self.num_workers = 1
self.worker_info = None
self.worker_seed = 0 # seed unique to each work instance
self.subsplit = None # set when data is distributed across workers using sub-splits
self.ds = None # initialized lazily on each dataloader worker process
self.init_count = 0 # number of ds TF data pipeline initializations
self.epoch_count = SharedCount()
# FIXME need to determine if reinit_each_iter is necessary. I'm don't completely trust behaviour
# of `shuffle_reshuffle_each_iteration` when there are multiple workers / nodes across epochs
self.reinit_each_iter = self.is_training
def set_epoch(self, count):
self.epoch_count.value = count
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
if self.ds is not None:
return
if num_workers is not None:
self.num_workers = num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
def _lazy_init(self):
""" Lazily initialize the dataset.
@ -174,9 +198,9 @@ class ParserTfds(Parser):
if worker_info is not None:
self.worker_info = worker_info
self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * num_workers
global_worker_id = self.dist_rank * num_workers + worker_info.id
self.num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
global_worker_id = self.dist_rank * self.num_workers + worker_info.id
""" Data sharding
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
@ -186,17 +210,17 @@ class ParserTfds(Parser):
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding.
for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
"""
should_subsplit = self.global_num_workers > 1 and (
self.split_info.num_shards < self.global_num_workers or not self.is_training)
if should_subsplit:
# split the dataset w/o using sharding for more even examples / worker, can result in less optimal
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal
# read patterns for distributed training (overlap across shards) so better to use InputContext there
if has_buggy_even_splits:
# my even_split workaround doesn't work on subsplits, upgrade tfds!
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples)
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
self.subsplit = subsplits[global_worker_id]
else:
subsplits = tfds.even_splits(self.split, self.global_num_workers)
@ -211,15 +235,19 @@ class ParserTfds(Parser):
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
)
read_config = tfds.ReadConfig(
shuffle_seed=self.common_seed,
shuffle_seed=self.common_seed + self.epoch_count.value,
shuffle_reshuffle_each_iteration=True,
input_context=input_context)
input_context=input_context,
)
ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
split=self.subsplit or self.split,
shuffle_files=self.is_training,
read_config=read_config,
)
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
options = tf.data.Options()
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // self.num_workers)
getattr(options, thread_member).max_intra_op_parallelism = 1
ds = ds.with_options(options)
if self.is_training or self.repeats > 1:
@ -227,59 +255,65 @@ class ParserTfds(Parser):
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually
if self.is_training:
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
self.ds = tfds.as_numpy(ds)
self.init_count += 1
def _num_samples_per_worker(self):
num_worker_samples = \
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
if self.is_training or self.dist_num_replicas > 1:
num_worker_samples = math.ceil(num_worker_samples)
if self.is_training and self.batch_size is not None:
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
return int(num_worker_samples)
def __iter__(self):
if self.ds is None:
if self.ds is None or self.reinit_each_iter:
self._lazy_init()
# Compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation.
# This adds extra examples and will slightly alter validation results.
# This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around)
target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers)
if self.is_training:
# round up to nearest batch_size per worker-replica
target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size
target_sample_count = self._num_samples_per_worker()
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
example_count = 0
for example in self.ds:
input_data = example[self.input_name]
if self.input_image:
input_data = Image.fromarray(input_data, mode=self.input_image)
target_data = example[self.target_name]
if self.target_image:
target_data = Image.fromarray(target_data, mode=self.target_image)
sample_count = 0
for sample in self.ds:
input_data = sample[self.input_name]
if self.input_img_mode:
input_data = Image.fromarray(input_data, mode=self.input_img_mode)
target_data = sample[self.target_name]
if self.target_img_mode:
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
yield input_data, target_data
example_count += 1
if self.is_training and example_count >= target_example_count:
sample_count += 1
if self.is_training and sample_count >= target_sample_count:
# Need to break out of loop when repeat() is enabled for training w/ oversampling
# this results in extra examples per epoch but seems more desirable than dropping
# this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break
# Pad across distributed nodes (make counts equal by adding examples)
# Pad across distributed nodes (make counts equal by adding samples)
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
0 < example_count < target_example_count:
0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes.
# If using input_context or % based splits, sample count can vary significantly across workers and this
# approach should not be used (hence disabled if self.subsplit isn't set).
while example_count < target_example_count:
while sample_count < target_sample_count:
yield input_data, target_data # yield prev sample again
example_count += 1
sample_count += 1
def __len__(self):
# this is just an estimate and does not factor in extra examples added to pad batches based on
# complete worker & replica info (not available until init in dataloader).
return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas)
num_samples = self._num_samples_per_worker() * self.num_workers
return num_samples
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples
assert False, "Not supported" # no random access to samples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
@ -287,7 +321,7 @@ class ParserTfds(Parser):
self._lazy_init()
names = []
for sample in self.ds:
if len(names) > self.num_examples:
if len(names) > self.num_samples:
break # safety for ds.repeat() case
if 'file_name' in sample:
name = sample['file_name']

@ -0,0 +1,461 @@
""" Dataset reader for webdataset
Hacked together by / Copyright 2022 Ross Wightman
"""
import io
import json
import logging
import math
import os
import random
import sys
from dataclasses import dataclass
from functools import partial
from itertools import islice
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
import yaml
from PIL import Image
from torch.utils.data import Dataset, IterableDataset, get_worker_info
try:
import webdataset as wds
from webdataset.filters import _shuffle
from webdataset.shardlists import expand_urls
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
except ImportError:
wds = None
expand_urls = None
from .reader import Reader
from .shared_count import SharedCount
_logger = logging.getLogger(__name__)
SHUFFLE_SIZE = os.environ.get('WDS_SHUFFLE_SIZE', 8192)
def _load_info(root, basename='info'):
info_json = os.path.join(root, basename + '.json')
info_yaml = os.path.join(root, basename + '.yaml')
err_str = ''
try:
with wds.gopen.gopen(info_json) as f:
info_dict = json.load(f)
return info_dict
except Exception as e:
err_str = str(e)
try:
with wds.gopen.gopen(info_yaml) as f:
info_dict = yaml.safe_load(f)
return info_dict
except Exception:
pass
_logger.warning(
f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. '
'Falling back to provided split and size arg.')
return {}
@dataclass
class SplitInfo:
num_samples: int
filenames: Tuple[str]
shard_lengths: Tuple[int] = ()
alt_label: str = ''
name: str = ''
def _parse_split_info(split: str, info: Dict):
def _info_convert(dict_info):
return SplitInfo(
num_samples=dict_info['num_samples'],
filenames=tuple(dict_info['filenames']),
shard_lengths=tuple(dict_info['shard_lengths']),
alt_label=dict_info.get('alt_label', ''),
name=dict_info['name'],
)
if 'tar' in split or '..' in split:
# split in WDS string braceexpand format, sample count can be included with a | separator
# ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples
split = split.split('|')
num_samples = 0
split_name = ''
if len(split) > 1:
num_samples = int(split[1])
split = split[0]
if '::' not in split:
split_parts = split.split('-', 3)
split_idx = len(split_parts) - 1
if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']:
split_name = split_parts[split_idx]
split_filenames = expand_urls(split)
if split_name:
split_info = info['splits'][split_name]
if not num_samples:
_fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])}
num_samples = sum(_fc[f] for f in split_filenames)
split_info['filenames'] = tuple(_fc.keys())
split_info['shard_lengths'] = tuple(_fc.values())
split_info['num_samples'] = num_samples
split_info = _info_convert(split_info)
else:
split_info = SplitInfo(
name=split_name,
num_samples=num_samples,
filenames=split_filenames,
)
else:
if split not in info['splits']:
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
split = split
split_info = info['splits'][split]
split_info = _info_convert(split_info)
return split_info
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
return True
def _decode(
sample,
image_key='jpg',
image_format='RGB',
target_key='cls',
alt_label=''
):
""" Custom sample decode
* decode and convert PIL Image
* cls byte string label to int
* pass through JSON byte string (if it exists) without parse
"""
# decode class label, skip if alternate label not valid
if alt_label:
# alternative labels are encoded in json metadata
meta = json.loads(sample['json'])
class_label = int(meta[alt_label])
if class_label < 0:
# skipped labels currently encoded as -1, may change to a null/None value
return None
else:
class_label = int(sample[target_key])
# decode image
with io.BytesIO(sample[image_key]) as b:
img = Image.open(b)
img.load()
if image_format:
img = img.convert(image_format)
# json passed through in undecoded state
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
return decoded
def _decode_samples(
data,
image_key='jpg',
image_format='RGB',
target_key='cls',
alt_label='',
handler=log_and_continue):
"""Decode samples with skip."""
for sample in data:
try:
result = _decode(
sample,
image_key=image_key,
image_format=image_format,
target_key=target_key,
alt_label=alt_label
)
except Exception as exn:
if handler(exn):
continue
else:
break
# null results are skipped
if result is not None:
if isinstance(sample, dict) and isinstance(result, dict):
result["__key__"] = sample.get("__key__")
yield result
def pytorch_worker_seed():
"""get dataloader worker seed from pytorch"""
worker_info = get_worker_info()
if worker_info is not None:
# favour the seed already created for pytorch dataloader workers if it exists
return worker_info.seed
# fallback to wds rank based seed
return wds.utils.pytorch_worker_seed()
if wds is not None:
# conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage)
class detshuffle2(wds.PipelineStage):
def __init__(
self,
bufsize=1000,
initial=100,
seed=0,
epoch=-1,
):
self.bufsize = bufsize
self.initial = initial
self.seed = seed
self.epoch = epoch
def run(self, src):
if isinstance(self.epoch, SharedCount):
epoch = self.epoch.value
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
if self.seed < 0:
seed = pytorch_worker_seed() + epoch
else:
seed = self.seed + epoch
# _logger.info(f'shuffle seed: {self.seed}, {seed}, epoch: {epoch}') # FIXME temporary
rng = random.Random(seed)
return _shuffle(src, self.bufsize, self.initial, rng)
else:
detshuffle2 = None
class ResampledShards2(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(
self,
urls,
nshards=sys.maxsize,
worker_seed=None,
deterministic=True,
epoch=-1,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls = wds.shardlists.expand_urls(urls)
self.urls = urls
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.rng = random.Random()
self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
self.deterministic = deterministic
self.epoch = epoch
def __iter__(self):
"""Return an iterator over the shards."""
if isinstance(self.epoch, SharedCount):
epoch = self.epoch.value
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
if self.deterministic:
# reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
self.rng = random.Random(self.worker_seed() + epoch)
for _ in range(self.nshards):
index = self.rng.randint(0, len(self.urls) - 1)
yield dict(url=self.urls[index])
class ReaderWds(Reader):
def __init__(
self,
root,
name,
split,
is_training=False,
batch_size=None,
repeats=0,
seed=42,
input_name='jpg',
input_image='RGB',
target_name='cls',
target_image='',
prefetch_size=None,
shuffle_size=None,
):
super().__init__()
if wds is None:
raise RuntimeError(
'Please install webdataset 0.2.x package `pip install git+https://github.com/webdataset/webdataset`.')
self.root = root
self.is_training = is_training
self.batch_size = batch_size
self.repeats = repeats
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.shard_shuffle_size = 500
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
self.image_key = input_name
self.image_format = input_image
self.target_key = target_name
self.filename_key = 'filename'
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
self.info = _load_info(self.root)
self.split_info = _parse_split_info(split, self.info)
self.num_samples = self.split_info.num_samples
if not self.num_samples:
raise RuntimeError(f'Invalid split definition, no samples found.')
# Distributed world state
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init
self.worker_info = None
self.worker_id = 0
self.worker_seed = seed # seed unique to each worker instance
self.num_workers = 1
self.global_worker_id = 0
self.global_num_workers = 1
self.init_count = 0
self.epoch_count = SharedCount()
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
# is not handled in manner where it can be deterministic for each worker AND initialized up front
self.ds = None
def set_epoch(self, count):
self.epoch_count.value = count
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
if self.ds is not None:
return
if num_workers is not None:
self.num_workers = num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
def _lazy_init(self):
""" Lazily initialize worker (in worker processes)
"""
if self.worker_info is None:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.worker_info = worker_info
self.worker_id = worker_info.id
self.worker_seed = worker_info.seed
self.num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
# init data pipeline
abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames]
pipeline = [wds.SimpleShardList(abs_shard_filenames)]
# at this point we have an iterator over all the shards
if self.is_training:
pipeline.extend([
detshuffle2(self.shard_shuffle_size, seed=self.common_seed, epoch=self.epoch_count),
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(handler=log_and_continue),
wds.shuffle(
self.sample_shuffle_size,
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
])
else:
pipeline.extend([
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(handler=log_and_continue),
])
pipeline.extend([
partial(
_decode_samples,
image_key=self.image_key,
image_format=self.image_format,
alt_label=self.split_info.alt_label
)
])
self.ds = wds.DataPipeline(*pipeline)
def _split_by_node_and_worker(self, src):
if self.global_num_workers > 1:
for s in islice(src, self.global_worker_id, None, self.global_num_workers):
yield s
else:
for s in src:
yield s
def _num_samples_per_worker(self):
num_worker_samples = self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
if self.is_training or self.dist_num_replicas > 1:
num_worker_samples = math.ceil(num_worker_samples)
if self.is_training and self.batch_size is not None:
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
return int(num_worker_samples)
def __iter__(self):
if self.ds is None:
self._lazy_init()
num_worker_samples = self._num_samples_per_worker()
if self.is_training or self.dist_num_replicas > 1:
# NOTE: doing distributed validation w/ WDS is messy, hard to meet constraints that
# same # of batches needed across all replicas w/ seeing each sample once.
# with_epoch() is simple but could miss a shard's worth of samples in some workers,
# and duplicate in others. Best to keep num DL workers low and a divisor of #val shards.
ds = self.ds.with_epoch(num_worker_samples)
else:
ds = self.ds
i = 0
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
for sample in ds:
yield sample[self.image_key], sample[self.target_key]
i += 1
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
def __len__(self):
num_samples = self._num_samples_per_worker() * self.num_workers
return num_samples
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
if self.ds is None:
self._lazy_init()
names = []
for sample in self.ds:
if self.filename_key in sample:
name = sample[self.filename_key]
elif '__key__' in sample:
name = sample['__key__'] + self.key_ext
else:
assert False, "No supported name field present"
names.append(name)
if len(names) >= self.num_samples:
break # safety for ds.repeat() case
return names

@ -0,0 +1,14 @@
from multiprocessing import Value
class SharedCount:
def __init__(self, epoch: int = 0):
self.shared_epoch = Value('i', epoch)
@property
def value(self):
return self.shared_epoch.value
@value.setter
def value(self, epoch):
self.shared_epoch.value = epoch

@ -63,7 +63,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'):
@ -72,10 +72,28 @@ def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema)
if remap:
state_dict = remap_checkpoint(model, state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def remap_checkpoint(model, state_dict, allow_reshape=True):
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
This assumes models (and originating state dict) were created with params registered in same order.
"""
out_dict = {}
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
if va.shape != vb.shape:
if allow_reshape:
vb = vb.reshape(va.shape)
else:
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
out_dict[ka] = vb
return out_dict
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
resume_epoch = None
if os.path.isfile(checkpoint_path):

@ -72,3 +72,31 @@ class EffectiveSEModule(nn.Module):
EffectiveSqueezeExcite = EffectiveSEModule # alias
class SqueezeExciteCl(nn.Module):
""" SE Module as defined in original SE-Nets with a few additions
Additions include:
* divisor can be specified to keep channels % div == 0 (default: 8)
* reduction channels can be specified directly by arg (if rd_channels is set)
* reduction channels can be specified by float rd_ratio (default: 1/16)
* global max pooling can be added to the squeeze aggregation
* customizable activation, normalization, and gate layer
"""
def __init__(
self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8,
bias=True, act_layer=nn.ReLU, gate_layer='sigmoid'):
super().__init__()
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.fc1 = nn.Linear(channels, rd_channels, bias=bias)
self.act = create_act_layer(act_layer, inplace=True)
self.fc2 = nn.Linear(rd_channels, channels, bias=bias)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((1, 2), keepdims=True) # FIXME avg dim [1:n-1], don't assume 2D NHWC
x_se = self.fc1(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * self.gate(x_se)

@ -0,0 +1,124 @@
""" Adan Optimizer
Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
https://arxiv.org/abs/2208.06677
Implementation adapted from https://github.com/sail-sg/Adan
"""
import math
import torch
from torch.optim import Optimizer
class Adan(Optimizer):
"""
Implements a pytorch variant of Adan
Adan was proposed in
Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
https://arxiv.org/abs/2208.06677
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float, flot], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.98, 0.92, 0.99))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0)
no_prox (bool): how to perform the decoupled weight decay (default: False)
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.98, 0.92, 0.99),
eps=1e-8,
weight_decay=0.0,
no_prox=False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= betas[2] < 1.0:
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, no_prox=no_prox)
super(Adan, self).__init__(params, defaults)
@torch.no_grad()
def restart_opt(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
if p.requires_grad:
state = self.state[p]
# State initialization
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
# Exponential moving average of gradient difference
state['exp_avg_diff'] = torch.zeros_like(p)
@torch.no_grad()
def step(self, closure=None):
""" Performs a single optimization step.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
beta1, beta2, beta3 = group['betas']
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
bias_correction1 = 1.0 - beta1 ** group['step']
bias_correction2 = 1.0 - beta2 ** group['step']
bias_correction3 = 1.0 - beta3 ** group['step']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_diff'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['pre_grad'] = grad.clone()
exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq']
grad_diff = grad - state['pre_grad']
exp_avg.lerp_(grad, 1. - beta1) # m_t
exp_avg_diff.lerp_(grad_diff, 1. - beta2) # diff_t (v)
update = grad + beta2 * grad_diff
exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1. - beta3) # n_t
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction3)).add_(group['eps'])
update = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2).div_(denom)
if group['no_prox']:
p.data.mul_(1 - group['lr'] * group['weight_decay'])
p.add_(update, alpha=-group['lr'])
else:
p.add_(update, alpha=-group['lr'])
p.data.div_(1 + group['lr'] * group['weight_decay'])
state['pre_grad'].copy_(grad)
return loss

@ -15,6 +15,7 @@ from .adabelief import AdaBelief
from .adafactor import Adafactor
from .adahessian import Adahessian
from .adamp import AdamP
from .adan import Adan
from .lamb import Lamb
from .lars import Lars
from .lookahead import Lookahead
@ -192,7 +193,8 @@ def create_optimizer_v2(
filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable] = None,
**kwargs):
**kwargs,
):
""" Create an optimizer.
TODO currently the model is passed in and all parameters are selected for optimization.
@ -285,6 +287,10 @@ def create_optimizer_v2(
optimizer = optim.Adagrad(parameters, **opt_args)
elif opt_lower == 'adafactor':
optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'adanp':
optimizer = Adan(parameters, no_prox=False, **opt_args)
elif opt_lower == 'adanw':
optimizer = Adan(parameters, no_prox=True, **opt_args)
elif opt_lower == 'lamb':
optimizer = Lamb(parameters, **opt_args)
elif opt_lower == 'lambc':

@ -5,4 +5,4 @@ from .poly_lr import PolyLRScheduler
from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler
from .scheduler_factory import create_scheduler
from .scheduler_factory import create_scheduler, create_scheduler_v2, scheduler_kwargs

@ -26,7 +26,8 @@ class CosineLRScheduler(Scheduler):
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
"""
def __init__(self,
def __init__(
self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lr_min: float = 0.,
@ -42,16 +43,24 @@ class CosineLRScheduler(Scheduler):
noise_std=1.0,
noise_seed=42,
k_decay=1.0,
initialize=True) -> None:
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
optimizer,
param_group_field="lr",
t_in_epochs=t_in_epochs,
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
assert t_initial > 0
assert lr_min >= 0
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
_logger.warning("Cosine annealing scheduler will have no effect on the learning "
_logger.warning(
"Cosine annealing scheduler will have no effect on the learning "
"rate since t_initial = t_mul = eta_mul = 1.")
self.t_initial = t_initial
self.lr_min = lr_min
@ -61,7 +70,6 @@ class CosineLRScheduler(Scheduler):
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
self.k_decay = k_decay
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
@ -99,18 +107,6 @@ class CosineLRScheduler(Scheduler):
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:

@ -11,12 +11,14 @@ class MultiStepLRScheduler(Scheduler):
"""
"""
def __init__(self,
def __init__(
self,
optimizer: torch.optim.Optimizer,
decay_t: List[int],
decay_rate: float = 1.,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=True,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
@ -25,15 +27,21 @@ class MultiStepLRScheduler(Scheduler):
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
optimizer,
param_group_field="lr",
t_in_epochs=t_in_epochs,
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
self.decay_t = decay_t
self.decay_rate = decay_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
self.warmup_prefix = warmup_prefix
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
@ -43,23 +51,13 @@ class MultiStepLRScheduler(Scheduler):
def get_curr_decay_steps(self, t):
# find where in the array t goes,
# assumes self.decay_t is sorted
return bisect.bisect_right(self.decay_t, t+1)
return bisect.bisect_right(self.decay_t, t + 1)
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
if self.warmup_prefix:
t = t - self.warmup_t
lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None

@ -12,7 +12,8 @@ from .scheduler import Scheduler
class PlateauLRScheduler(Scheduler):
"""Decay the LR by a factor every time the validation loss plateaus."""
def __init__(self,
def __init__(
self,
optimizer,
decay_rate=0.1,
patience_t=10,
@ -89,6 +90,9 @@ class PlateauLRScheduler(Scheduler):
if self._is_apply_noise(epoch):
self._apply_noise(epoch)
def step_update(self, num_updates: int, metric: float = None):
return None
def _apply_noise(self, epoch):
noise = self._calculate_noise(epoch)
@ -101,3 +105,6 @@ class PlateauLRScheduler(Scheduler):
new_lr = old_lr + old_lr * noise
param_group['lr'] = new_lr
self.restore_lr = restore_lr
def _get_lr(self, t: int) -> float:
assert False, 'should not be called as step is overridden'

@ -21,7 +21,8 @@ class PolyLRScheduler(Scheduler):
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
"""
def __init__(self,
def __init__(
self,
optimizer: torch.optim.Optimizer,
t_initial: int,
power: float = 0.5,
@ -38,11 +39,18 @@ class PolyLRScheduler(Scheduler):
noise_std=1.0,
noise_seed=42,
k_decay=1.0,
initialize=True) -> None:
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
optimizer,
param_group_field="lr",
t_in_epochs=t_in_epochs,
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize
)
assert t_initial > 0
assert lr_min >= 0
@ -58,7 +66,6 @@ class PolyLRScheduler(Scheduler):
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
self.k_decay = k_decay
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
@ -96,18 +103,6 @@ class PolyLRScheduler(Scheduler):
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:

@ -1,9 +1,11 @@
from typing import Dict, Any
import abc
from abc import ABC
from typing import Any, Dict, Optional
import torch
class Scheduler:
class Scheduler(ABC):
""" Parameter Scheduler Base Class
A scheduler base class that can be used to schedule any optimizer parameter groups.
@ -22,15 +24,18 @@ class Scheduler:
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
"""
def __init__(self,
def __init__(
self,
optimizer: torch.optim.Optimizer,
param_group_field: str,
t_in_epochs: bool = True,
noise_range_t=None,
noise_type='normal',
noise_pct=0.67,
noise_std=1.0,
noise_seed=None,
initialize: bool = True) -> None:
initialize: bool = True,
) -> None:
self.optimizer = optimizer
self.param_group_field = param_group_field
self._initial_param_group_field = f"initial_{param_group_field}"
@ -45,6 +50,7 @@ class Scheduler:
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
self.metric = None # any point to having this for all?
self.t_in_epochs = t_in_epochs
self.noise_range_t = noise_range_t
self.noise_pct = noise_pct
self.noise_type = noise_type
@ -58,22 +64,26 @@ class Scheduler:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.__dict__.update(state_dict)
def get_epoch_values(self, epoch: int):
return None
@abc.abstractmethod
def _get_lr(self, t: int) -> float:
pass
def get_update_values(self, num_updates: int):
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]:
proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
if not proceed:
return None
return self._get_lr(t)
def step(self, epoch: int, metric: float = None) -> None:
self.metric = metric
values = self.get_epoch_values(epoch)
values = self._get_values(epoch, on_epoch=True)
if values is not None:
values = self._add_noise(values, epoch)
self.update_groups(values)
def step_update(self, num_updates: int, metric: float = None):
self.metric = metric
values = self.get_update_values(num_updates)
values = self._get_values(num_updates, on_epoch=False)
if values is not None:
values = self._add_noise(values, num_updates)
self.update_groups(values)

@ -1,6 +1,10 @@
""" Scheduler Factory
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import List, Union
from torch.optim import Optimizer
from .cosine_lr import CosineLRScheduler
from .multistep_lr import MultiStepLRScheduler
from .plateau_lr import PlateauLRScheduler
@ -9,99 +13,191 @@ from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler
def create_scheduler(args, optimizer):
num_epochs = args.epochs
def scheduler_kwargs(cfg):
""" cfg/argparse to kwargs helper
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
"""
eval_metric = getattr(cfg, 'eval_metric', 'top1')
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
kwargs = dict(
sched=cfg.sched,
num_epochs=getattr(cfg, 'epochs', 100),
decay_epochs=getattr(cfg, 'decay_epochs', 30),
decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]),
warmup_epochs=getattr(cfg, 'warmup_epochs', 5),
cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0),
patience_epochs=getattr(cfg, 'patience_epochs', 10),
decay_rate=getattr(cfg, 'decay_rate', 0.1),
min_lr=getattr(cfg, 'min_lr', 0.),
warmup_lr=getattr(cfg, 'warmup_lr', 1e-5),
warmup_prefix=getattr(cfg, 'warmup_prefix', False),
noise=getattr(cfg, 'lr_noise', None),
noise_pct=getattr(cfg, 'lr_noise_pct', 0.67),
noise_std=getattr(cfg, 'lr_noise_std', 1.),
noise_seed=getattr(cfg, 'seed', 42),
cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.),
cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1),
cycle_limit=getattr(cfg, 'lr_cycle_limit', 1),
k_decay=getattr(cfg, 'lr_k_decay', 1.0),
plateau_mode=plateau_mode,
step_on_epochs=not getattr(cfg, 'sched_on_updates', False),
)
return kwargs
def create_scheduler(
args,
optimizer: Optimizer,
updates_per_epoch: int = 0,
):
return create_scheduler_v2(
optimizer=optimizer,
**scheduler_kwargs(args),
updates_per_epoch=updates_per_epoch,
)
def create_scheduler_v2(
optimizer: Optimizer,
sched: str = 'cosine',
num_epochs: int = 300,
decay_epochs: int = 90,
decay_milestones: List[int] = (90, 180, 270),
cooldown_epochs: int = 0,
patience_epochs: int = 10,
decay_rate: float = 0.1,
min_lr: float = 0,
warmup_lr: float = 1e-5,
warmup_epochs: int = 0,
warmup_prefix: bool = False,
noise: Union[float, List[float]] = None,
noise_pct: float = 0.67,
noise_std: float = 1.,
noise_seed: int = 42,
cycle_mul: float = 1.,
cycle_decay: float = 0.1,
cycle_limit: int = 1,
k_decay: float = 1.0,
plateau_mode: str = 'max',
step_on_epochs: bool = True,
updates_per_epoch: int = 0,
):
t_initial = num_epochs
warmup_t = warmup_epochs
decay_t = decay_epochs
cooldown_t = cooldown_epochs
if not step_on_epochs:
assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches'
t_initial = t_initial * updates_per_epoch
warmup_t = warmup_t * updates_per_epoch
decay_t = decay_t * updates_per_epoch
decay_milestones = [d * updates_per_epoch for d in decay_milestones]
cooldown_t = cooldown_t * updates_per_epoch
# warmup args
warmup_args = dict(
warmup_lr_init=warmup_lr,
warmup_t=warmup_t,
warmup_prefix=warmup_prefix,
)
if getattr(args, 'lr_noise', None) is not None:
lr_noise = getattr(args, 'lr_noise')
if isinstance(lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in lr_noise]
# setup noise args for supporting schedulers
if noise is not None:
if isinstance(noise, (list, tuple)):
noise_range = [n * t_initial for n in noise]
if len(noise_range) == 1:
noise_range = noise_range[0]
else:
noise_range = lr_noise * num_epochs
noise_range = noise * t_initial
else:
noise_range = None
noise_args = dict(
noise_range_t=noise_range,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
)
# setup cycle args for supporting schedulers
cycle_args = dict(
cycle_mul=getattr(args, 'lr_cycle_mul', 1.),
cycle_decay=getattr(args, 'lr_cycle_decay', 0.1),
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
cycle_mul=cycle_mul,
cycle_decay=cycle_decay,
cycle_limit=cycle_limit,
)
lr_scheduler = None
if args.sched == 'cosine':
if sched == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_epochs,
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
k_decay=getattr(args, 'lr_k_decay', 1.0),
t_initial=t_initial,
lr_min=min_lr,
t_in_epochs=step_on_epochs,
**cycle_args,
**warmup_args,
**noise_args,
k_decay=k_decay,
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == 'tanh':
elif sched == 'tanh':
lr_scheduler = TanhLRScheduler(
optimizer,
t_initial=num_epochs,
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
t_in_epochs=True,
t_initial=t_initial,
lr_min=min_lr,
t_in_epochs=step_on_epochs,
**cycle_args,
**warmup_args,
**noise_args,
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == 'step':
elif sched == 'step':
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
decay_t=decay_t,
decay_rate=decay_rate,
t_in_epochs=step_on_epochs,
**warmup_args,
**noise_args,
)
elif args.sched == 'multistep':
elif sched == 'multistep':
lr_scheduler = MultiStepLRScheduler(
optimizer,
decay_t=args.decay_milestones,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
decay_t=decay_milestones,
decay_rate=decay_rate,
t_in_epochs=step_on_epochs,
**warmup_args,
**noise_args,
)
elif args.sched == 'plateau':
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
elif sched == 'plateau':
assert step_on_epochs, 'Plateau LR only supports step per epoch.'
warmup_args.pop('warmup_prefix', False)
lr_scheduler = PlateauLRScheduler(
optimizer,
decay_rate=args.decay_rate,
patience_t=args.patience_epochs,
lr_min=args.min_lr,
mode=mode,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
decay_rate=decay_rate,
patience_t=patience_epochs,
cooldown_t=0,
**warmup_args,
lr_min=min_lr,
mode=plateau_mode,
**noise_args,
)
elif args.sched == 'poly':
elif sched == 'poly':
lr_scheduler = PolyLRScheduler(
optimizer,
power=args.decay_rate, # overloading 'decay_rate' as polynomial power
t_initial=num_epochs,
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
k_decay=getattr(args, 'lr_k_decay', 1.0),
power=decay_rate, # overloading 'decay_rate' as polynomial power
t_initial=t_initial,
lr_min=min_lr,
t_in_epochs=step_on_epochs,
k_decay=k_decay,
**cycle_args,
**warmup_args,
**noise_args,
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
if hasattr(lr_scheduler, 'get_cycle_length'):
# for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
if step_on_epochs:
num_epochs = t_with_cycles_and_cooldown
else:
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
return lr_scheduler, num_epochs

@ -14,12 +14,14 @@ class StepLRScheduler(Scheduler):
"""
"""
def __init__(self,
def __init__(
self,
optimizer: torch.optim.Optimizer,
decay_t: float,
decay_rate: float = 1.,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=True,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
@ -28,15 +30,21 @@ class StepLRScheduler(Scheduler):
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
optimizer,
param_group_field="lr",
t_in_epochs=t_in_epochs,
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
self.decay_t = decay_t
self.decay_rate = decay_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
self.warmup_prefix = warmup_prefix
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
@ -47,17 +55,7 @@ class StepLRScheduler(Scheduler):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
if self.warmup_prefix:
t = t - self.warmup_t
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None

@ -21,7 +21,8 @@ class TanhLRScheduler(Scheduler):
This is described in the paper https://arxiv.org/abs/1806.01593
"""
def __init__(self,
def __init__(
self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lb: float = -7.,
@ -38,11 +39,18 @@ class TanhLRScheduler(Scheduler):
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True) -> None:
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
optimizer,
param_group_field="lr",
t_in_epochs=t_in_epochs,
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
assert t_initial > 0
assert lr_min >= 0
@ -60,7 +68,6 @@ class TanhLRScheduler(Scheduler):
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
if self.warmup_t:
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
@ -97,18 +104,6 @@ class TanhLRScheduler(Scheduler):
lrs = [self.lr_min for _ in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:

@ -3,7 +3,8 @@ from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler
from .decay_batch import decay_batch_step, check_batch_size_retry
from .distributed import distribute_bn, reduce_tensor
from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
world_info_from_env, is_distributed_env, is_primary
from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy

@ -2,9 +2,16 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import torch
from torch import distributed as dist
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from .model import unwrap_model
@ -26,3 +33,105 @@ def distribute_bn(model, world_size, reduce=False):
else:
# broadcast bn stats from rank 0 to whole group
torch.distributed.broadcast(bn_buf, 0)
def is_global_primary(args):
return args.rank == 0
def is_local_primary(args):
return args.local_rank == 0
def is_primary(args, local=False):
return is_local_primary(args) if local else is_global_primary(args)
def is_distributed_env():
if 'WORLD_SIZE' in os.environ:
return int(os.environ['WORLD_SIZE']) > 1
if 'SLURM_NTASKS' in os.environ:
return int(os.environ['SLURM_NTASKS']) > 1
return False
def world_info_from_env():
local_rank = 0
for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
if v in os.environ:
world_size = int(os.environ[v])
break
return local_rank, global_rank, world_size
def init_distributed_device(args):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args.distributed = False
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
# TBD, support horovod?
# if args.horovod:
# assert hvd is not None, "Horovod is not installed"
# hvd.init()
# args.local_rank = int(hvd.local_rank())
# args.rank = hvd.rank()
# args.world_size = hvd.size()
# args.distributed = True
# os.environ['LOCAL_RANK'] = str(args.local_rank)
# os.environ['RANK'] = str(args.rank)
# os.environ['WORLD_SIZE'] = str(args.world_size)
dist_backend = getattr(args, 'dist_backend', 'nccl')
dist_url = getattr(args, 'dist_url', 'env://')
if is_distributed_env():
if 'SLURM_PROCID' in os.environ:
# DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ['LOCAL_RANK'] = str(args.local_rank)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
torch.distributed.init_process_group(
backend=dist_backend,
init_method=dist_url,
world_size=args.world_size,
rank=args.rank,
)
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=dist_backend,
init_method=dist_url,
)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
if torch.cuda.is_available():
if args.distributed:
device = 'cuda:%d' % args.local_rank
else:
device = 'cuda:0'
torch.cuda.set_device(device)
else:
device = 'cpu'
args.device = device
device = torch.device(device)
return device

@ -10,6 +10,7 @@ try:
except ImportError:
pass
def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths)
if not os.path.exists(outdir):
@ -26,10 +27,20 @@ def get_outdir(path, *paths, inc=False):
return outdir
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False):
def update_summary(
epoch,
train_metrics,
eval_metrics,
filename,
lr=None,
write_header=False,
log_wandb=False,
):
rowd = OrderedDict(epoch=epoch)
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
if lr is not None:
rowd['lr'] = lr
if log_wandb:
wandb.log(rowd)
with open(filename, mode='a') as cf:

@ -21,6 +21,7 @@ import time
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from functools import partial
import torch
import torch.nn as nn
@ -35,7 +36,7 @@ from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntrop
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
from timm.utils import ApexScaler, NativeScaler
try:
@ -66,7 +67,6 @@ except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
@ -111,7 +111,9 @@ group.add_argument('--num-classes', type=int, default=None, metavar='N',
group.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)')
help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=None, type=float,
@ -161,10 +163,18 @@ group.add_argument('--layer-decay', type=float, default=None,
# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
help='LR scheduler (default: "step"')
group.add_argument('--lr', type=float, default=0.05, metavar='LR',
help='learning rate (default: 0.05)')
group.add_argument('--sched-on-updates', action='store_true', default=False,
help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
@ -179,23 +189,25 @@ group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
group.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=0, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES",
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=100, metavar='N',
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
group.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
group.add_argument('--warmup-prefix', action='store_true', default=False,
help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
@ -303,10 +315,10 @@ group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
group.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--pin-mem', action='store_true', default=False,
@ -349,49 +361,42 @@ def main():
utils.setup_default_logging()
args, args_text = _parse_args()
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
device = utils.init_distributed_device(args)
if args.distributed:
if 'LOCAL_RANK' in os.environ:
args.local_rank = int(os.getenv('LOCAL_RANK'))
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
_logger.info(
'Training in distributed mode with multiple processes, 1 device per process.'
f'Process {args.rank}, total {args.world_size}, device {args.device}.')
else:
_logger.info('Training with a single process on 1 GPUs.')
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0
if args.rank == 0 and args.log_wandb:
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
if args.amp:
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
if args.apex_amp and has_apex:
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
use_amp = 'apex'
elif args.native_amp and has_native_amp:
assert args.amp_dtype == 'float16'
else:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
assert args.amp_dtype in ('float16', 'bfloat16')
if args.amp_dtype == 'bfloat16':
amp_dtype = torch.bfloat16
utils.random_seed(args.seed, args.rank)
@ -400,19 +405,26 @@ def main():
if args.fast_norm:
set_fast_norm()
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chanes
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
pretrained=args.pretrained,
in_chans=in_chans,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint)
checkpoint_path=args.initial_checkpoint,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
@ -420,11 +432,11 @@ def main():
if args.grad_checkpointing:
model.set_grad_checkpointing(enable=True)
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
@ -438,9 +450,9 @@ def main():
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set
model.cuda()
model.to(device=device)
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
@ -452,7 +464,7 @@ def main():
model = convert_syncbn_model(model)
else:
model = convert_sync_batchnorm(model)
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
@ -461,38 +473,56 @@ def main():
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
if args.lr is None:
global_batch_size = args.batch_size * args.world_size
batch_ratio = global_batch_size / args.lr_base_size
if not args.lr_base_scale:
on = args.opt.lower()
args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
if args.lr_base_scale == 'sqrt':
batch_ratio = batch_ratio ** 0.5
args.lr = args.lr_base * batch_ratio
if utils.is_primary(args):
_logger.info(
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
assert device.type == 'cuda'
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
if device.type == 'cuda':
loss_scaler = NativeScaler()
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
resume_epoch = resume_checkpoint(
model, args.resume,
model,
args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
log_info=utils.is_primary(args),
)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
@ -507,41 +537,37 @@ def main():
if args.distributed:
if has_apex and use_amp == 'apex':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets
dataset_train = create_dataset(
args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
args.dataset,
root=args.data_dir,
split=args.train_split,
is_training=True,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
repeats=args.epoch_repeats)
seed=args.seed,
repeats=args.epoch_repeats,
)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
args.dataset,
root=args.data_dir,
split=args.val_split,
is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size)
batch_size=args.batch_size,
)
# setup mixup / cutmix
collate_fn = None
@ -549,9 +575,15 @@ def main():
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes)
mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix,
cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob,
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.smoothing,
num_classes=args.num_classes
)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
@ -592,10 +624,15 @@ def main():
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
device=device,
use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding,
)
eval_workers = args.workers
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
eval_workers = min(2, args.workers)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
@ -605,10 +642,11 @@ def main():
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
num_workers=eval_workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
device=device,
)
# setup loss function
@ -628,8 +666,8 @@ def main():
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
train_loss_fn = train_loss_fn.to(device=device)
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
@ -637,7 +675,7 @@ def main():
best_epoch = None
saver = None
output_dir = None
if args.rank == 0:
if utils.is_primary(args):
if args.experiment:
exp_name = args.experiment
else:
@ -649,60 +687,136 @@ def main():
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = utils.CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
model=model,
optimizer=optimizer,
args=args,
model_ema=model_ema,
amp_scaler=loss_scaler,
checkpoint_dir=output_dir,
recovery_dir=output_dir,
decreasing=decreasing,
max_history=args.checkpoint_hist
)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
# setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train)
lr_scheduler, num_epochs = create_scheduler_v2(
optimizer,
**scheduler_kwargs(args),
updates_per_epoch=updates_per_epoch,
)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
if args.step_on_updates:
lr_scheduler.step_update(start_epoch * updates_per_epoch)
else:
lr_scheduler.step(start_epoch)
if utils.is_primary(args):
_logger.info(
f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
try:
for epoch in range(start_epoch, num_epochs):
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
if hasattr(dataset_train, 'set_epoch'):
dataset_train.set_epoch(epoch)
elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
train_metrics = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
epoch,
model,
loader_train,
optimizer,
train_loss_fn,
args,
lr_scheduler=lr_scheduler,
saver=saver,
output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_ema=model_ema,
mixup_fn=mixup_fn,
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info("Distributing BatchNorm running means and vars")
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
eval_metrics = validate(
model,
loader_eval,
validate_loss_fn,
args,
amp_autocast=amp_autocast,
)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
model_ema.module,
loader_eval,
validate_loss_fn,
args,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
if output_dir is not None:
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
utils.update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
epoch,
train_metrics,
eval_metrics,
filename=os.path.join(output_dir, 'summary.csv'),
lr=sum(lrs) / len(lrs),
write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb,
)
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
epoch,
model,
loader,
optimizer,
loss_fn,
args,
device=torch.device('cuda'),
lr_scheduler=None,
saver=None,
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
@ -717,13 +831,14 @@ def train_one_epoch(
model.train()
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
num_batches_per_epoch = len(loader)
last_idx = num_batches_per_epoch - 1
num_updates = epoch * num_batches_per_epoch
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
input, target = input.to(device), target.to(device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
@ -740,21 +855,26 @@ def train_one_epoch(
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
create_graph=second_order
)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
value=args.clip_grad,
mode=args.clip_mode
)
optimizer.step()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
@ -765,7 +885,7 @@ def train_one_epoch(
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0:
if utils.is_primary(args):
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
@ -781,14 +901,16 @@ def train_one_epoch(
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
data_time=data_time_m)
)
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
normalize=True
)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
@ -806,7 +928,15 @@ def train_one_epoch(
return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
def validate(
model,
loader,
loss_fn,
args,
device=torch.device('cuda'),
amp_autocast=suppress,
log_suffix=''
):
batch_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter()
@ -820,8 +950,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
input = input.to(device)
target = target.to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
@ -846,6 +976,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
else:
reduced_loss = loss.data
if device.type == 'cuda':
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
@ -854,7 +985,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
'{0}: [{1:>4d}/{2}] '
@ -862,8 +993,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx, batch_time=batch_time_m,
loss=losses_m, top1=top1_m, top5=top5_m))
log_name, batch_idx, last_idx,
batch_time=batch_time_m,
loss=losses_m,
top1=top1_m,
top5=top5_m)
)
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])

@ -19,6 +19,7 @@ import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
from functools import partial
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
@ -45,7 +46,6 @@ try:
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
@ -100,12 +100,14 @@ parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
@ -133,25 +135,35 @@ def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
amp_autocast = suppress # do nothing
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
device = torch.device(args.device)
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_autocast = suppress
if args.amp:
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
assert args.amp_dtype == 'float16'
use_amp = 'apex'
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
_logger.warning("Neither APEX or Native Torch AMP is available.")
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
if args.native_amp:
amp_autocast = torch.cuda.amp.autocast
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
assert args.amp_dtype in ('float16', 'bfloat16')
use_amp = 'native'
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
_logger.info('Validating in mixed precision with native PyTorch AMP.')
elif args.apex_amp:
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
if args.fuser:
set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
@ -162,7 +174,8 @@ def validate(args):
num_classes=args.num_classes,
in_chans=3,
global_pool=args.gp,
scriptable=args.torchscript)
scriptable=args.torchscript,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
@ -177,7 +190,7 @@ def validate(args):
vars(args),
model=model,
use_test_size=not args.use_train_size,
verbose=True
verbose=True,
)
test_time_pool = False
if args.test_pool:
@ -186,12 +199,13 @@ def validate(args):
if args.torchscript:
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
model = model.cuda()
if args.apex_amp:
model = model.to(device)
if use_amp == 'apex':
model = amp.initialize(model, opt_level='O1')
if args.channels_last:
@ -200,11 +214,16 @@ def validate(args):
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
criterion = nn.CrossEntropyLoss().cuda()
criterion = nn.CrossEntropyLoss().to(device)
dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split,
download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
root=args.data,
name=args.dataset,
split=args.split,
download=args.dataset_download,
load_bytes=args.tf_preprocessing,
class_map=args.class_map,
)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:
@ -230,7 +249,9 @@ def validate(args):
num_workers=args.workers,
crop_pct=crop_pct,
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
device=device,
tf_preprocessing=args.tf_preprocessing,
)
batch_time = AverageMeter()
losses = AverageMeter()
@ -240,7 +261,7 @@ def validate(args):
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
@ -249,8 +270,8 @@ def validate(args):
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.cuda()
input = input.cuda()
target = target.to(device)
input = input.to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
@ -282,9 +303,15 @@ def validate(args):
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
batch_idx, len(loader), batch_time=batch_time,
batch_idx,
len(loader),
batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
loss=losses,
top1=top1,
top5=top5
)
)
if real_labels is not None:
# real labels mode replaces topk values at the end
@ -298,7 +325,8 @@ def validate(args):
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
crop_pct=crop_pct,
interpolation=data_config['interpolation'])
interpolation=data_config['interpolation'],
)
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
@ -313,6 +341,7 @@ def _try_run(args, initial_batch_size):
while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try:
if torch.cuda.is_available() and 'cuda' in args.device:
torch.cuda.empty_cache()
results = validate(args)
return results

Loading…
Cancel
Save