diff --git a/README.md b/README.md index e5094e73..0fb37886 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ More models, more fixes * `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs. * Hugging Face Hub support fixes verified, demo notebook TBA * Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation. -* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103) +* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103) * Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases. * a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges. * previous impl exists as `LayerNormExp2d` in `models/layers/norm.py` diff --git a/benchmark.py b/benchmark.py index 4a89441b..a03c1982 100755 --- a/benchmark.py +++ b/benchmark.py @@ -57,7 +57,9 @@ except ImportError as e: has_functorch = False -torch.backends.cudnn.benchmark = True +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -216,7 +218,7 @@ class BenchmarkRunner: self.device = device self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) self.channels_last = kwargs.pop('channels_last', False) - self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress + self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress if fuser: set_jit_fuser(fuser) diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 0eb10a66..7cc7b0b0 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -6,8 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset from .loader import create_loader from .mixup import Mixup, FastCollateMixup -from .parsers import create_parser,\ - get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions +from .readers import create_reader +from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions from .real_labels import RealLabelsImagenet from .transforms import * from .transforms_factory import create_transform diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 20b663ce..e7f67925 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -2,14 +2,15 @@ Hacked together by / Copyright 2019, Ross Wightman """ -import torch.utils.data as data -import os -import torch +import io import logging +from typing import Optional +import torch +import torch.utils.data as data from PIL import Image -from .parsers import create_parser +from .readers import create_reader _logger = logging.getLogger(__name__) @@ -22,48 +23,62 @@ class ImageDataset(data.Dataset): def __init__( self, root, - parser=None, + reader=None, + split='train', class_map=None, load_bytes=False, + img_mode='RGB', transform=None, target_transform=None, ): - if parser is None or isinstance(parser, str): - parser = create_parser(parser or '', root=root, class_map=class_map) - self.parser = parser + if reader is None or isinstance(reader, str): + reader = create_reader( + reader or '', + root=root, + split=split, + class_map=class_map + ) + self.reader = reader self.load_bytes = load_bytes + self.img_mode = img_mode self.transform = transform self.target_transform = target_transform self._consecutive_errors = 0 def __getitem__(self, index): - img, target = self.parser[index] + img, target = self.reader[index] + try: - img = img.read() if self.load_bytes else Image.open(img).convert('RGB') + img = img.read() if self.load_bytes else Image.open(img) except Exception as e: - _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}') + _logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}') self._consecutive_errors += 1 if self._consecutive_errors < _ERROR_RETRY: - return self.__getitem__((index + 1) % len(self.parser)) + return self.__getitem__((index + 1) % len(self.reader)) else: raise e self._consecutive_errors = 0 + + if self.img_mode and not self.load_bytes: + img = img.convert(self.img_mode) if self.transform is not None: img = self.transform(img) + if target is None: target = -1 elif self.target_transform is not None: target = self.target_transform(target) + return img, target def __len__(self): - return len(self.parser) + return len(self.reader) def filename(self, index, basename=False, absolute=False): - return self.parser.filename(index, basename, absolute) + return self.reader.filename(index, basename, absolute) def filenames(self, basename=False, absolute=False): - return self.parser.filenames(basename, absolute) + return self.reader.filenames(basename, absolute) class IterableImageDataset(data.IterableDataset): @@ -71,28 +86,36 @@ class IterableImageDataset(data.IterableDataset): def __init__( self, root, - parser=None, + reader=None, split='train', is_training=False, batch_size=None, + seed=42, repeats=0, download=False, transform=None, target_transform=None, ): - assert parser is not None - if isinstance(parser, str): - self.parser = create_parser( - parser, root=root, split=split, is_training=is_training, - batch_size=batch_size, repeats=repeats, download=download) + assert reader is not None + if isinstance(reader, str): + self.reader = create_reader( + reader, + root=root, + split=split, + is_training=is_training, + batch_size=batch_size, + seed=seed, + repeats=repeats, + download=download, + ) else: - self.parser = parser + self.reader = reader self.transform = transform self.target_transform = target_transform self._consecutive_errors = 0 def __iter__(self): - for img, target in self.parser: + for img, target in self.reader: if self.transform is not None: img = self.transform(img) if self.target_transform is not None: @@ -100,16 +123,29 @@ class IterableImageDataset(data.IterableDataset): yield img, target def __len__(self): - if hasattr(self.parser, '__len__'): - return len(self.parser) + if hasattr(self.reader, '__len__'): + return len(self.reader) else: return 0 + def set_epoch(self, count): + # TFDS and WDS need external epoch count for deterministic cross process shuffle + if hasattr(self.reader, 'set_epoch'): + self.reader.set_epoch(count) + + def set_loader_cfg( + self, + num_workers: Optional[int] = None, + ): + # TFDS and WDS readers need # workers for correct # samples estimate before loader processes created + if hasattr(self.reader, 'set_loader_cfg'): + self.reader.set_loader_cfg(num_workers=num_workers) + def filename(self, index, basename=False, absolute=False): assert False, 'Filename lookup by index not supported, use filenames().' def filenames(self, basename=False, absolute=False): - return self.parser.filenames(basename, absolute) + return self.reader.filenames(basename, absolute) class AugMixDataset(torch.utils.data.Dataset): diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index d0ac30b1..3777a5aa 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -60,6 +60,7 @@ def create_dataset( is_training=False, download=False, batch_size=None, + seed=42, repeats=0, **kwargs ): @@ -68,7 +69,9 @@ def create_dataset( In parenthesis after each arg are the type of dataset supported for each arg, one of: * folder - default, timm folder (or tar) based ImageDataset * torch - torchvision based datasets + * HFDS - Hugging Face Datasets * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset + * WDS - Webdataset * all - any of the above Args: @@ -79,11 +82,12 @@ def create_dataset( `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) class_map: specify class -> index mapping via text file or dict (folder) load_bytes: load data, return images as undecoded bytes (folder) - download: download dataset if not present and supported (TFDS, torch) + download: download dataset if not present and supported (HFDS, TFDS, torch) is_training: create dataset in train mode, this is different from the split. - For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS) - batch_size: batch size hint for (TFDS) - repeats: dataset repeats per iteration i.e. epoch (TFDS) + For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS) + batch_size: batch size hint for (TFDS, WDS) + seed: seed for iterable datasets (TFDS, WDS) + repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS) **kwargs: other args to pass to dataset Returns: @@ -130,14 +134,37 @@ def create_dataset( ds = ImageFolder(root, **kwargs) else: assert False, f"Unknown torchvision dataset {name}" + elif name.startswith('hfds/'): + # NOTE right now, HF datasets default arrow format is a random-access Dataset, + # There will be a IterableDataset variant too, TBD + ds = ImageDataset(root, reader=name, split=split, **kwargs) elif name.startswith('tfds/'): ds = IterableImageDataset( - root, parser=name, split=split, is_training=is_training, - download=download, batch_size=batch_size, repeats=repeats, **kwargs) + root, + reader=name, + split=split, + is_training=is_training, + download=download, + batch_size=batch_size, + repeats=repeats, + seed=seed, + **kwargs + ) + elif name.startswith('wds/'): + ds = IterableImageDataset( + root, + reader=name, + split=split, + is_training=is_training, + batch_size=batch_size, + repeats=repeats, + seed=seed, + **kwargs + ) else: # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future if search_split and os.path.isdir(root): # look for split specific sub-folder in root root = _search_split(root, split) - ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) + ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs) return ds diff --git a/timm/data/loader.py b/timm/data/loader.py index ecc075c0..1a4800f8 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -5,19 +5,25 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d Hacked together by / Copyright 2019, Ross Wightman """ +import logging import random +from contextlib import suppress from functools import partial from itertools import repeat from typing import Callable +import torch import torch.utils.data import numpy as np -from .transforms_factory import create_transform from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .dataset import IterableImageDataset from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler from .random_erasing import RandomErasing from .mixup import FastCollateMixup +from .transforms_factory import create_transform + +_logger = logging.getLogger(__name__) def fast_collate(batch): @@ -55,11 +61,13 @@ def fast_collate(batch): assert False -def expand_to_chs(x, n): +def adapt_to_chs(x, n): if not isinstance(x, (tuple, list)): x = tuple(repeat(x, n)) - elif len(x) == 1: - x = x * n + elif len(x) != n: + x_mean = np.mean(x).item() + x = (x_mean,) * n + _logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.') else: assert len(x) == n, 'normalization stats must match image channels' return x @@ -73,41 +81,55 @@ class PrefetchLoader: mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, channels=3, + device=torch.device('cuda'), + img_dtype=torch.float32, fp16=False, re_prob=0., re_mode='const', re_count=1, re_num_splits=0): - mean = expand_to_chs(mean, channels) - std = expand_to_chs(std, channels) + mean = adapt_to_chs(mean, channels) + std = adapt_to_chs(std, channels) normalization_shape = (1, channels, 1, 1) self.loader = loader - self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape) - self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape) - self.fp16 = fp16 + self.device = device if fp16: - self.mean = self.mean.half() - self.std = self.std.half() + # fp16 arg is deprecated, but will override dtype arg if set for bwd compat + img_dtype = torch.float16 + self.img_dtype = img_dtype + self.mean = torch.tensor( + [x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape) + self.std = torch.tensor( + [x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape) if re_prob > 0.: self.random_erasing = RandomErasing( - probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) + probability=re_prob, + mode=re_mode, + max_count=re_count, + num_splits=re_num_splits, + device=device, + ) else: self.random_erasing = None + self.is_cuda = torch.cuda.is_available() and device.type == 'cuda' def __iter__(self): - stream = torch.cuda.Stream() first = True + if self.is_cuda: + stream = torch.cuda.Stream() + stream_context = partial(torch.cuda.stream, stream=stream) + else: + stream = None + stream_context = suppress for next_input, next_target in self.loader: - with torch.cuda.stream(stream): - next_input = next_input.cuda(non_blocking=True) - next_target = next_target.cuda(non_blocking=True) - if self.fp16: - next_input = next_input.half().sub_(self.mean).div_(self.std) - else: - next_input = next_input.float().sub_(self.mean).div_(self.std) + + with stream_context(): + next_input = next_input.to(device=self.device, non_blocking=True) + next_target = next_target.to(device=self.device, non_blocking=True) + next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std) if self.random_erasing is not None: next_input = self.random_erasing(next_input) @@ -116,7 +138,9 @@ class PrefetchLoader: else: first = False - torch.cuda.current_stream().wait_stream(stream) + if stream is not None: + torch.cuda.current_stream().wait_stream(stream) + input = next_input target = next_target @@ -189,7 +213,9 @@ def create_loader( crop_pct=None, collate_fn=None, pin_memory=False, - fp16=False, + fp16=False, # deprecated, use img_dtype + img_dtype=torch.float32, + device=torch.device('cuda'), tf_preprocessing=False, use_multi_epochs_loader=False, persistent_workers=True, @@ -222,6 +248,11 @@ def create_loader( separate=num_aug_splits > 0, ) + if isinstance(dataset, IterableImageDataset): + # give Iterable datasets early knowledge of num_workers so that sample estimates + # are correct before worker processes are launched + dataset.set_loader_cfg(num_workers=num_workers) + sampler = None if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: @@ -266,7 +297,9 @@ def create_loader( mean=mean, std=std, channels=input_size[0], - fp16=fp16, + device=device, + fp16=fp16, # deprecated, use img_dtype + img_dtype=img_dtype, re_prob=prefetch_re_prob, re_mode=re_mode, re_count=re_count, diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py deleted file mode 100644 index 4e820d5e..00000000 --- a/timm/data/parsers/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .parser_factory import create_parser -from .img_extensions import * diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py deleted file mode 100644 index 0665c02a..00000000 --- a/timm/data/parsers/parser_factory.py +++ /dev/null @@ -1,28 +0,0 @@ -import os - -from .parser_image_folder import ParserImageFolder -from .parser_image_in_tar import ParserImageInTar - - -def create_parser(name, root, split='train', **kwargs): - name = name.lower() - name = name.split('/', 2) - prefix = '' - if len(name) > 1: - prefix = name[0] - name = name[-1] - - # FIXME improve the selection right now just tfds prefix or fallback path, will need options to - # explicitly select other options shortly - if prefix == 'tfds': - from .parser_tfds import ParserTfds # defer tensorflow import - parser = ParserTfds(root, name, split=split, **kwargs) - else: - assert os.path.exists(root) - # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder - # FIXME support split here, in parser? - if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': - parser = ParserImageInTar(root, **kwargs) - else: - parser = ParserImageFolder(root, **kwargs) - return parser diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index 98108488..1dee5f86 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -7,6 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman """ import random import math + import torch @@ -44,8 +45,17 @@ class RandomErasing: def __init__( self, - probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, - mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): + probability=0.5, + min_area=0.02, + max_area=1/3, + min_aspect=0.3, + max_aspect=None, + mode='const', + min_count=1, + max_count=None, + num_splits=0, + device='cuda', + ): self.probability = probability self.min_area = min_area self.max_area = max_area @@ -81,8 +91,12 @@ class RandomErasing: top = random.randint(0, img_h - h) left = random.randint(0, img_w - w) img[:, top:top + h, left:left + w] = _get_pixels( - self.per_pixel, self.rand_color, (chan, h, w), - dtype=dtype, device=self.device) + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) break def __call__(self, input): diff --git a/timm/data/readers/__init__.py b/timm/data/readers/__init__.py new file mode 100644 index 00000000..63e4e649 --- /dev/null +++ b/timm/data/readers/__init__.py @@ -0,0 +1,2 @@ +from .reader_factory import create_reader +from .img_extensions import * diff --git a/timm/data/parsers/class_map.py b/timm/data/readers/class_map.py similarity index 100% rename from timm/data/parsers/class_map.py rename to timm/data/readers/class_map.py diff --git a/timm/data/parsers/img_extensions.py b/timm/data/readers/img_extensions.py similarity index 100% rename from timm/data/parsers/img_extensions.py rename to timm/data/readers/img_extensions.py diff --git a/timm/data/parsers/parser.py b/timm/data/readers/reader.py similarity index 97% rename from timm/data/parsers/parser.py rename to timm/data/readers/reader.py index 76ab6d18..fe55d313 100644 --- a/timm/data/parsers/parser.py +++ b/timm/data/readers/reader.py @@ -1,7 +1,7 @@ from abc import abstractmethod -class Parser: +class Reader: def __init__(self): pass diff --git a/timm/data/readers/reader_factory.py b/timm/data/readers/reader_factory.py new file mode 100644 index 00000000..58ff56cd --- /dev/null +++ b/timm/data/readers/reader_factory.py @@ -0,0 +1,35 @@ +import os + +from .reader_image_folder import ReaderImageFolder +from .reader_image_in_tar import ReaderImageInTar + + +def create_reader(name, root, split='train', **kwargs): + name = name.lower() + name = name.split('/', 2) + prefix = '' + if len(name) > 1: + prefix = name[0] + name = name[-1] + + # FIXME improve the selection right now just tfds prefix or fallback path, will need options to + # explicitly select other options shortly + if prefix == 'hfds': + from .reader_hfds import ReaderHfds # defer tensorflow import + reader = ReaderHfds(root, name, split=split, **kwargs) + elif prefix == 'tfds': + from .reader_tfds import ReaderTfds # defer tensorflow import + reader = ReaderTfds(root, name, split=split, **kwargs) + elif prefix == 'wds': + from .reader_wds import ReaderWds + kwargs.pop('download', False) + reader = ReaderWds(root, name, split=split, **kwargs) + else: + assert os.path.exists(root) + # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder + # FIXME support split here or in reader? + if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': + reader = ReaderImageInTar(root, **kwargs) + else: + reader = ReaderImageFolder(root, **kwargs) + return reader diff --git a/timm/data/readers/reader_hfds.py b/timm/data/readers/reader_hfds.py new file mode 100644 index 00000000..901cf4bc --- /dev/null +++ b/timm/data/readers/reader_hfds.py @@ -0,0 +1,70 @@ +""" Dataset reader that wraps Hugging Face datasets + +Hacked together by / Copyright 2022 Ross Wightman +""" +import io +import math +import torch +import torch.distributed as dist +from PIL import Image + +try: + import datasets +except ImportError as e: + print("Please install Hugging Face datasets package `pip install datasets`.") + exit(1) +from .reader import Reader + + +def get_class_labels(info): + if 'label' not in info.features: + return {} + class_label = info.features['label'] + class_to_idx = {n: class_label.str2int(n) for n in class_label.names} + return class_to_idx + + +class ReaderHfds(Reader): + + def __init__( + self, + root, + name, + split='train', + class_map=None, + download=False, + ): + """ + """ + super().__init__() + self.root = root + self.split = split + self.dataset = datasets.load_dataset( + name, # 'name' maps to path arg in hf datasets + split=split, + cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path + #use_auth_token=True, + ) + # leave decode for caller, plus we want easy access to original path names... + self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False)) + + self.class_to_idx = get_class_labels(self.dataset.info) + self.split_info = self.dataset.info.splits[split] + self.num_samples = self.split_info.num_examples + + def __getitem__(self, index): + item = self.dataset[index] + image = item['image'] + if 'bytes' in image and image['bytes']: + image = io.BytesIO(image['bytes']) + else: + assert 'path' in image and image['path'] + image = open(image['path'], 'rb') + return image, item['label'] + + def __len__(self): + return len(self.dataset) + + def _filename(self, index, basename=False, absolute=False): + item = self.dataset[index] + return item['image']['path'] diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/readers/reader_image_folder.py similarity index 94% rename from timm/data/parsers/parser_image_folder.py rename to timm/data/readers/reader_image_folder.py index 3d22a17b..05823372 100644 --- a/timm/data/parsers/parser_image_folder.py +++ b/timm/data/readers/reader_image_folder.py @@ -1,6 +1,6 @@ -""" A dataset parser that reads images from folders +""" A dataset reader that extracts images from folders -Folders are scannerd recursively to find image files. Labels are based +Folders are scanned recursively to find image files. Labels are based on the folder hierarchy, just leaf folders by default. Hacked together by / Copyright 2020 Ross Wightman @@ -12,7 +12,7 @@ from timm.utils.misc import natural_key from .class_map import load_class_map from .img_extensions import get_img_extensions -from .parser import Parser +from .reader import Reader def find_images_and_targets( @@ -56,7 +56,7 @@ def find_images_and_targets( return images_and_targets, class_to_idx -class ParserImageFolder(Parser): +class ReaderImageFolder(Reader): def __init__( self, diff --git a/timm/data/parsers/parser_image_in_tar.py b/timm/data/readers/reader_image_in_tar.py similarity index 97% rename from timm/data/parsers/parser_image_in_tar.py rename to timm/data/readers/reader_image_in_tar.py index 4fcad797..001c9f4e 100644 --- a/timm/data/parsers/parser_image_in_tar.py +++ b/timm/data/readers/reader_image_in_tar.py @@ -1,6 +1,6 @@ -""" A dataset parser that reads tarfile based datasets +""" A dataset reader that reads tarfile based datasets -This parser can read and extract image samples from: +This reader can extract image samples from: * a single tar of image files * a folder of multiple tarfiles containing imagefiles * a tar of tars containing image files @@ -22,7 +22,7 @@ from timm.utils.misc import natural_key from .class_map import load_class_map from .img_extensions import get_img_extensions -from .parser import Parser +from .reader import Reader _logger = logging.getLogger(__name__) CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' @@ -169,8 +169,8 @@ def extract_tarinfos( return samples, targets, class_name_to_idx, tarfiles -class ParserImageInTar(Parser): - """ Multi-tarfile dataset parser where there is one .tar file per class +class ReaderImageInTar(Reader): + """ Multi-tarfile dataset reader where there is one .tar file per class """ def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None): diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/readers/reader_image_tar.py similarity index 91% rename from timm/data/parsers/parser_image_tar.py rename to timm/data/readers/reader_image_tar.py index c2ed429d..6051f26d 100644 --- a/timm/data/parsers/parser_image_tar.py +++ b/timm/data/readers/reader_image_tar.py @@ -1,6 +1,6 @@ -""" A dataset parser that reads single tarfile based datasets +""" A dataset reader that reads single tarfile based datasets -This parser can read datasets consisting if a single tarfile containing images. +This reader can read datasets consisting if a single tarfile containing images. I am planning to deprecated it in favour of ParerImageInTar. Hacked together by / Copyright 2020 Ross Wightman @@ -12,7 +12,7 @@ from timm.utils.misc import natural_key from .class_map import load_class_map from .img_extensions import get_img_extensions -from .parser import Parser +from .reader import Reader def extract_tarinfo(tarfile, class_to_idx=None, sort=True): @@ -38,9 +38,9 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True): return tarinfo_and_targets, class_to_idx -class ParserImageTar(Parser): +class ReaderImageTar(Reader): """ Single tarfile dataset where classes are mapped to folders within tar - NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can + NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can operate on folders of tars or tars in tars. """ def __init__(self, root, class_map=''): diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/readers/reader_tfds.py similarity index 72% rename from timm/data/parsers/parser_tfds.py rename to timm/data/readers/reader_tfds.py index 739f3813..7ccbf908 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/readers/reader_tfds.py @@ -1,4 +1,4 @@ -""" Dataset parser interface that wraps TFDS datasets +""" Dataset reader that wraps TFDS datasets Wraps many (most?) TFDS image-classification datasets from https://github.com/tensorflow/datasets @@ -7,6 +7,9 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification Hacked together by / Copyright 2020 Ross Wightman """ import math +import os +from typing import Optional + import torch import torch.distributed as dist from PIL import Image @@ -30,16 +33,18 @@ except ImportError as e: print(e) print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") exit(1) -from .parser import Parser + +from .reader import Reader +from .shared_count import SharedCount -MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue -PREFETCH_SIZE = 2048 # examples to prefetch +MAX_TP_SIZE = os.environ.get('TFDS_TP_SIZE', 8) # maximum TF threadpool size, for jpeg decodes and queuing activities +SHUFFLE_SIZE = os.environ.get('TFDS_SHUFFLE_SIZE', 8192) # samples to shuffle in DS queue +PREFETCH_SIZE = os.environ.get('TFDS_PREFETCH_SIZE', 2048) # samples to prefetch -def even_split_indices(split, n, num_examples): - partitions = [round(i * num_examples / n) for i in range(n + 1)] +def even_split_indices(split, n, num_samples): + partitions = [round(i * num_samples / n) for i in range(n + 1)] return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)] @@ -51,24 +56,24 @@ def get_class_labels(info): return class_to_idx -class ParserTfds(Parser): +class ReaderTfds(Reader): """ Wrap Tensorflow Datasets for use in PyTorch There several things to be aware of: - * To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of + * To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last https://github.com/pytorch/pytorch/issues/33413 * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch from each worker could be a different size. For training this is worked around by option above, for - validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced + validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced across replicas are of same size. This will slightly alter the results, distributed validation will not be 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse - since there are up to N * J extra examples with IterableDatasets. + since there are up to N * J extra samples with IterableDatasets. * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of replicas and dataloader workers you can use. For really small datasets that only contain a few shards you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the benefit of distributed training or fast dataloading should be much less for small datasets. - * This wrapper is currently configured to return individual, decompressed image examples from the TFDS + * This wrapper is currently configured to return individual, decompressed image samples from the TFDS dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream components. @@ -86,9 +91,9 @@ class ParserTfds(Parser): repeats=0, seed=42, input_name='image', - input_image='RGB', + input_img_mode='RGB', target_name='label', - target_image='', + target_img_mode='', prefetch_size=None, shuffle_size=None, max_threadpool_size=None @@ -100,14 +105,14 @@ class ParserTfds(Parser): name: tfds dataset name (eg `imagenet2012`) split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`) is_training: training mode, shuffle enabled, dataset len rounded by batch_size - batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes + batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes download: download and build TFDS dataset if set, otherwise must use tfds CLI repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) seed: common seed for shard shuffle across all distributed/worker instances input_name: name of Feature to return as data (input) - input_image: image mode if input is an image (currently PIL mode string) + input_img_mode: image mode if input is an image (currently PIL mode string) target_name: name of Feature to return as target (label) - target_image: image mode if target is an image (currently PIL mode string) + target_img_mode: image mode if target is an image (currently PIL mode string) prefetch_size: override default tf.data prefetch buffer size shuffle_size: override default tf.data shuffle buffer size max_threadpool_size: override default threadpool size for tf.data @@ -130,16 +135,16 @@ class ParserTfds(Parser): # TFDS builder and split information self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature - self.input_image = input_image + self.input_img_mode = input_img_mode self.target_name = target_name - self.target_image = target_image + self.target_img_mode = target_img_mode self.builder = tfds.builder(name, data_dir=root) # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag if download: self.builder.download_and_prepare() self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {} self.split_info = self.builder.info.splits[split] - self.num_examples = self.split_info.num_examples + self.num_samples = self.split_info.num_examples # Distributed world state self.dist_rank = 0 @@ -150,10 +155,29 @@ class ParserTfds(Parser): # Attributes that are updated in _lazy_init, including the tf.data pipeline itself self.global_num_workers = 1 + self.num_workers = 1 self.worker_info = None self.worker_seed = 0 # seed unique to each work instance self.subsplit = None # set when data is distributed across workers using sub-splits self.ds = None # initialized lazily on each dataloader worker process + self.init_count = 0 # number of ds TF data pipeline initializations + self.epoch_count = SharedCount() + # FIXME need to determine if reinit_each_iter is necessary. I'm don't completely trust behaviour + # of `shuffle_reshuffle_each_iteration` when there are multiple workers / nodes across epochs + self.reinit_each_iter = self.is_training + + def set_epoch(self, count): + self.epoch_count.value = count + + def set_loader_cfg( + self, + num_workers: Optional[int] = None, + ): + if self.ds is not None: + return + if num_workers is not None: + self.num_workers = num_workers + self.global_num_workers = self.dist_num_replicas * self.num_workers def _lazy_init(self): """ Lazily initialize the dataset. @@ -174,9 +198,9 @@ class ParserTfds(Parser): if worker_info is not None: self.worker_info = worker_info self.worker_seed = worker_info.seed - num_workers = worker_info.num_workers - self.global_num_workers = self.dist_num_replicas * num_workers - global_worker_id = self.dist_rank * num_workers + worker_info.id + self.num_workers = worker_info.num_workers + self.global_num_workers = self.dist_num_replicas * self.num_workers + global_worker_id = self.dist_rank * self.num_workers + worker_info.id """ Data sharding InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. @@ -186,17 +210,17 @@ class ParserTfds(Parser): I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing the data across workers. For training InputContext is used to assign shards to nodes unless num_shards in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or - for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding. + for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding. """ should_subsplit = self.global_num_workers > 1 and ( self.split_info.num_shards < self.global_num_workers or not self.is_training) if should_subsplit: - # split the dataset w/o using sharding for more even examples / worker, can result in less optimal + # split the dataset w/o using sharding for more even samples / worker, can result in less optimal # read patterns for distributed training (overlap across shards) so better to use InputContext there if has_buggy_even_splits: # my even_split workaround doesn't work on subsplits, upgrade tfds! if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): - subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples) + subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples) self.subsplit = subsplits[global_worker_id] else: subsplits = tfds.even_splits(self.split, self.global_num_workers) @@ -211,15 +235,19 @@ class ParserTfds(Parser): num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? ) read_config = tfds.ReadConfig( - shuffle_seed=self.common_seed, + shuffle_seed=self.common_seed + self.epoch_count.value, shuffle_reshuffle_each_iteration=True, - input_context=input_context) + input_context=input_context, + ) ds = self.builder.as_dataset( - split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) + split=self.subsplit or self.split, + shuffle_files=self.is_training, + read_config=read_config, + ) # avoid overloading threading w/ combo of TF ds threads + PyTorch workers options = tf.data.Options() thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading' - getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers) + getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // self.num_workers) getattr(options, thread_member).max_intra_op_parallelism = 1 ds = ds.with_options(options) if self.is_training or self.repeats > 1: @@ -227,59 +255,65 @@ class ParserTfds(Parser): # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually if self.is_training: - ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) - ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size)) + ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) + ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size)) self.ds = tfds.as_numpy(ds) + self.init_count += 1 + + def _num_samples_per_worker(self): + num_worker_samples = \ + max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas) + if self.is_training or self.dist_num_replicas > 1: + num_worker_samples = math.ceil(num_worker_samples) + if self.is_training and self.batch_size is not None: + num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size + return int(num_worker_samples) def __iter__(self): - if self.ds is None: + if self.ds is None or self.reinit_each_iter: self._lazy_init() # Compute a rounded up sample count that is used to: # 1. make batches even cross workers & replicas in distributed validation. - # This adds extra examples and will slightly alter validation results. + # This adds extra samples and will slightly alter validation results. # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # batches are produced (underlying tfds iter wraps around) - target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers) - if self.is_training: - # round up to nearest batch_size per worker-replica - target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size + target_sample_count = self._num_samples_per_worker() # Iterate until exhausted or sample count hits target when training (ds.repeat enabled) - example_count = 0 - for example in self.ds: - input_data = example[self.input_name] - if self.input_image: - input_data = Image.fromarray(input_data, mode=self.input_image) - target_data = example[self.target_name] - if self.target_image: - target_data = Image.fromarray(target_data, mode=self.target_image) + sample_count = 0 + for sample in self.ds: + input_data = sample[self.input_name] + if self.input_img_mode: + input_data = Image.fromarray(input_data, mode=self.input_img_mode) + target_data = sample[self.target_name] + if self.target_img_mode: + target_data = Image.fromarray(target_data, mode=self.target_img_mode) yield input_data, target_data - example_count += 1 - if self.is_training and example_count >= target_example_count: + sample_count += 1 + if self.is_training and sample_count >= target_sample_count: # Need to break out of loop when repeat() is enabled for training w/ oversampling - # this results in extra examples per epoch but seems more desirable than dropping + # this results in extra samples per epoch but seems more desirable than dropping # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) break - # Pad across distributed nodes (make counts equal by adding examples) + # Pad across distributed nodes (make counts equal by adding samples) if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \ - 0 < example_count < target_example_count: + 0 < sample_count < target_sample_count: # Validation batch padding only done for distributed training where results are reduced across nodes. # For single process case, it won't matter if workers return different batch sizes. # If using input_context or % based splits, sample count can vary significantly across workers and this # approach should not be used (hence disabled if self.subsplit isn't set). - while example_count < target_example_count: + while sample_count < target_sample_count: yield input_data, target_data # yield prev sample again - example_count += 1 + sample_count += 1 def __len__(self): - # this is just an estimate and does not factor in extra examples added to pad batches based on - # complete worker & replica info (not available until init in dataloader). - return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas) + num_samples = self._num_samples_per_worker() * self.num_workers + return num_samples def _filename(self, index, basename=False, absolute=False): - assert False, "Not supported" # no random access to examples + assert False, "Not supported" # no random access to samples def filenames(self, basename=False, absolute=False): """ Return all filenames in dataset, overrides base""" @@ -287,7 +321,7 @@ class ParserTfds(Parser): self._lazy_init() names = [] for sample in self.ds: - if len(names) > self.num_examples: + if len(names) > self.num_samples: break # safety for ds.repeat() case if 'file_name' in sample: name = sample['file_name'] diff --git a/timm/data/readers/reader_wds.py b/timm/data/readers/reader_wds.py new file mode 100644 index 00000000..afc9030b --- /dev/null +++ b/timm/data/readers/reader_wds.py @@ -0,0 +1,461 @@ +""" Dataset reader for webdataset + +Hacked together by / Copyright 2022 Ross Wightman +""" +import io +import json +import logging +import math +import os +import random +import sys +from dataclasses import dataclass +from functools import partial +from itertools import islice +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +import yaml +from PIL import Image +from torch.utils.data import Dataset, IterableDataset, get_worker_info + +try: + import webdataset as wds + from webdataset.filters import _shuffle + from webdataset.shardlists import expand_urls + from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample +except ImportError: + wds = None + expand_urls = None + +from .reader import Reader +from .shared_count import SharedCount + +_logger = logging.getLogger(__name__) + +SHUFFLE_SIZE = os.environ.get('WDS_SHUFFLE_SIZE', 8192) + + +def _load_info(root, basename='info'): + info_json = os.path.join(root, basename + '.json') + info_yaml = os.path.join(root, basename + '.yaml') + err_str = '' + try: + with wds.gopen.gopen(info_json) as f: + info_dict = json.load(f) + return info_dict + except Exception as e: + err_str = str(e) + try: + with wds.gopen.gopen(info_yaml) as f: + info_dict = yaml.safe_load(f) + return info_dict + except Exception: + pass + _logger.warning( + f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. ' + 'Falling back to provided split and size arg.') + return {} + + +@dataclass +class SplitInfo: + num_samples: int + filenames: Tuple[str] + shard_lengths: Tuple[int] = () + alt_label: str = '' + name: str = '' + + +def _parse_split_info(split: str, info: Dict): + def _info_convert(dict_info): + return SplitInfo( + num_samples=dict_info['num_samples'], + filenames=tuple(dict_info['filenames']), + shard_lengths=tuple(dict_info['shard_lengths']), + alt_label=dict_info.get('alt_label', ''), + name=dict_info['name'], + ) + + if 'tar' in split or '..' in split: + # split in WDS string braceexpand format, sample count can be included with a | separator + # ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples + split = split.split('|') + num_samples = 0 + split_name = '' + if len(split) > 1: + num_samples = int(split[1]) + split = split[0] + if '::' not in split: + split_parts = split.split('-', 3) + split_idx = len(split_parts) - 1 + if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']: + split_name = split_parts[split_idx] + + split_filenames = expand_urls(split) + if split_name: + split_info = info['splits'][split_name] + if not num_samples: + _fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])} + num_samples = sum(_fc[f] for f in split_filenames) + split_info['filenames'] = tuple(_fc.keys()) + split_info['shard_lengths'] = tuple(_fc.values()) + split_info['num_samples'] = num_samples + split_info = _info_convert(split_info) + else: + split_info = SplitInfo( + name=split_name, + num_samples=num_samples, + filenames=split_filenames, + ) + else: + if split not in info['splits']: + raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})") + split = split + split_info = info['splits'][split] + split_info = _info_convert(split_info) + + return split_info + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + _logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') + return True + + +def _decode( + sample, + image_key='jpg', + image_format='RGB', + target_key='cls', + alt_label='' +): + """ Custom sample decode + * decode and convert PIL Image + * cls byte string label to int + * pass through JSON byte string (if it exists) without parse + """ + # decode class label, skip if alternate label not valid + if alt_label: + # alternative labels are encoded in json metadata + meta = json.loads(sample['json']) + class_label = int(meta[alt_label]) + if class_label < 0: + # skipped labels currently encoded as -1, may change to a null/None value + return None + else: + class_label = int(sample[target_key]) + + # decode image + with io.BytesIO(sample[image_key]) as b: + img = Image.open(b) + img.load() + if image_format: + img = img.convert(image_format) + + # json passed through in undecoded state + decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None)) + return decoded + + +def _decode_samples( + data, + image_key='jpg', + image_format='RGB', + target_key='cls', + alt_label='', + handler=log_and_continue): + """Decode samples with skip.""" + for sample in data: + try: + result = _decode( + sample, + image_key=image_key, + image_format=image_format, + target_key=target_key, + alt_label=alt_label + ) + except Exception as exn: + if handler(exn): + continue + else: + break + + # null results are skipped + if result is not None: + if isinstance(sample, dict) and isinstance(result, dict): + result["__key__"] = sample.get("__key__") + yield result + + +def pytorch_worker_seed(): + """get dataloader worker seed from pytorch""" + worker_info = get_worker_info() + if worker_info is not None: + # favour the seed already created for pytorch dataloader workers if it exists + return worker_info.seed + # fallback to wds rank based seed + return wds.utils.pytorch_worker_seed() + + +if wds is not None: + # conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage) + class detshuffle2(wds.PipelineStage): + def __init__( + self, + bufsize=1000, + initial=100, + seed=0, + epoch=-1, + ): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + if isinstance(self.epoch, SharedCount): + epoch = self.epoch.value + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + + if self.seed < 0: + seed = pytorch_worker_seed() + epoch + else: + seed = self.seed + epoch + # _logger.info(f'shuffle seed: {self.seed}, {seed}, epoch: {epoch}') # FIXME temporary + rng = random.Random(seed) + return _shuffle(src, self.bufsize, self.initial, rng) + +else: + detshuffle2 = None + + +class ResampledShards2(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + nshards=sys.maxsize, + worker_seed=None, + deterministic=True, + epoch=-1, + ): + """Sample shards from the shard list with replacement. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = wds.shardlists.expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.rng = random.Random() + self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed + self.deterministic = deterministic + self.epoch = epoch + + def __iter__(self): + """Return an iterator over the shards.""" + if isinstance(self.epoch, SharedCount): + epoch = self.epoch.value + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + + if self.deterministic: + # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed + self.rng = random.Random(self.worker_seed() + epoch) + + for _ in range(self.nshards): + index = self.rng.randint(0, len(self.urls) - 1) + yield dict(url=self.urls[index]) + + +class ReaderWds(Reader): + def __init__( + self, + root, + name, + split, + is_training=False, + batch_size=None, + repeats=0, + seed=42, + input_name='jpg', + input_image='RGB', + target_name='cls', + target_image='', + prefetch_size=None, + shuffle_size=None, + ): + super().__init__() + if wds is None: + raise RuntimeError( + 'Please install webdataset 0.2.x package `pip install git+https://github.com/webdataset/webdataset`.') + self.root = root + self.is_training = is_training + self.batch_size = batch_size + self.repeats = repeats + self.common_seed = seed # a seed that's fixed across all worker / distributed instances + self.shard_shuffle_size = 500 + self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE + + self.image_key = input_name + self.image_format = input_image + self.target_key = target_name + self.filename_key = 'filename' + self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet) + + self.info = _load_info(self.root) + self.split_info = _parse_split_info(split, self.info) + self.num_samples = self.split_info.num_samples + if not self.num_samples: + raise RuntimeError(f'Invalid split definition, no samples found.') + + # Distributed world state + self.dist_rank = 0 + self.dist_num_replicas = 1 + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + self.dist_rank = dist.get_rank() + self.dist_num_replicas = dist.get_world_size() + + # Attributes that are updated in _lazy_init + self.worker_info = None + self.worker_id = 0 + self.worker_seed = seed # seed unique to each worker instance + self.num_workers = 1 + self.global_worker_id = 0 + self.global_num_workers = 1 + self.init_count = 0 + self.epoch_count = SharedCount() + + # DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed + # is not handled in manner where it can be deterministic for each worker AND initialized up front + self.ds = None + + def set_epoch(self, count): + self.epoch_count.value = count + + def set_loader_cfg( + self, + num_workers: Optional[int] = None, + ): + if self.ds is not None: + return + if num_workers is not None: + self.num_workers = num_workers + self.global_num_workers = self.dist_num_replicas * self.num_workers + + def _lazy_init(self): + """ Lazily initialize worker (in worker processes) + """ + if self.worker_info is None: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + self.worker_info = worker_info + self.worker_id = worker_info.id + self.worker_seed = worker_info.seed + self.num_workers = worker_info.num_workers + self.global_num_workers = self.dist_num_replicas * self.num_workers + self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id + + # init data pipeline + abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames] + pipeline = [wds.SimpleShardList(abs_shard_filenames)] + # at this point we have an iterator over all the shards + if self.is_training: + pipeline.extend([ + detshuffle2(self.shard_shuffle_size, seed=self.common_seed, epoch=self.epoch_count), + self._split_by_node_and_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + self.sample_shuffle_size, + rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline + ]) + else: + pipeline.extend([ + self._split_by_node_and_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + pipeline.extend([ + partial( + _decode_samples, + image_key=self.image_key, + image_format=self.image_format, + alt_label=self.split_info.alt_label + ) + ]) + self.ds = wds.DataPipeline(*pipeline) + + def _split_by_node_and_worker(self, src): + if self.global_num_workers > 1: + for s in islice(src, self.global_worker_id, None, self.global_num_workers): + yield s + else: + for s in src: + yield s + + def _num_samples_per_worker(self): + num_worker_samples = self.num_samples / max(self.global_num_workers, self.dist_num_replicas) + if self.is_training or self.dist_num_replicas > 1: + num_worker_samples = math.ceil(num_worker_samples) + if self.is_training and self.batch_size is not None: + num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size + return int(num_worker_samples) + + def __iter__(self): + if self.ds is None: + self._lazy_init() + + num_worker_samples = self._num_samples_per_worker() + if self.is_training or self.dist_num_replicas > 1: + # NOTE: doing distributed validation w/ WDS is messy, hard to meet constraints that + # same # of batches needed across all replicas w/ seeing each sample once. + # with_epoch() is simple but could miss a shard's worth of samples in some workers, + # and duplicate in others. Best to keep num DL workers low and a divisor of #val shards. + ds = self.ds.with_epoch(num_worker_samples) + else: + ds = self.ds + + i = 0 + # _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug + for sample in ds: + yield sample[self.image_key], sample[self.target_key] + i += 1 + # _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug + + def __len__(self): + num_samples = self._num_samples_per_worker() * self.num_workers + return num_samples + + def _filename(self, index, basename=False, absolute=False): + assert False, "Not supported" # no random access to examples + + def filenames(self, basename=False, absolute=False): + """ Return all filenames in dataset, overrides base""" + if self.ds is None: + self._lazy_init() + + names = [] + for sample in self.ds: + if self.filename_key in sample: + name = sample[self.filename_key] + elif '__key__' in sample: + name = sample['__key__'] + self.key_ext + else: + assert False, "No supported name field present" + names.append(name) + if len(names) >= self.num_samples: + break # safety for ds.repeat() case + return names diff --git a/timm/data/readers/shared_count.py b/timm/data/readers/shared_count.py new file mode 100644 index 00000000..fe4e85b4 --- /dev/null +++ b/timm/data/readers/shared_count.py @@ -0,0 +1,14 @@ +from multiprocessing import Value + + +class SharedCount: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value('i', epoch) + + @property + def value(self): + return self.shared_epoch.value + + @value.setter + def value(self, epoch): + self.shared_epoch.value = epoch diff --git a/timm/models/helpers.py b/timm/models/helpers.py index d68c7e65..c771e825 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -63,7 +63,7 @@ def load_state_dict(checkpoint_path, use_ema=True): raise FileNotFoundError() -def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True): +def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): # numpy checkpoint, try to load via model specific load_pretrained fn if hasattr(model, 'load_pretrained'): @@ -72,10 +72,28 @@ def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True): raise NotImplementedError('Model cannot load numpy checkpoint') return state_dict = load_state_dict(checkpoint_path, use_ema) + if remap: + state_dict = remap_checkpoint(model, state_dict) incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys +def remap_checkpoint(model, state_dict, allow_reshape=True): + """ remap checkpoint by iterating over state dicts in order (ignoring original keys). + This assumes models (and originating state dict) were created with params registered in same order. + """ + out_dict = {} + for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): + assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + if va.shape != vb.shape: + if allow_reshape: + vb = vb.reshape(va.shape) + else: + assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + out_dict[ka] = vb + return out_dict + + def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): resume_epoch = None if os.path.isfile(checkpoint_path): diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py index 2e41d956..4fe568fe 100644 --- a/timm/models/layers/squeeze_excite.py +++ b/timm/models/layers/squeeze_excite.py @@ -72,3 +72,31 @@ class EffectiveSEModule(nn.Module): EffectiveSqueezeExcite = EffectiveSEModule # alias + + +class SqueezeExciteCl(nn.Module): + """ SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, + bias=True, act_layer=nn.ReLU, gate_layer='sigmoid'): + super().__init__() + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Linear(channels, rd_channels, bias=bias) + self.act = create_act_layer(act_layer, inplace=True) + self.fc2 = nn.Linear(rd_channels, channels, bias=bias) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((1, 2), keepdims=True) # FIXME avg dim [1:n-1], don't assume 2D NHWC + x_se = self.fc1(x_se) + x_se = self.act(x_se) + x_se = self.fc2(x_se) + return x * self.gate(x_se) \ No newline at end of file diff --git a/timm/optim/adan.py b/timm/optim/adan.py new file mode 100644 index 00000000..1d2a7585 --- /dev/null +++ b/timm/optim/adan.py @@ -0,0 +1,124 @@ +""" Adan Optimizer + +Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + +Implementation adapted from https://github.com/sail-sg/Adan +""" + +import math + +import torch + +from torch.optim import Optimizer + + +class Adan(Optimizer): + """ + Implements a pytorch variant of Adan + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0) + no_prox (bool): how to perform the decoupled weight decay (default: False) + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + no_prox=False, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, no_prox=no_prox) + super(Adan, self).__init__(params, defaults) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """ Performs a single optimization step. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1 ** group['step'] + bias_correction2 = 1.0 - beta2 ** group['step'] + bias_correction3 = 1.0 - beta3 ** group['step'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['pre_grad'] = grad.clone() + + exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq'] + grad_diff = grad - state['pre_grad'] + + exp_avg.lerp_(grad, 1. - beta1) # m_t + exp_avg_diff.lerp_(grad_diff, 1. - beta2) # diff_t (v) + update = grad + beta2 * grad_diff + exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1. - beta3) # n_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction3)).add_(group['eps']) + update = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2).div_(denom) + if group['no_prox']: + p.data.mul_(1 - group['lr'] * group['weight_decay']) + p.add_(update, alpha=-group['lr']) + else: + p.add_(update, alpha=-group['lr']) + p.data.div_(1 + group['lr'] * group['weight_decay']) + + state['pre_grad'].copy_(grad) + + return loss diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index c82fd3d2..02f0e250 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -15,6 +15,7 @@ from .adabelief import AdaBelief from .adafactor import Adafactor from .adahessian import Adahessian from .adamp import AdamP +from .adan import Adan from .lamb import Lamb from .lars import Lars from .lookahead import Lookahead @@ -192,7 +193,8 @@ def create_optimizer_v2( filter_bias_and_bn: bool = True, layer_decay: Optional[float] = None, param_group_fn: Optional[Callable] = None, - **kwargs): + **kwargs, +): """ Create an optimizer. TODO currently the model is passed in and all parameters are selected for optimization. @@ -285,6 +287,10 @@ def create_optimizer_v2( optimizer = optim.Adagrad(parameters, **opt_args) elif opt_lower == 'adafactor': optimizer = Adafactor(parameters, **opt_args) + elif opt_lower == 'adanp': + optimizer = Adan(parameters, no_prox=False, **opt_args) + elif opt_lower == 'adanw': + optimizer = Adan(parameters, no_prox=True, **opt_args) elif opt_lower == 'lamb': optimizer = Lamb(parameters, **opt_args) elif opt_lower == 'lambc': diff --git a/timm/scheduler/__init__.py b/timm/scheduler/__init__.py index f1961b88..9f7191bb 100644 --- a/timm/scheduler/__init__.py +++ b/timm/scheduler/__init__.py @@ -5,4 +5,4 @@ from .poly_lr import PolyLRScheduler from .step_lr import StepLRScheduler from .tanh_lr import TanhLRScheduler -from .scheduler_factory import create_scheduler +from .scheduler_factory import create_scheduler, create_scheduler_v2, scheduler_kwargs diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py index 84ee349e..e2c975fb 100644 --- a/timm/scheduler/cosine_lr.py +++ b/timm/scheduler/cosine_lr.py @@ -26,33 +26,42 @@ class CosineLRScheduler(Scheduler): k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 """ - def __init__(self, - optimizer: torch.optim.Optimizer, - t_initial: int, - lr_min: float = 0., - cycle_mul: float = 1., - cycle_decay: float = 1., - cycle_limit: int = 1, - warmup_t=0, - warmup_lr_init=0, - warmup_prefix=False, - t_in_epochs=True, - noise_range_t=None, - noise_pct=0.67, - noise_std=1.0, - noise_seed=42, - k_decay=1.0, - initialize=True) -> None: + def __init__( + self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + k_decay=1.0, + initialize=True, + ) -> None: super().__init__( - optimizer, param_group_field="lr", - noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, - initialize=initialize) + optimizer, + param_group_field="lr", + t_in_epochs=t_in_epochs, + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) assert t_initial > 0 assert lr_min >= 0 if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: - _logger.warning("Cosine annealing scheduler will have no effect on the learning " - "rate since t_initial = t_mul = eta_mul = 1.") + _logger.warning( + "Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") self.t_initial = t_initial self.lr_min = lr_min self.cycle_mul = cycle_mul @@ -61,7 +70,6 @@ class CosineLRScheduler(Scheduler): self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init self.warmup_prefix = warmup_prefix - self.t_in_epochs = t_in_epochs self.k_decay = k_decay if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] @@ -99,18 +107,6 @@ class CosineLRScheduler(Scheduler): return lrs - def get_epoch_values(self, epoch: int): - if self.t_in_epochs: - return self._get_lr(epoch) - else: - return None - - def get_update_values(self, num_updates: int): - if not self.t_in_epochs: - return self._get_lr(num_updates) - else: - return None - def get_cycle_length(self, cycles=0): cycles = max(1, cycles or self.cycle_limit) if self.cycle_mul == 1.0: diff --git a/timm/scheduler/multistep_lr.py b/timm/scheduler/multistep_lr.py index a5d5fe19..10f2fb50 100644 --- a/timm/scheduler/multistep_lr.py +++ b/timm/scheduler/multistep_lr.py @@ -11,29 +11,37 @@ class MultiStepLRScheduler(Scheduler): """ """ - def __init__(self, - optimizer: torch.optim.Optimizer, - decay_t: List[int], - decay_rate: float = 1., - warmup_t=0, - warmup_lr_init=0, - t_in_epochs=True, - noise_range_t=None, - noise_pct=0.67, - noise_std=1.0, - noise_seed=42, - initialize=True, - ) -> None: + def __init__( + self, + optimizer: torch.optim.Optimizer, + decay_t: List[int], + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=True, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: super().__init__( - optimizer, param_group_field="lr", - noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, - initialize=initialize) + optimizer, + param_group_field="lr", + t_in_epochs=t_in_epochs, + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) self.decay_t = decay_t self.decay_rate = decay_rate self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init - self.t_in_epochs = t_in_epochs + self.warmup_prefix = warmup_prefix if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] super().update_groups(self.warmup_lr_init) @@ -43,23 +51,13 @@ class MultiStepLRScheduler(Scheduler): def get_curr_decay_steps(self, t): # find where in the array t goes, # assumes self.decay_t is sorted - return bisect.bisect_right(self.decay_t, t+1) + return bisect.bisect_right(self.decay_t, t + 1) def _get_lr(self, t): if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: + if self.warmup_prefix: + t = t - self.warmup_t lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] return lrs - - def get_epoch_values(self, epoch: int): - if self.t_in_epochs: - return self._get_lr(epoch) - else: - return None - - def get_update_values(self, num_updates: int): - if not self.t_in_epochs: - return self._get_lr(num_updates) - else: - return None diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py index cacfab3c..9f827157 100644 --- a/timm/scheduler/plateau_lr.py +++ b/timm/scheduler/plateau_lr.py @@ -12,24 +12,25 @@ from .scheduler import Scheduler class PlateauLRScheduler(Scheduler): """Decay the LR by a factor every time the validation loss plateaus.""" - def __init__(self, - optimizer, - decay_rate=0.1, - patience_t=10, - verbose=True, - threshold=1e-4, - cooldown_t=0, - warmup_t=0, - warmup_lr_init=0, - lr_min=0, - mode='max', - noise_range_t=None, - noise_type='normal', - noise_pct=0.67, - noise_std=1.0, - noise_seed=None, - initialize=True, - ): + def __init__( + self, + optimizer, + decay_rate=0.1, + patience_t=10, + verbose=True, + threshold=1e-4, + cooldown_t=0, + warmup_t=0, + warmup_lr_init=0, + lr_min=0, + mode='max', + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize=True, + ): super().__init__( optimizer, 'lr', @@ -89,6 +90,9 @@ class PlateauLRScheduler(Scheduler): if self._is_apply_noise(epoch): self._apply_noise(epoch) + def step_update(self, num_updates: int, metric: float = None): + return None + def _apply_noise(self, epoch): noise = self._calculate_noise(epoch) @@ -101,3 +105,6 @@ class PlateauLRScheduler(Scheduler): new_lr = old_lr + old_lr * noise param_group['lr'] = new_lr self.restore_lr = restore_lr + + def _get_lr(self, t: int) -> float: + assert False, 'should not be called as step is overridden' diff --git a/timm/scheduler/poly_lr.py b/timm/scheduler/poly_lr.py index 9c351be6..906f6acf 100644 --- a/timm/scheduler/poly_lr.py +++ b/timm/scheduler/poly_lr.py @@ -21,28 +21,36 @@ class PolyLRScheduler(Scheduler): k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 """ - def __init__(self, - optimizer: torch.optim.Optimizer, - t_initial: int, - power: float = 0.5, - lr_min: float = 0., - cycle_mul: float = 1., - cycle_decay: float = 1., - cycle_limit: int = 1, - warmup_t=0, - warmup_lr_init=0, - warmup_prefix=False, - t_in_epochs=True, - noise_range_t=None, - noise_pct=0.67, - noise_std=1.0, - noise_seed=42, - k_decay=1.0, - initialize=True) -> None: + def __init__( + self, + optimizer: torch.optim.Optimizer, + t_initial: int, + power: float = 0.5, + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + k_decay=1.0, + initialize=True, + ) -> None: super().__init__( - optimizer, param_group_field="lr", - noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, - initialize=initialize) + optimizer, + param_group_field="lr", + t_in_epochs=t_in_epochs, + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize + ) assert t_initial > 0 assert lr_min >= 0 @@ -58,7 +66,6 @@ class PolyLRScheduler(Scheduler): self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init self.warmup_prefix = warmup_prefix - self.t_in_epochs = t_in_epochs self.k_decay = k_decay if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] @@ -96,18 +103,6 @@ class PolyLRScheduler(Scheduler): return lrs - def get_epoch_values(self, epoch: int): - if self.t_in_epochs: - return self._get_lr(epoch) - else: - return None - - def get_update_values(self, num_updates: int): - if not self.t_in_epochs: - return self._get_lr(num_updates) - else: - return None - def get_cycle_length(self, cycles=0): cycles = max(1, cycles or self.cycle_limit) if self.cycle_mul == 1.0: diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index af20be9b..4ae2e2ae 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -1,9 +1,11 @@ -from typing import Dict, Any +import abc +from abc import ABC +from typing import Any, Dict, Optional import torch -class Scheduler: +class Scheduler(ABC): """ Parameter Scheduler Base Class A scheduler base class that can be used to schedule any optimizer parameter groups. @@ -22,15 +24,18 @@ class Scheduler: * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers """ - def __init__(self, - optimizer: torch.optim.Optimizer, - param_group_field: str, - noise_range_t=None, - noise_type='normal', - noise_pct=0.67, - noise_std=1.0, - noise_seed=None, - initialize: bool = True) -> None: + def __init__( + self, + optimizer: torch.optim.Optimizer, + param_group_field: str, + t_in_epochs: bool = True, + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize: bool = True, + ) -> None: self.optimizer = optimizer self.param_group_field = param_group_field self._initial_param_group_field = f"initial_{param_group_field}" @@ -45,6 +50,7 @@ class Scheduler: raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] self.metric = None # any point to having this for all? + self.t_in_epochs = t_in_epochs self.noise_range_t = noise_range_t self.noise_pct = noise_pct self.noise_type = noise_type @@ -58,22 +64,26 @@ class Scheduler: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.__dict__.update(state_dict) - def get_epoch_values(self, epoch: int): - return None + @abc.abstractmethod + def _get_lr(self, t: int) -> float: + pass - def get_update_values(self, num_updates: int): - return None + def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]: + proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs) + if not proceed: + return None + return self._get_lr(t) def step(self, epoch: int, metric: float = None) -> None: self.metric = metric - values = self.get_epoch_values(epoch) + values = self._get_values(epoch, on_epoch=True) if values is not None: values = self._add_noise(values, epoch) self.update_groups(values) def step_update(self, num_updates: int, metric: float = None): self.metric = metric - values = self.get_update_values(num_updates) + values = self._get_values(num_updates, on_epoch=False) if values is not None: values = self._add_noise(values, num_updates) self.update_groups(values) diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 3e100fe0..6cb506a5 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -1,6 +1,10 @@ """ Scheduler Factory Hacked together by / Copyright 2021 Ross Wightman """ +from typing import List, Union + +from torch.optim import Optimizer + from .cosine_lr import CosineLRScheduler from .multistep_lr import MultiStepLRScheduler from .plateau_lr import PlateauLRScheduler @@ -9,99 +13,191 @@ from .step_lr import StepLRScheduler from .tanh_lr import TanhLRScheduler -def create_scheduler(args, optimizer): - num_epochs = args.epochs +def scheduler_kwargs(cfg): + """ cfg/argparse to kwargs helper + Convert scheduler args in argparse args or cfg (.dot) like object to keyword args. + """ + eval_metric = getattr(cfg, 'eval_metric', 'top1') + plateau_mode = 'min' if 'loss' in eval_metric else 'max' + kwargs = dict( + sched=cfg.sched, + num_epochs=getattr(cfg, 'epochs', 100), + decay_epochs=getattr(cfg, 'decay_epochs', 30), + decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]), + warmup_epochs=getattr(cfg, 'warmup_epochs', 5), + cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0), + patience_epochs=getattr(cfg, 'patience_epochs', 10), + decay_rate=getattr(cfg, 'decay_rate', 0.1), + min_lr=getattr(cfg, 'min_lr', 0.), + warmup_lr=getattr(cfg, 'warmup_lr', 1e-5), + warmup_prefix=getattr(cfg, 'warmup_prefix', False), + noise=getattr(cfg, 'lr_noise', None), + noise_pct=getattr(cfg, 'lr_noise_pct', 0.67), + noise_std=getattr(cfg, 'lr_noise_std', 1.), + noise_seed=getattr(cfg, 'seed', 42), + cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.), + cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1), + cycle_limit=getattr(cfg, 'lr_cycle_limit', 1), + k_decay=getattr(cfg, 'lr_k_decay', 1.0), + plateau_mode=plateau_mode, + step_on_epochs=not getattr(cfg, 'sched_on_updates', False), + ) + return kwargs + + +def create_scheduler( + args, + optimizer: Optimizer, + updates_per_epoch: int = 0, +): + return create_scheduler_v2( + optimizer=optimizer, + **scheduler_kwargs(args), + updates_per_epoch=updates_per_epoch, + ) + + +def create_scheduler_v2( + optimizer: Optimizer, + sched: str = 'cosine', + num_epochs: int = 300, + decay_epochs: int = 90, + decay_milestones: List[int] = (90, 180, 270), + cooldown_epochs: int = 0, + patience_epochs: int = 10, + decay_rate: float = 0.1, + min_lr: float = 0, + warmup_lr: float = 1e-5, + warmup_epochs: int = 0, + warmup_prefix: bool = False, + noise: Union[float, List[float]] = None, + noise_pct: float = 0.67, + noise_std: float = 1., + noise_seed: int = 42, + cycle_mul: float = 1., + cycle_decay: float = 0.1, + cycle_limit: int = 1, + k_decay: float = 1.0, + plateau_mode: str = 'max', + step_on_epochs: bool = True, + updates_per_epoch: int = 0, +): + t_initial = num_epochs + warmup_t = warmup_epochs + decay_t = decay_epochs + cooldown_t = cooldown_epochs + + if not step_on_epochs: + assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches' + t_initial = t_initial * updates_per_epoch + warmup_t = warmup_t * updates_per_epoch + decay_t = decay_t * updates_per_epoch + decay_milestones = [d * updates_per_epoch for d in decay_milestones] + cooldown_t = cooldown_t * updates_per_epoch + + # warmup args + warmup_args = dict( + warmup_lr_init=warmup_lr, + warmup_t=warmup_t, + warmup_prefix=warmup_prefix, + ) - if getattr(args, 'lr_noise', None) is not None: - lr_noise = getattr(args, 'lr_noise') - if isinstance(lr_noise, (list, tuple)): - noise_range = [n * num_epochs for n in lr_noise] + # setup noise args for supporting schedulers + if noise is not None: + if isinstance(noise, (list, tuple)): + noise_range = [n * t_initial for n in noise] if len(noise_range) == 1: noise_range = noise_range[0] else: - noise_range = lr_noise * num_epochs + noise_range = noise * t_initial else: noise_range = None noise_args = dict( noise_range_t=noise_range, - noise_pct=getattr(args, 'lr_noise_pct', 0.67), - noise_std=getattr(args, 'lr_noise_std', 1.), - noise_seed=getattr(args, 'seed', 42), + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, ) + + # setup cycle args for supporting schedulers cycle_args = dict( - cycle_mul=getattr(args, 'lr_cycle_mul', 1.), - cycle_decay=getattr(args, 'lr_cycle_decay', 0.1), - cycle_limit=getattr(args, 'lr_cycle_limit', 1), + cycle_mul=cycle_mul, + cycle_decay=cycle_decay, + cycle_limit=cycle_limit, ) lr_scheduler = None - if args.sched == 'cosine': + if sched == 'cosine': lr_scheduler = CosineLRScheduler( optimizer, - t_initial=num_epochs, - lr_min=args.min_lr, - warmup_lr_init=args.warmup_lr, - warmup_t=args.warmup_epochs, - k_decay=getattr(args, 'lr_k_decay', 1.0), + t_initial=t_initial, + lr_min=min_lr, + t_in_epochs=step_on_epochs, **cycle_args, + **warmup_args, **noise_args, + k_decay=k_decay, ) - num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs - elif args.sched == 'tanh': + elif sched == 'tanh': lr_scheduler = TanhLRScheduler( optimizer, - t_initial=num_epochs, - lr_min=args.min_lr, - warmup_lr_init=args.warmup_lr, - warmup_t=args.warmup_epochs, - t_in_epochs=True, + t_initial=t_initial, + lr_min=min_lr, + t_in_epochs=step_on_epochs, **cycle_args, + **warmup_args, **noise_args, ) - num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs - elif args.sched == 'step': + elif sched == 'step': lr_scheduler = StepLRScheduler( optimizer, - decay_t=args.decay_epochs, - decay_rate=args.decay_rate, - warmup_lr_init=args.warmup_lr, - warmup_t=args.warmup_epochs, + decay_t=decay_t, + decay_rate=decay_rate, + t_in_epochs=step_on_epochs, + **warmup_args, **noise_args, ) - elif args.sched == 'multistep': + elif sched == 'multistep': lr_scheduler = MultiStepLRScheduler( optimizer, - decay_t=args.decay_milestones, - decay_rate=args.decay_rate, - warmup_lr_init=args.warmup_lr, - warmup_t=args.warmup_epochs, + decay_t=decay_milestones, + decay_rate=decay_rate, + t_in_epochs=step_on_epochs, + **warmup_args, **noise_args, ) - elif args.sched == 'plateau': - mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' + elif sched == 'plateau': + assert step_on_epochs, 'Plateau LR only supports step per epoch.' + warmup_args.pop('warmup_prefix', False) lr_scheduler = PlateauLRScheduler( optimizer, - decay_rate=args.decay_rate, - patience_t=args.patience_epochs, - lr_min=args.min_lr, - mode=mode, - warmup_lr_init=args.warmup_lr, - warmup_t=args.warmup_epochs, + decay_rate=decay_rate, + patience_t=patience_epochs, cooldown_t=0, + **warmup_args, + lr_min=min_lr, + mode=plateau_mode, **noise_args, ) - elif args.sched == 'poly': + elif sched == 'poly': lr_scheduler = PolyLRScheduler( optimizer, - power=args.decay_rate, # overloading 'decay_rate' as polynomial power - t_initial=num_epochs, - lr_min=args.min_lr, - warmup_lr_init=args.warmup_lr, - warmup_t=args.warmup_epochs, - k_decay=getattr(args, 'lr_k_decay', 1.0), + power=decay_rate, # overloading 'decay_rate' as polynomial power + t_initial=t_initial, + lr_min=min_lr, + t_in_epochs=step_on_epochs, + k_decay=k_decay, **cycle_args, + **warmup_args, **noise_args, ) - num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + + if hasattr(lr_scheduler, 'get_cycle_length'): + # for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown + t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t + if step_on_epochs: + num_epochs = t_with_cycles_and_cooldown + else: + num_epochs = t_with_cycles_and_cooldown // updates_per_epoch return lr_scheduler, num_epochs diff --git a/timm/scheduler/step_lr.py b/timm/scheduler/step_lr.py index f797e1a8..70a45a70 100644 --- a/timm/scheduler/step_lr.py +++ b/timm/scheduler/step_lr.py @@ -14,29 +14,37 @@ class StepLRScheduler(Scheduler): """ """ - def __init__(self, - optimizer: torch.optim.Optimizer, - decay_t: float, - decay_rate: float = 1., - warmup_t=0, - warmup_lr_init=0, - t_in_epochs=True, - noise_range_t=None, - noise_pct=0.67, - noise_std=1.0, - noise_seed=42, - initialize=True, - ) -> None: + def __init__( + self, + optimizer: torch.optim.Optimizer, + decay_t: float, + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=True, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: super().__init__( - optimizer, param_group_field="lr", - noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, - initialize=initialize) + optimizer, + param_group_field="lr", + t_in_epochs=t_in_epochs, + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) self.decay_t = decay_t self.decay_rate = decay_rate self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init - self.t_in_epochs = t_in_epochs + self.warmup_prefix = warmup_prefix if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] super().update_groups(self.warmup_lr_init) @@ -47,17 +55,7 @@ class StepLRScheduler(Scheduler): if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: + if self.warmup_prefix: + t = t - self.warmup_t lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] return lrs - - def get_epoch_values(self, epoch: int): - if self.t_in_epochs: - return self._get_lr(epoch) - else: - return None - - def get_update_values(self, num_updates: int): - if not self.t_in_epochs: - return self._get_lr(num_updates) - else: - return None diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py index f2d3c9cd..48acc61b 100644 --- a/timm/scheduler/tanh_lr.py +++ b/timm/scheduler/tanh_lr.py @@ -21,28 +21,36 @@ class TanhLRScheduler(Scheduler): This is described in the paper https://arxiv.org/abs/1806.01593 """ - def __init__(self, - optimizer: torch.optim.Optimizer, - t_initial: int, - lb: float = -7., - ub: float = 3., - lr_min: float = 0., - cycle_mul: float = 1., - cycle_decay: float = 1., - cycle_limit: int = 1, - warmup_t=0, - warmup_lr_init=0, - warmup_prefix=False, - t_in_epochs=True, - noise_range_t=None, - noise_pct=0.67, - noise_std=1.0, - noise_seed=42, - initialize=True) -> None: + def __init__( + self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lb: float = -7., + ub: float = 3., + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: super().__init__( - optimizer, param_group_field="lr", - noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, - initialize=initialize) + optimizer, + param_group_field="lr", + t_in_epochs=t_in_epochs, + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) assert t_initial > 0 assert lr_min >= 0 @@ -60,7 +68,6 @@ class TanhLRScheduler(Scheduler): self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init self.warmup_prefix = warmup_prefix - self.t_in_epochs = t_in_epochs if self.warmup_t: t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] @@ -97,18 +104,6 @@ class TanhLRScheduler(Scheduler): lrs = [self.lr_min for _ in self.base_values] return lrs - def get_epoch_values(self, epoch: int): - if self.t_in_epochs: - return self._get_lr(epoch) - else: - return None - - def get_update_values(self, num_updates: int): - if not self.t_in_epochs: - return self._get_lr(num_updates) - else: - return None - def get_cycle_length(self, cycles=0): cycles = max(1, cycles or self.cycle_limit) if self.cycle_mul == 1.0: diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 7b139852..a9ff0c78 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -3,7 +3,8 @@ from .checkpoint_saver import CheckpointSaver from .clip_grad import dispatch_clip_grad from .cuda import ApexScaler, NativeScaler from .decay_batch import decay_batch_step, check_batch_size_retry -from .distributed import distribute_bn, reduce_tensor +from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\ + world_info_from_env, is_distributed_env, is_primary from .jit import set_jit_legacy, set_jit_fuser from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 3c5dba8c..ee9a358c 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -2,9 +2,16 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import os + import torch from torch import distributed as dist +try: + import horovod.torch as hvd +except ImportError: + hvd = None + from .model import unwrap_model @@ -26,3 +33,105 @@ def distribute_bn(model, world_size, reduce=False): else: # broadcast bn stats from rank 0 to whole group torch.distributed.broadcast(bn_buf, 0) + + +def is_global_primary(args): + return args.rank == 0 + + +def is_local_primary(args): + return args.local_rank == 0 + + +def is_primary(args, local=False): + return is_local_primary(args) if local else is_global_primary(args) + + +def is_distributed_env(): + if 'WORLD_SIZE' in os.environ: + return int(os.environ['WORLD_SIZE']) > 1 + if 'SLURM_NTASKS' in os.environ: + return int(os.environ['SLURM_NTASKS']) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): + if v in os.environ: + local_rank = int(os.environ[v]) + break + + global_rank = 0 + for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): + if v in os.environ: + global_rank = int(os.environ[v]) + break + + world_size = 1 + for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + + # TBD, support horovod? + # if args.horovod: + # assert hvd is not None, "Horovod is not installed" + # hvd.init() + # args.local_rank = int(hvd.local_rank()) + # args.rank = hvd.rank() + # args.world_size = hvd.size() + # args.distributed = True + # os.environ['LOCAL_RANK'] = str(args.local_rank) + # os.environ['RANK'] = str(args.rank) + # os.environ['WORLD_SIZE'] = str(args.world_size) + dist_backend = getattr(args, 'dist_backend', 'nccl') + dist_url = getattr(args, 'dist_url', 'env://') + if is_distributed_env(): + if 'SLURM_PROCID' in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ['LOCAL_RANK'] = str(args.local_rank) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + torch.distributed.init_process_group( + backend=dist_backend, + init_method=dist_url, + world_size=args.world_size, + rank=args.rank, + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend=dist_backend, + init_method=dist_url, + ) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + + if torch.cuda.is_available(): + if args.distributed: + device = 'cuda:%d' % args.local_rank + else: + device = 'cuda:0' + torch.cuda.set_device(device) + else: + device = 'cpu' + + args.device = device + device = torch.device(device) + return device diff --git a/timm/utils/summary.py b/timm/utils/summary.py index 9f5af9a0..c377a75f 100644 --- a/timm/utils/summary.py +++ b/timm/utils/summary.py @@ -10,6 +10,7 @@ try: except ImportError: pass + def get_outdir(path, *paths, inc=False): outdir = os.path.join(path, *paths) if not os.path.exists(outdir): @@ -26,10 +27,20 @@ def get_outdir(path, *paths, inc=False): return outdir -def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): +def update_summary( + epoch, + train_metrics, + eval_metrics, + filename, + lr=None, + write_header=False, + log_wandb=False, +): rowd = OrderedDict(epoch=epoch) rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) + if lr is not None: + rowd['lr'] = lr if log_wandb: wandb.log(rowd) with open(filename, mode='a') as cf: diff --git a/train.py b/train.py index ee137217..08dadb02 100755 --- a/train.py +++ b/train.py @@ -21,6 +21,7 @@ import time from collections import OrderedDict from contextlib import suppress from datetime import datetime +from functools import partial import torch import torch.nn as nn @@ -35,7 +36,7 @@ from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntrop from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm from timm.optim import create_optimizer_v2, optimizer_kwargs -from timm.scheduler import create_scheduler +from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler try: @@ -66,7 +67,6 @@ except ImportError as e: has_functorch = False -torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') # The first arg parser parses out only the --config argument, this argument is used to @@ -111,7 +111,9 @@ group.add_argument('--num-classes', type=int, default=None, metavar='N', group.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') group.add_argument('--img-size', type=int, default=None, metavar='N', - help='Image patch size (default: None => model default)') + help='Image size (default: None => model default)') +group.add_argument('--in-chans', type=int, default=None, metavar='N', + help='Image input channels (default: None => 3)') group.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') group.add_argument('--crop-pct', default=None, type=float, @@ -161,10 +163,18 @@ group.add_argument('--layer-decay', type=float, default=None, # Learning rate schedule parameters group = parser.add_argument_group('Learning rate schedule parameters') -group.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', +group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER', help='LR scheduler (default: "step"') -group.add_argument('--lr', type=float, default=0.05, metavar='LR', - help='learning rate (default: 0.05)') +group.add_argument('--sched-on-updates', action='store_true', default=False, + help='Apply LR scheduler step on update instead of epoch end.') +group.add_argument('--lr', type=float, default=None, metavar='LR', + help='learning rate, overrides lr-base if set (default: None)') +group.add_argument('--lr-base', type=float, default=0.1, metavar='LR', + help='base learning rate: lr = lr_base * global_batch_size / base_size') +group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV', + help='base learning rate batch size (divisor, default: 256).') +group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE', + help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)') group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', @@ -179,23 +189,25 @@ group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', help='learning rate cycle limit, cycles enabled if > 1') group.add_argument('--lr-k-decay', type=float, default=1.0, help='learning rate k-decay for cosine/poly (default: 1.0)') -group.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', - help='warmup learning rate (default: 0.0001)') -group.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', - help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') +group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR', + help='warmup learning rate (default: 1e-5)') +group.add_argument('--min-lr', type=float, default=0, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (default: 0)') group.add_argument('--epochs', type=int, default=300, metavar='N', help='number of epochs to train (default: 300)') group.add_argument('--epoch-repeats', type=float, default=0., metavar='N', help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') group.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') -group.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", +group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES", help='list of decay epoch indices for multistep lr. must be increasing') -group.add_argument('--decay-epochs', type=float, default=100, metavar='N', +group.add_argument('--decay-epochs', type=float, default=90, metavar='N', help='epoch interval to decay LR') -group.add_argument('--warmup-epochs', type=int, default=3, metavar='N', +group.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') -group.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', +group.add_argument('--warmup-prefix', action='store_true', default=False, + help='Exclude warmup period from decay schedule.'), +group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') group.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') @@ -303,10 +315,10 @@ group.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') group.add_argument('--amp', action='store_true', default=False, help='use NVIDIA Apex AMP or Native AMP for mixed precision training') -group.add_argument('--apex-amp', action='store_true', default=False, - help='Use NVIDIA Apex AMP mixed precision') -group.add_argument('--native-amp', action='store_true', default=False, - help='Use Native Torch AMP mixed precision') +group.add_argument('--amp-dtype', default='float16', type=str, + help='lower precision AMP dtype (default: float16)') +group.add_argument('--amp-impl', default='native', type=str, + help='AMP impl to use, "native" or "apex" (default: native)') group.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') group.add_argument('--pin-mem', action='store_true', default=False, @@ -349,49 +361,42 @@ def main(): utils.setup_default_logging() args, args_text = _parse_args() + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + args.prefetcher = not args.no_prefetcher - args.distributed = False - if 'WORLD_SIZE' in os.environ: - args.distributed = int(os.environ['WORLD_SIZE']) > 1 - args.device = 'cuda:0' - args.world_size = 1 - args.rank = 0 # global rank + device = utils.init_distributed_device(args) if args.distributed: - if 'LOCAL_RANK' in os.environ: - args.local_rank = int(os.getenv('LOCAL_RANK')) - args.device = 'cuda:%d' % args.local_rank - torch.cuda.set_device(args.local_rank) - torch.distributed.init_process_group(backend='nccl', init_method='env://') - args.world_size = torch.distributed.get_world_size() - args.rank = torch.distributed.get_rank() - _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' - % (args.rank, args.world_size)) + _logger.info( + 'Training in distributed mode with multiple processes, 1 device per process.' + f'Process {args.rank}, total {args.world_size}, device {args.device}.') else: - _logger.info('Training with a single process on 1 GPUs.') + _logger.info(f'Training with a single process on 1 device ({args.device}).') assert args.rank >= 0 - if args.rank == 0 and args.log_wandb: + if utils.is_primary(args) and args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: - _logger.warning("You've requested to log metrics to wandb but package not found. " - "Metrics not being logged to wandb, try `pip install wandb`") + _logger.warning( + "You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") # resolve AMP arguments based on PyTorch / Apex availability use_amp = None + amp_dtype = torch.float16 if args.amp: - # `--amp` chooses native amp before apex (APEX ver not actively maintained) - if has_native_amp: - args.native_amp = True - elif has_apex: - args.apex_amp = True - if args.apex_amp and has_apex: - use_amp = 'apex' - elif args.native_amp and has_native_amp: - use_amp = 'native' - elif args.apex_amp or args.native_amp: - _logger.warning("Neither APEX or native Torch AMP is available, using float32. " - "Install NVIDA apex or upgrade to PyTorch 1.6") + if args.amp_impl == 'apex': + assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' + use_amp = 'apex' + assert args.amp_dtype == 'float16' + else: + assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' + use_amp = 'native' + assert args.amp_dtype in ('float16', 'bfloat16') + if args.amp_dtype == 'bfloat16': + amp_dtype = torch.bfloat16 utils.random_seed(args.seed, args.rank) @@ -400,19 +405,26 @@ def main(): if args.fast_norm: set_fast_norm() + in_chans = 3 + if args.in_chans is not None: + in_chans = args.in_chanes + elif args.input_size is not None: + in_chans = args.input_size[0] + model = create_model( args.model, pretrained=args.pretrained, + in_chans=in_chans, num_classes=args.num_classes, drop_rate=args.drop, - drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, - checkpoint_path=args.initial_checkpoint) + checkpoint_path=args.initial_checkpoint, + ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly @@ -420,11 +432,11 @@ def main(): if args.grad_checkpointing: model.set_grad_checkpointing(enable=True) - if args.local_rank == 0: + if utils.is_primary(args): _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') - data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args)) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 @@ -438,9 +450,9 @@ def main(): model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set - model.cuda() + model.to(device=device) if args.channels_last: - model = model.to(memory_format=torch.channels_last) + model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: @@ -452,7 +464,7 @@ def main(): model = convert_syncbn_model(model) else: model = convert_sync_batchnorm(model) - if args.local_rank == 0: + if utils.is_primary(args): _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') @@ -461,38 +473,56 @@ def main(): assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) + if args.aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model) + if args.lr is None: + global_batch_size = args.batch_size * args.world_size + batch_ratio = global_batch_size / args.lr_base_size + if not args.lr_base_scale: + on = args.opt.lower() + args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear' + if args.lr_base_scale == 'sqrt': + batch_ratio = batch_ratio ** 0.5 + args.lr = args.lr_base * batch_ratio + if utils.is_primary(args): + _logger.info( + f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) ' + f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') + optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': + assert device.type == 'cuda' model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() - if args.local_rank == 0: + if utils.is_primary(args): _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': - amp_autocast = torch.cuda.amp.autocast - loss_scaler = NativeScaler() - if args.local_rank == 0: + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + if device.type == 'cuda': + loss_scaler = NativeScaler() + if utils.is_primary(args): _logger.info('Using native Torch AMP. Training in mixed precision.') else: - if args.local_rank == 0: + if utils.is_primary(args): _logger.info('AMP not enabled. Training in float32.') - # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( - model, args.resume, + model, + args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, - log_info=args.local_rank == 0) + log_info=utils.is_primary(args), + ) # setup exponential moving average of model weights, SWA could be used here too model_ema = None @@ -507,41 +537,37 @@ def main(): if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated - if args.local_rank == 0: + if utils.is_primary(args): _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: - if args.local_rank == 0: + if utils.is_primary(args): _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) + model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP - # setup learning rate schedule and starting epoch - lr_scheduler, num_epochs = create_scheduler(args, optimizer) - start_epoch = 0 - if args.start_epoch is not None: - # a specified start_epoch will always override the resume epoch - start_epoch = args.start_epoch - elif resume_epoch is not None: - start_epoch = resume_epoch - if lr_scheduler is not None and start_epoch > 0: - lr_scheduler.step(start_epoch) - - if args.local_rank == 0: - _logger.info('Scheduled epochs: {}'.format(num_epochs)) - # create the train and eval datasets dataset_train = create_dataset( - args.dataset, root=args.data_dir, split=args.train_split, is_training=True, + args.dataset, + root=args.data_dir, + split=args.train_split, + is_training=True, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, - repeats=args.epoch_repeats) + seed=args.seed, + repeats=args.epoch_repeats, + ) + dataset_eval = create_dataset( - args.dataset, root=args.data_dir, split=args.val_split, is_training=False, + args.dataset, + root=args.data_dir, + split=args.val_split, + is_training=False, class_map=args.class_map, download=args.dataset_download, - batch_size=args.batch_size) + batch_size=args.batch_size, + ) # setup mixup / cutmix collate_fn = None @@ -549,9 +575,15 @@ def main(): mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( - mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, - prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, - label_smoothing=args.smoothing, num_classes=args.num_classes) + mixup_alpha=args.mixup, + cutmix_alpha=args.cutmix, + cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, + switch_prob=args.mixup_switch_prob, + mode=args.mixup_mode, + label_smoothing=args.smoothing, + num_classes=args.num_classes + ) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) @@ -592,10 +624,15 @@ def main(): distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, + device=device, use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) + eval_workers = args.workers + if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset): + # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training + eval_workers = min(2, args.workers) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], @@ -605,10 +642,11 @@ def main(): interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], - num_workers=args.workers, + num_workers=eval_workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, + device=device, ) # setup loss function @@ -628,8 +666,8 @@ def main(): train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() - train_loss_fn = train_loss_fn.cuda() - validate_loss_fn = nn.CrossEntropyLoss().cuda() + train_loss_fn = train_loss_fn.to(device=device) + validate_loss_fn = nn.CrossEntropyLoss().to(device=device) # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric @@ -637,7 +675,7 @@ def main(): best_epoch = None saver = None output_dir = None - if args.rank == 0: + if utils.is_primary(args): if args.experiment: exp_name = args.experiment else: @@ -649,60 +687,136 @@ def main(): output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = utils.CheckpointSaver( - model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, - checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) + model=model, + optimizer=optimizer, + args=args, + model_ema=model_ema, + amp_scaler=loss_scaler, + checkpoint_dir=output_dir, + recovery_dir=output_dir, + decreasing=decreasing, + max_history=args.checkpoint_hist + ) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) + # setup learning rate schedule and starting epoch + updates_per_epoch = len(loader_train) + lr_scheduler, num_epochs = create_scheduler_v2( + optimizer, + **scheduler_kwargs(args), + updates_per_epoch=updates_per_epoch, + ) + start_epoch = 0 + if args.start_epoch is not None: + # a specified start_epoch will always override the resume epoch + start_epoch = args.start_epoch + elif resume_epoch is not None: + start_epoch = resume_epoch + if lr_scheduler is not None and start_epoch > 0: + if args.step_on_updates: + lr_scheduler.step_update(start_epoch * updates_per_epoch) + else: + lr_scheduler.step(start_epoch) + + if utils.is_primary(args): + _logger.info( + f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.') + try: for epoch in range(start_epoch, num_epochs): - if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): + if hasattr(dataset_train, 'set_epoch'): + dataset_train.set_epoch(epoch) + elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch( - epoch, model, loader_train, optimizer, train_loss_fn, args, - lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, - amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) + epoch, + model, + loader_train, + optimizer, + train_loss_fn, + args, + lr_scheduler=lr_scheduler, + saver=saver, + output_dir=output_dir, + amp_autocast=amp_autocast, + loss_scaler=loss_scaler, + model_ema=model_ema, + mixup_fn=mixup_fn, + ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): - if args.local_rank == 0: + if utils.is_primary(args): _logger.info("Distributing BatchNorm running means and vars") utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') - eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) + eval_metrics = validate( + model, + loader_eval, + validate_loss_fn, + args, + amp_autocast=amp_autocast, + ) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') + ema_eval_metrics = validate( - model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') + model_ema.module, + loader_eval, + validate_loss_fn, + args, + amp_autocast=amp_autocast, + log_suffix=' (EMA)', + ) eval_metrics = ema_eval_metrics - if lr_scheduler is not None: - # step LR for next epoch - lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) - if output_dir is not None: + lrs = [param_group['lr'] for param_group in optimizer.param_groups] utils.update_summary( - epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) + epoch, + train_metrics, + eval_metrics, + filename=os.path.join(output_dir, 'summary.csv'), + lr=sum(lrs) / len(lrs), + write_header=best_metric is None, + log_wandb=args.log_wandb and has_wandb, + ) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) + if lr_scheduler is not None: + # step LR for next epoch + lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) + except KeyboardInterrupt: pass + if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) def train_one_epoch( - epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, - loss_scaler=None, model_ema=None, mixup_fn=None): - + epoch, + model, + loader, + optimizer, + loss_fn, + args, + device=torch.device('cuda'), + lr_scheduler=None, + saver=None, + output_dir=None, + amp_autocast=suppress, + loss_scaler=None, + model_ema=None, + mixup_fn=None +): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: loader.mixup_enabled = False @@ -717,13 +831,14 @@ def train_one_epoch( model.train() end = time.time() - last_idx = len(loader) - 1 - num_updates = epoch * len(loader) + num_batches_per_epoch = len(loader) + last_idx = num_batches_per_epoch - 1 + num_updates = epoch * num_batches_per_epoch for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: - input, target = input.cuda(), target.cuda() + input, target = input.to(device), target.to(device) if mixup_fn is not None: input, target = mixup_fn(input, target) if args.channels_last: @@ -740,21 +855,26 @@ def train_one_epoch( if loss_scaler is not None: loss_scaler( loss, optimizer, - clip_grad=args.clip_grad, clip_mode=args.clip_mode, + clip_grad=args.clip_grad, + clip_mode=args.clip_mode, parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), - create_graph=second_order) + create_graph=second_order + ) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: utils.dispatch_clip_grad( model_parameters(model, exclude_head='agc' in args.clip_mode), - value=args.clip_grad, mode=args.clip_mode) + value=args.clip_grad, + mode=args.clip_mode + ) optimizer.step() if model_ema is not None: model_ema.update(model) torch.cuda.synchronize() + num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: @@ -765,7 +885,7 @@ def train_one_epoch( reduced_loss = utils.reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) - if args.local_rank == 0: + if utils.is_primary(args): _logger.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' @@ -781,14 +901,16 @@ def train_one_epoch( rate=input.size(0) * args.world_size / batch_time_m.val, rate_avg=input.size(0) * args.world_size / batch_time_m.avg, lr=lr, - data_time=data_time_m)) + data_time=data_time_m) + ) if args.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, - normalize=True) + normalize=True + ) if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): @@ -806,7 +928,15 @@ def train_one_epoch( return OrderedDict([('loss', losses_m.avg)]) -def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): +def validate( + model, + loader, + loss_fn, + args, + device=torch.device('cuda'), + amp_autocast=suppress, + log_suffix='' +): batch_time_m = utils.AverageMeter() losses_m = utils.AverageMeter() top1_m = utils.AverageMeter() @@ -820,8 +950,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if not args.prefetcher: - input = input.cuda() - target = target.cuda() + input = input.to(device) + target = target.to(device) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) @@ -846,7 +976,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') else: reduced_loss = loss.data - torch.cuda.synchronize() + if device.type == 'cuda': + torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) @@ -854,7 +985,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') batch_time_m.update(time.time() - end) end = time.time() - if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): + if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info( '{0}: [{1:>4d}/{2}] ' @@ -862,8 +993,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( - log_name, batch_idx, last_idx, batch_time=batch_time_m, - loss=losses_m, top1=top1_m, top5=top5_m)) + log_name, batch_idx, last_idx, + batch_time=batch_time_m, + loss=losses_m, + top1=top1_m, + top5=top5_m) + ) metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) diff --git a/validate.py b/validate.py index 6244f052..1a1ea9cd 100755 --- a/validate.py +++ b/validate.py @@ -19,6 +19,7 @@ import torch.nn as nn import torch.nn.parallel from collections import OrderedDict from contextlib import suppress +from functools import partial from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet @@ -45,7 +46,6 @@ try: except ImportError as e: has_functorch = False -torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -100,12 +100,14 @@ parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') +parser.add_argument('--device', default='cuda', type=str, + help="Device (accelerator) to use.") parser.add_argument('--amp', action='store_true', default=False, - help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') -parser.add_argument('--apex-amp', action='store_true', default=False, - help='Use NVIDIA Apex AMP mixed precision') -parser.add_argument('--native-amp', action='store_true', default=False, - help='Use Native Torch AMP mixed precision') + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') +parser.add_argument('--amp-dtype', default='float16', type=str, + help='lower precision AMP dtype (default: float16)') +parser.add_argument('--amp-impl', default='native', type=str, + help='AMP impl to use, "native" or "apex" (default: native)') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', @@ -133,25 +135,35 @@ def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher - amp_autocast = suppress # do nothing + + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + + device = torch.device(args.device) + + # resolve AMP arguments based on PyTorch / Apex availability + use_amp = None + amp_autocast = suppress if args.amp: - if has_native_amp: - args.native_amp = True - elif has_apex: - args.apex_amp = True + if args.amp_impl == 'apex': + assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' + assert args.amp_dtype == 'float16' + use_amp = 'apex' + _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: - _logger.warning("Neither APEX or Native Torch AMP is available.") - assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." - if args.native_amp: - amp_autocast = torch.cuda.amp.autocast - _logger.info('Validating in mixed precision with native PyTorch AMP.') - elif args.apex_amp: - _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') + assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' + assert args.amp_dtype in ('float16', 'bfloat16') + use_amp = 'native' + amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + _logger.info('Validating in mixed precision with native PyTorch AMP.') else: _logger.info('Validating in float32. AMP not enabled.') if args.fuser: set_jit_fuser(args.fuser) + if args.fast_norm: set_fast_norm() @@ -162,7 +174,8 @@ def validate(args): num_classes=args.num_classes, in_chans=3, global_pool=args.gp, - scriptable=args.torchscript) + scriptable=args.torchscript, + ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes @@ -177,7 +190,7 @@ def validate(args): vars(args), model=model, use_test_size=not args.use_train_size, - verbose=True + verbose=True, ) test_time_pool = False if args.test_pool: @@ -186,12 +199,13 @@ def validate(args): if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) + if args.aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model) - model = model.cuda() - if args.apex_amp: + model = model.to(device) + if use_amp == 'apex': model = amp.initialize(model, opt_level='O1') if args.channels_last: @@ -200,11 +214,16 @@ def validate(args): if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) - criterion = nn.CrossEntropyLoss().cuda() + criterion = nn.CrossEntropyLoss().to(device) dataset = create_dataset( - root=args.data, name=args.dataset, split=args.split, - download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map) + root=args.data, + name=args.dataset, + split=args.split, + download=args.dataset_download, + load_bytes=args.tf_preprocessing, + class_map=args.class_map, + ) if args.valid_labels: with open(args.valid_labels, 'r') as f: @@ -230,7 +249,9 @@ def validate(args): num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, - tf_preprocessing=args.tf_preprocessing) + device=device, + tf_preprocessing=args.tf_preprocessing, + ) batch_time = AverageMeter() losses = AverageMeter() @@ -240,7 +261,7 @@ def validate(args): model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() + input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): @@ -249,8 +270,8 @@ def validate(args): end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: - target = target.cuda() - input = input.cuda() + target = target.to(device) + input = input.to(device) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) @@ -282,9 +303,15 @@ def validate(args): 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( - batch_idx, len(loader), batch_time=batch_time, + batch_idx, + len(loader), + batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, - loss=losses, top1=top1, top5=top5)) + loss=losses, + top1=top1, + top5=top5 + ) + ) if real_labels is not None: # real labels mode replaces topk values at the end @@ -298,7 +325,8 @@ def validate(args): param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], crop_pct=crop_pct, - interpolation=data_config['interpolation']) + interpolation=data_config['interpolation'], + ) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) @@ -313,7 +341,8 @@ def _try_run(args, initial_batch_size): while batch_size: args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case try: - torch.cuda.empty_cache() + if torch.cuda.is_available() and 'cuda' in args.device: + torch.cuda.empty_cache() results = validate(args) return results except RuntimeError as e: