Merge pull request #323 from rwightman/imagenet21k_datasets_more
BiT (Big Transfer) ResNetV2 models, Official ViT Hybrid R50 weights, VIT IN21K weights updated w/ repr layer, ImageNet21k and dataset / parser refactorpull/401/head
commit
9a38416fbd
File diff suppressed because it is too large
Load Diff
@ -1,10 +1,12 @@
|
||||
from .constants import *
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
from .config import resolve_data_config
|
||||
from .dataset import Dataset, DatasetTar, AugMixDataset
|
||||
from .transforms import *
|
||||
from .constants import *
|
||||
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
from .loader import create_loader
|
||||
from .transforms_factory import create_transform
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
from .parsers import create_parser
|
||||
from .real_labels import RealLabelsImagenet
|
||||
from .transforms import *
|
||||
from .transforms_factory import create_transform
|
@ -0,0 +1,29 @@
|
||||
import os
|
||||
|
||||
from .dataset import IterableImageDataset, ImageDataset
|
||||
|
||||
|
||||
def _search_split(root, split):
|
||||
# look for sub-folder with name of split in root and use that if it exists
|
||||
split_name = split.split('[')[0]
|
||||
try_root = os.path.join(root, split_name)
|
||||
if os.path.exists(try_root):
|
||||
return try_root
|
||||
if split_name == 'validation':
|
||||
try_root = os.path.join(root, 'val')
|
||||
if os.path.exists(try_root):
|
||||
return try_root
|
||||
return root
|
||||
|
||||
|
||||
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
|
||||
name = name.lower()
|
||||
if name.startswith('tfds'):
|
||||
ds = IterableImageDataset(
|
||||
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
||||
else:
|
||||
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
||||
if search_split and os.path.isdir(root):
|
||||
root = _search_split(root, split)
|
||||
ds = ImageDataset(root, parser=name, **kwargs)
|
||||
return ds
|
@ -0,0 +1 @@
|
||||
from .parser_factory import create_parser
|
@ -0,0 +1,16 @@
|
||||
import os
|
||||
|
||||
|
||||
def load_class_map(filename, root=''):
|
||||
class_map_path = filename
|
||||
if not os.path.exists(class_map_path):
|
||||
class_map_path = os.path.join(root, filename)
|
||||
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
|
||||
class_map_ext = os.path.splitext(filename)[-1].lower()
|
||||
if class_map_ext == '.txt':
|
||||
with open(class_map_path) as f:
|
||||
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
|
||||
else:
|
||||
assert False, 'Unsupported class map extension'
|
||||
return class_to_idx
|
||||
|
@ -0,0 +1 @@
|
||||
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
|
@ -0,0 +1,17 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
pass
|
||||
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
return self._filename(index, basename=basename, absolute=absolute)
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
|
||||
|
@ -0,0 +1,29 @@
|
||||
import os
|
||||
|
||||
from .parser_image_folder import ParserImageFolder
|
||||
from .parser_image_tar import ParserImageTar
|
||||
from .parser_image_in_tar import ParserImageInTar
|
||||
|
||||
|
||||
def create_parser(name, root, split='train', **kwargs):
|
||||
name = name.lower()
|
||||
name = name.split('/', 2)
|
||||
prefix = ''
|
||||
if len(name) > 1:
|
||||
prefix = name[0]
|
||||
name = name[-1]
|
||||
|
||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||
# explicitly select other options shortly
|
||||
if prefix == 'tfds':
|
||||
from .parser_tfds import ParserTfds # defer tensorflow import
|
||||
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
|
||||
else:
|
||||
assert os.path.exists(root)
|
||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||
# FIXME support split here, in parser?
|
||||
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
||||
parser = ParserImageInTar(root, **kwargs)
|
||||
else:
|
||||
parser = ParserImageFolder(root, **kwargs)
|
||||
return parser
|
@ -0,0 +1,69 @@
|
||||
""" A dataset parser that reads images from folders
|
||||
|
||||
Folders are scannerd recursively to find image files. Labels are based
|
||||
on the folder hierarchy, just leaf folders by default.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
|
||||
|
||||
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
|
||||
labels = []
|
||||
filenames = []
|
||||
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
|
||||
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
|
||||
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
|
||||
for f in files:
|
||||
base, ext = os.path.splitext(f)
|
||||
if ext.lower() in types:
|
||||
filenames.append(os.path.join(root, f))
|
||||
labels.append(label)
|
||||
if class_to_idx is None:
|
||||
# building class index
|
||||
unique_labels = set(labels)
|
||||
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
||||
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
||||
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
|
||||
if sort:
|
||||
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
|
||||
return images_and_targets, class_to_idx
|
||||
|
||||
|
||||
class ParserImageFolder(Parser):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
class_map=''):
|
||||
super().__init__()
|
||||
|
||||
self.root = root
|
||||
class_to_idx = None
|
||||
if class_map:
|
||||
class_to_idx = load_class_map(class_map, root)
|
||||
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
||||
if len(self.samples) == 0:
|
||||
raise RuntimeError(
|
||||
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
|
||||
|
||||
def __getitem__(self, index):
|
||||
path, target = self.samples[index]
|
||||
return open(path, 'rb'), target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
filename = self.samples[index][0]
|
||||
if basename:
|
||||
filename = os.path.basename(filename)
|
||||
elif not absolute:
|
||||
filename = os.path.relpath(filename, self.root)
|
||||
return filename
|
@ -0,0 +1,222 @@
|
||||
""" A dataset parser that reads tarfile based datasets
|
||||
|
||||
This parser can read and extract image samples from:
|
||||
* a single tar of image files
|
||||
* a folder of multiple tarfiles containing imagefiles
|
||||
* a tar of tars containing image files
|
||||
|
||||
Labels are based on the combined folder and/or tar name structure.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import tarfile
|
||||
import pickle
|
||||
import logging
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from typing import List, Dict
|
||||
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
||||
|
||||
|
||||
class TarState:
|
||||
|
||||
def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None):
|
||||
self.tf: tarfile.TarFile = tf
|
||||
self.ti: tarfile.TarInfo = ti
|
||||
self.children: Dict[str, TarState] = {} # child states (tars within tars)
|
||||
|
||||
def reset(self):
|
||||
self.tf = None
|
||||
|
||||
|
||||
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
|
||||
sample_count = 0
|
||||
for i, ti in enumerate(tf):
|
||||
if not ti.isfile():
|
||||
continue
|
||||
dirname, basename = os.path.split(ti.path)
|
||||
name, ext = os.path.splitext(basename)
|
||||
ext = ext.lower()
|
||||
if ext == '.tar':
|
||||
with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf:
|
||||
child_info = dict(
|
||||
name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[])
|
||||
sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions)
|
||||
_logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.')
|
||||
parent_info['children'].append(child_info)
|
||||
elif ext in extensions:
|
||||
parent_info['samples'].append(ti)
|
||||
sample_count += 1
|
||||
return sample_count
|
||||
|
||||
|
||||
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
|
||||
root_is_tar = False
|
||||
if os.path.isfile(root):
|
||||
assert os.path.splitext(root)[-1].lower() == '.tar'
|
||||
tar_filenames = [root]
|
||||
root, root_name = os.path.split(root)
|
||||
root_name = os.path.splitext(root_name)[0]
|
||||
root_is_tar = True
|
||||
else:
|
||||
root_name = root.strip(os.path.sep).split(os.path.sep)[-1]
|
||||
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
|
||||
num_tars = len(tar_filenames)
|
||||
tar_bytes = sum([os.path.getsize(f) for f in tar_filenames])
|
||||
assert num_tars, f'No .tar files found at specified path ({root}).'
|
||||
|
||||
_logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...')
|
||||
info = dict(tartrees=[])
|
||||
cache_path = ''
|
||||
if cache_tarinfo is None:
|
||||
cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB
|
||||
if cache_tarinfo:
|
||||
cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX
|
||||
cache_path = os.path.join(root, cache_filename)
|
||||
if os.path.exists(cache_path):
|
||||
_logger.info(f'Reading tar info from cache file {cache_path}.')
|
||||
with open(cache_path, 'rb') as pf:
|
||||
info = pickle.load(pf)
|
||||
assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles"
|
||||
else:
|
||||
for i, fn in enumerate(tar_filenames):
|
||||
path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0]
|
||||
with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode
|
||||
parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[])
|
||||
num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions)
|
||||
num_children = len(parent_info["children"])
|
||||
_logger.debug(
|
||||
f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.')
|
||||
info['tartrees'].append(parent_info)
|
||||
if cache_path:
|
||||
_logger.info(f'Writing tar info to cache file {cache_path}.')
|
||||
with open(cache_path, 'wb') as pf:
|
||||
pickle.dump(info, pf)
|
||||
|
||||
samples = []
|
||||
labels = []
|
||||
build_class_map = False
|
||||
if class_name_to_idx is None:
|
||||
build_class_map = True
|
||||
|
||||
# Flatten tartree info into lists of samples and targets w/ targets based on label id via
|
||||
# class map arg or from unique paths.
|
||||
# NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children
|
||||
# this covers my current use cases and keeps things a little easier to test for now.
|
||||
tarfiles = []
|
||||
|
||||
def _label_from_paths(*path, leaf_only=True):
|
||||
path = os.path.join(*path).strip(os.path.sep)
|
||||
return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_')
|
||||
|
||||
def _add_samples(info, fn):
|
||||
added = 0
|
||||
for s in info['samples']:
|
||||
label = _label_from_paths(info['path'], os.path.dirname(s.path))
|
||||
if not build_class_map and label not in class_name_to_idx:
|
||||
continue
|
||||
samples.append((s, fn, info['ti']))
|
||||
labels.append(label)
|
||||
added += 1
|
||||
return added
|
||||
|
||||
_logger.info(f'Collecting samples and building tar states.')
|
||||
for parent_info in info['tartrees']:
|
||||
# if tartree has children, we assume all samples are at the child level
|
||||
tar_name = None if root_is_tar else parent_info['name']
|
||||
tar_state = TarState()
|
||||
parent_added = 0
|
||||
for child_info in parent_info['children']:
|
||||
child_added = _add_samples(child_info, fn=tar_name)
|
||||
if child_added:
|
||||
tar_state.children[child_info['name']] = TarState(ti=child_info['ti'])
|
||||
parent_added += child_added
|
||||
parent_added += _add_samples(parent_info, fn=tar_name)
|
||||
if parent_added:
|
||||
tarfiles.append((tar_name, tar_state))
|
||||
del info
|
||||
|
||||
if build_class_map:
|
||||
# build class index
|
||||
sorted_labels = list(sorted(set(labels), key=natural_key))
|
||||
class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
||||
|
||||
_logger.info(f'Mapping targets and sorting samples.')
|
||||
samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
|
||||
if sort:
|
||||
samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
|
||||
samples, targets = zip(*samples_and_targets)
|
||||
samples = np.array(samples)
|
||||
targets = np.array(targets)
|
||||
_logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.')
|
||||
return samples, targets, class_name_to_idx, tarfiles
|
||||
|
||||
|
||||
class ParserImageInTar(Parser):
|
||||
""" Multi-tarfile dataset parser where there is one .tar file per class
|
||||
"""
|
||||
|
||||
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
|
||||
super().__init__()
|
||||
|
||||
class_name_to_idx = None
|
||||
if class_map:
|
||||
class_name_to_idx = load_class_map(class_map, root)
|
||||
self.root = root
|
||||
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
|
||||
self.root,
|
||||
class_name_to_idx=class_name_to_idx,
|
||||
cache_tarinfo=cache_tarinfo,
|
||||
extensions=IMG_EXTENSIONS)
|
||||
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
|
||||
if len(tarfiles) == 1 and tarfiles[0][0] is None:
|
||||
self.root_is_tar = True
|
||||
self.tar_state = tarfiles[0][1]
|
||||
else:
|
||||
self.root_is_tar = False
|
||||
self.tar_state = dict(tarfiles)
|
||||
self.cache_tarfiles = cache_tarfiles
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
target = self.targets[index]
|
||||
sample_ti, parent_fn, child_ti = sample
|
||||
parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
|
||||
|
||||
tf = None
|
||||
cache_state = None
|
||||
if self.cache_tarfiles:
|
||||
cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn]
|
||||
tf = cache_state.tf
|
||||
if tf is None:
|
||||
tf = tarfile.open(parent_abs)
|
||||
if self.cache_tarfiles:
|
||||
cache_state.tf = tf
|
||||
if child_ti is not None:
|
||||
ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None
|
||||
if ctf is None:
|
||||
ctf = tarfile.open(fileobj=tf.extractfile(child_ti))
|
||||
if self.cache_tarfiles:
|
||||
cache_state.children[child_ti.name].tf = ctf
|
||||
tf = ctf
|
||||
|
||||
return tf.extractfile(sample_ti), target
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
filename = self.samples[index][0].name
|
||||
if basename:
|
||||
filename = os.path.basename(filename)
|
||||
return filename
|
@ -0,0 +1,72 @@
|
||||
""" A dataset parser that reads single tarfile based datasets
|
||||
|
||||
This parser can read datasets consisting if a single tarfile containing images.
|
||||
I am planning to deprecated it in favour of ParerImageInTar.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
|
||||
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||
files = []
|
||||
labels = []
|
||||
for ti in tarfile.getmembers():
|
||||
if not ti.isfile():
|
||||
continue
|
||||
dirname, basename = os.path.split(ti.path)
|
||||
label = os.path.basename(dirname)
|
||||
ext = os.path.splitext(basename)[1]
|
||||
if ext.lower() in IMG_EXTENSIONS:
|
||||
files.append(ti)
|
||||
labels.append(label)
|
||||
if class_to_idx is None:
|
||||
unique_labels = set(labels)
|
||||
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
||||
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
||||
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
|
||||
if sort:
|
||||
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
|
||||
return tarinfo_and_targets, class_to_idx
|
||||
|
||||
|
||||
class ParserImageTar(Parser):
|
||||
""" Single tarfile dataset where classes are mapped to folders within tar
|
||||
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
|
||||
operate on folders of tars or tars in tars.
|
||||
"""
|
||||
def __init__(self, root, class_map=''):
|
||||
super().__init__()
|
||||
|
||||
class_to_idx = None
|
||||
if class_map:
|
||||
class_to_idx = load_class_map(class_map, root)
|
||||
assert os.path.isfile(root)
|
||||
self.root = root
|
||||
|
||||
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
|
||||
self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
|
||||
self.imgs = self.samples
|
||||
self.tarfile = None # lazy init in __getitem__
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.tarfile is None:
|
||||
self.tarfile = tarfile.open(self.root)
|
||||
tarinfo, target = self.samples[index]
|
||||
fileobj = self.tarfile.extractfile(tarinfo)
|
||||
return fileobj, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
filename = self.samples[index][0].name
|
||||
if basename:
|
||||
filename = os.path.basename(filename)
|
||||
return filename
|
@ -0,0 +1,201 @@
|
||||
""" Dataset parser interface that wraps TFDS datasets
|
||||
|
||||
Wraps many (most?) TFDS image-classification datasets
|
||||
from https://github.com/tensorflow/datasets
|
||||
https://www.tensorflow.org/datasets/catalog/overview#image_classification
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import io
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
|
||||
import tensorflow_datasets as tfds
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||
exit(1)
|
||||
from .parser import Parser
|
||||
|
||||
|
||||
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
||||
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
|
||||
PREFETCH_SIZE = 4096 # samples to prefetch
|
||||
|
||||
|
||||
class ParserTfds(Parser):
|
||||
""" Wrap Tensorflow Datasets for use in PyTorch
|
||||
|
||||
There several things to be aware of:
|
||||
* To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
|
||||
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
|
||||
https://github.com/pytorch/pytorch/issues/33413
|
||||
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
|
||||
from each worker could be a different size. For training this is worked around by option above, for
|
||||
validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
|
||||
across replicas are of same size. This will slightly alter the results, distributed validation will not be
|
||||
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
|
||||
since there are up to N * J extra samples with IterableDatasets.
|
||||
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
|
||||
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
|
||||
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
|
||||
benefit of distributed training or fast dataloading should be much less for small datasets.
|
||||
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS
|
||||
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
|
||||
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
|
||||
components.
|
||||
|
||||
"""
|
||||
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
|
||||
super().__init__()
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.shuffle = shuffle
|
||||
self.is_training = is_training
|
||||
if self.is_training:
|
||||
assert batch_size is not None,\
|
||||
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.builder = tfds.builder(name, data_dir=root)
|
||||
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
|
||||
# download_and_prepare() by default here as it's caused issues generating unwanted paths.
|
||||
self.num_samples = self.builder.info.splits[split].num_examples
|
||||
self.ds = None # initialized lazily on each dataloader worker process
|
||||
|
||||
self.worker_info = None
|
||||
self.dist_rank = 0
|
||||
self.dist_num_replicas = 1
|
||||
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
self.dist_rank = dist.get_rank()
|
||||
self.dist_num_replicas = dist.get_world_size()
|
||||
|
||||
def _lazy_init(self):
|
||||
""" Lazily initialize the dataset.
|
||||
|
||||
This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that
|
||||
will be using the dataset instance. The __init__ method is called on the main process,
|
||||
this will be called in a dataloader worker process.
|
||||
|
||||
NOTE: There will be problems if you try to re-use this dataset across different loader/worker
|
||||
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
|
||||
before it is passed to dataloader.
|
||||
"""
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
# setup input context to split dataset across distributed processes
|
||||
split = self.split
|
||||
num_workers = 1
|
||||
if worker_info is not None:
|
||||
self.worker_info = worker_info
|
||||
num_workers = worker_info.num_workers
|
||||
worker_id = worker_info.id
|
||||
|
||||
# FIXME I need to spend more time figuring out the best way to distribute/split data across
|
||||
# combo of distributed replicas + dataloader worker processes
|
||||
"""
|
||||
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
|
||||
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
|
||||
between the splits each iteration, but that understanding could be wrong.
|
||||
Possible split options include:
|
||||
* InputContext for both distributed & worker processes (current)
|
||||
* InputContext for distributed and sub-splits for worker processes
|
||||
* sub-splits for both
|
||||
"""
|
||||
# split_size = self.num_samples // num_workers
|
||||
# start = worker_id * split_size
|
||||
# if worker_id == num_workers - 1:
|
||||
# split = split + '[{}:]'.format(start)
|
||||
# else:
|
||||
# split = split + '[{}:{}]'.format(start, start + split_size)
|
||||
|
||||
input_context = tf.distribute.InputContext(
|
||||
num_input_pipelines=self.dist_num_replicas * num_workers,
|
||||
input_pipeline_id=self.dist_rank * num_workers + worker_id,
|
||||
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact?
|
||||
)
|
||||
|
||||
read_config = tfds.ReadConfig(input_context=input_context)
|
||||
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config)
|
||||
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
||||
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
||||
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
||||
if self.is_training:
|
||||
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
||||
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||
ds = ds.repeat() # allow wrap around and break iteration manually
|
||||
if self.shuffle:
|
||||
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0)
|
||||
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
|
||||
self.ds = tfds.as_numpy(ds)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
# compute a rounded up sample count that is used to:
|
||||
# 1. make batches even cross workers & replicas in distributed validation.
|
||||
# This adds extra samples and will slightly alter validation results.
|
||||
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
||||
# batches are produced (underlying tfds iter wraps around)
|
||||
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
|
||||
if self.is_training:
|
||||
# round up to nearest batch_size per worker-replica
|
||||
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
|
||||
sample_count = 0
|
||||
for sample in self.ds:
|
||||
img = Image.fromarray(sample['image'], mode='RGB')
|
||||
yield img, sample['label']
|
||||
sample_count += 1
|
||||
if self.is_training and sample_count >= target_sample_count:
|
||||
# Need to break out of loop when repeat() is enabled for training w/ oversampling
|
||||
# this results in extra samples per epoch but seems more desirable than dropping
|
||||
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
|
||||
break
|
||||
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
|
||||
# Validation batch padding only done for distributed training where results are reduced across nodes.
|
||||
# For single process case, it won't matter if workers return different batch sizes.
|
||||
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
|
||||
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
|
||||
yield img, sample['label'] # yield prev sample again
|
||||
sample_count += 1
|
||||
|
||||
@property
|
||||
def _num_workers(self):
|
||||
return 1 if self.worker_info is None else self.worker_info.num_workers
|
||||
|
||||
@property
|
||||
def _num_pipelines(self):
|
||||
return self._num_workers * self.dist_num_replicas
|
||||
|
||||
def __len__(self):
|
||||
# this is just an estimate and does not factor in extra samples added to pad batches based on
|
||||
# complete worker & replica info (not available until init in dataloader).
|
||||
return math.ceil(self.num_samples / self.dist_num_replicas)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
assert False, "Not supported" # no random access to samples
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
""" Return all filenames in dataset, overrides base"""
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
names = []
|
||||
for sample in self.ds:
|
||||
if len(names) > self.num_samples:
|
||||
break # safety for ds.repeat() case
|
||||
if 'file_name' in sample:
|
||||
name = sample['file_name']
|
||||
elif 'filename' in sample:
|
||||
name = sample['filename']
|
||||
elif 'id' in sample:
|
||||
name = sample['id']
|
||||
else:
|
||||
assert False, "No supported name field present"
|
||||
names.append(name)
|
||||
return names
|
@ -0,0 +1,593 @@
|
||||
"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization.
|
||||
|
||||
A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code
|
||||
at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have
|
||||
been included here as pretrained models from their original .NPZ checkpoints.
|
||||
|
||||
Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and
|
||||
extra padding support to allow porting of official Hybrid ResNet pretrained weights from
|
||||
https://github.com/google-research/vision_transformer
|
||||
|
||||
Thanks to the Google team for the above two repositories and associated papers:
|
||||
* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370
|
||||
* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929
|
||||
|
||||
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
|
||||
"""
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict # pylint: disable=g-importing-member
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
from .layers import get_padding, GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, conv2d_same
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 480, 480), 'pool_size': (7, 7),
|
||||
'crop_pct': 1.0, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# pretrained on imagenet21k, finetuned on imagenet1k
|
||||
'resnetv2_50x1_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz'),
|
||||
'resnetv2_50x3_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz'),
|
||||
'resnetv2_101x1_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz'),
|
||||
'resnetv2_101x3_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz'),
|
||||
'resnetv2_152x2_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz'),
|
||||
'resnetv2_152x4_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz'),
|
||||
|
||||
# trained on imagenet-21k
|
||||
'resnetv2_50x1_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_50x3_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_101x1_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_101x3_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_152x2_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_152x4_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
|
||||
num_classes=21843),
|
||||
|
||||
|
||||
# trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now
|
||||
# 'resnetv2_50x1_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'),
|
||||
# 'resnetv2_50x3_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'),
|
||||
# 'resnetv2_101x1_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
|
||||
# 'resnetv2_101x3_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
|
||||
# 'resnetv2_152x2_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'),
|
||||
# 'resnetv2_152x4_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'),
|
||||
}
|
||||
|
||||
|
||||
def make_div(v, divisor=8):
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class StdConv2d(nn.Conv2d):
|
||||
|
||||
def __init__(
|
||||
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
in_channel, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=bias, groups=groups)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight
|
||||
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
w = (w - m) / (torch.sqrt(v) + self.eps)
|
||||
x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
|
||||
|
||||
class StdConv2dSame(nn.Conv2d):
|
||||
"""StdConv2d w/ TF compatible SAME padding. Used for ViT Hybrid model.
|
||||
"""
|
||||
def __init__(
|
||||
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
in_channel, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=bias, groups=groups)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight
|
||||
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
w = (w - m) / (torch.sqrt(v) + self.eps)
|
||||
x = conv2d_same(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
|
||||
|
||||
def tf2th(conv_weights):
|
||||
"""Possibly convert HWIO to OIHW."""
|
||||
if conv_weights.ndim == 4:
|
||||
conv_weights = conv_weights.transpose([3, 2, 0, 1])
|
||||
return torch.from_numpy(conv_weights)
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
"""Pre-activation (v2) bottleneck block.
|
||||
|
||||
Follows the implementation of "Identity Mappings in Deep Residual Networks":
|
||||
https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
|
||||
|
||||
Except it puts the stride on 3x3 conv when available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
|
||||
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
|
||||
super().__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
conv_layer = conv_layer or StdConv2d
|
||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||
out_chs = out_chs or in_chs
|
||||
mid_chs = make_div(out_chs * bottle_ratio)
|
||||
|
||||
if proj_layer is not None:
|
||||
self.downsample = proj_layer(
|
||||
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
|
||||
conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.norm1 = norm_layer(in_chs)
|
||||
self.conv1 = conv_layer(in_chs, mid_chs, 1)
|
||||
self.norm2 = norm_layer(mid_chs)
|
||||
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
|
||||
self.norm3 = norm_layer(mid_chs)
|
||||
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x_preact = self.norm1(x)
|
||||
|
||||
# shortcut branch
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x_preact)
|
||||
|
||||
# residual branch
|
||||
x = self.conv1(x_preact)
|
||||
x = self.conv2(self.norm2(x))
|
||||
x = self.conv3(self.norm3(x))
|
||||
x = self.drop_path(x)
|
||||
return x + shortcut
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
|
||||
"""
|
||||
def __init__(
|
||||
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
|
||||
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
|
||||
super().__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
act_layer = act_layer or nn.ReLU
|
||||
conv_layer = conv_layer or StdConv2d
|
||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||
out_chs = out_chs or in_chs
|
||||
mid_chs = make_div(out_chs * bottle_ratio)
|
||||
|
||||
if proj_layer is not None:
|
||||
self.downsample = proj_layer(
|
||||
in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
|
||||
conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.conv1 = conv_layer(in_chs, mid_chs, 1)
|
||||
self.norm1 = norm_layer(mid_chs)
|
||||
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
|
||||
self.norm2 = norm_layer(mid_chs)
|
||||
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
||||
self.norm3 = norm_layer(out_chs, apply_act=False)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.act3 = act_layer(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
# shortcut branch
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
# residual
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.norm3(x)
|
||||
x = self.drop_path(x)
|
||||
x = self.act3(x + shortcut)
|
||||
return x
|
||||
|
||||
|
||||
class DownsampleConv(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
|
||||
conv_layer=None, norm_layer=None):
|
||||
super(DownsampleConv, self).__init__()
|
||||
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
|
||||
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.conv(x))
|
||||
|
||||
|
||||
class DownsampleAvg(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
|
||||
preact=True, conv_layer=None, norm_layer=None):
|
||||
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
|
||||
super(DownsampleAvg, self).__init__()
|
||||
avg_stride = stride if dilation == 1 else 1
|
||||
if stride > 1 or dilation > 1:
|
||||
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
||||
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
||||
else:
|
||||
self.pool = nn.Identity()
|
||||
self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
|
||||
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.conv(self.pool(x)))
|
||||
|
||||
|
||||
class ResNetStage(nn.Module):
|
||||
"""ResNet Stage."""
|
||||
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
|
||||
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
|
||||
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
|
||||
super(ResNetStage, self).__init__()
|
||||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
proj_layer = DownsampleAvg if avg_down else DownsampleConv
|
||||
prev_chs = in_chs
|
||||
self.blocks = nn.Sequential()
|
||||
for block_idx in range(depth):
|
||||
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
|
||||
stride = stride if block_idx == 0 else 1
|
||||
self.blocks.add_module(str(block_idx), block_fn(
|
||||
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
|
||||
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
|
||||
**layer_kwargs, **block_kwargs))
|
||||
prev_chs = out_chs
|
||||
first_dilation = dilation
|
||||
proj_layer = None
|
||||
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None):
|
||||
stem = OrderedDict()
|
||||
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
|
||||
|
||||
# NOTE conv padding mode can be changed by overriding the conv_layer def
|
||||
if 'deep' in stem_type:
|
||||
# A 3 deep 3x3 conv stack as in ResNet V1D models
|
||||
mid_chs = out_chs // 2
|
||||
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
||||
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
||||
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
|
||||
else:
|
||||
# The usual 7x7 stem conv
|
||||
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
|
||||
|
||||
if not preact:
|
||||
stem['norm'] = norm_layer(out_chs)
|
||||
|
||||
if 'fixed' in stem_type:
|
||||
# 'fixed' SAME padding approximation that is used in BiT models
|
||||
stem['pad'] = nn.ConstantPad2d(1, 0.)
|
||||
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
|
||||
elif 'same' in stem_type:
|
||||
# full, input size based 'SAME' padding, used in ViT Hybrid model
|
||||
stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same')
|
||||
else:
|
||||
# the usual PyTorch symmetric padding
|
||||
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
return nn.Sequential(stem)
|
||||
|
||||
|
||||
class ResNetV2(nn.Module):
|
||||
"""Implementation of Pre-activation (v2) ResNet mode.
|
||||
"""
|
||||
|
||||
def __init__(self, layers, channels=(256, 512, 1024, 2048),
|
||||
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
||||
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
||||
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
|
||||
drop_rate=0., drop_path_rate=0.):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
wf = width_factor
|
||||
|
||||
self.feature_info = []
|
||||
stem_chs = make_div(stem_chs * wf)
|
||||
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
# NOTE no, reduction 2 feature if preact
|
||||
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm'))
|
||||
|
||||
prev_chs = stem_chs
|
||||
curr_stride = 4
|
||||
dilation = 1
|
||||
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
|
||||
block_fn = PreActBottleneck if preact else Bottleneck
|
||||
self.stages = nn.Sequential()
|
||||
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
|
||||
out_chs = make_div(c * wf)
|
||||
stride = 1 if stage_idx == 0 else 2
|
||||
if curr_stride >= output_stride:
|
||||
dilation *= stride
|
||||
stride = 1
|
||||
stage = ResNetStage(
|
||||
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
|
||||
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
|
||||
prev_chs = out_chs
|
||||
curr_stride *= stride
|
||||
feat_name = f'stages.{stage_idx}'
|
||||
if preact:
|
||||
feat_name = f'stages.{stage_idx + 1}.blocks.0.norm1' if (stage_idx + 1) != len(channels) else 'norm'
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=feat_name)]
|
||||
self.stages.add_module(str(stage_idx), stage)
|
||||
|
||||
self.num_features = prev_chs
|
||||
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)):
|
||||
nn.init.normal_(m.weight, mean=0.0, std=0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
if not self.head.global_pool.is_identity():
|
||||
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
|
||||
return x
|
||||
|
||||
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
|
||||
import numpy as np
|
||||
weights = np.load(checkpoint_path)
|
||||
with torch.no_grad():
|
||||
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])
|
||||
if self.stem.conv.weight.shape[1] == 1:
|
||||
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True))
|
||||
# FIXME handle > 3 in_chans?
|
||||
else:
|
||||
self.stem.conv.weight.copy_(stem_conv_w)
|
||||
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
|
||||
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
|
||||
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
|
||||
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
|
||||
for i, (sname, stage) in enumerate(self.stages.named_children()):
|
||||
for j, (bname, block) in enumerate(stage.blocks.named_children()):
|
||||
convname = 'standardized_conv2d'
|
||||
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
|
||||
block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel']))
|
||||
block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel']))
|
||||
block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel']))
|
||||
block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma']))
|
||||
block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma']))
|
||||
block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma']))
|
||||
block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta']))
|
||||
block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta']))
|
||||
block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta']))
|
||||
if block.downsample is not None:
|
||||
w = weights[f'{block_prefix}a/proj/{convname}/kernel']
|
||||
block.downsample.conv.weight.copy_(tf2th(w))
|
||||
|
||||
|
||||
def _create_resnetv2(variant, pretrained=False, **kwargs):
|
||||
# FIXME feature map extraction is not setup properly for pre-activation mode right now
|
||||
preact = kwargs.get('preact', True)
|
||||
feature_cfg = dict(flatten_sequential=True)
|
||||
if preact:
|
||||
feature_cfg['feature_cls'] = 'hook'
|
||||
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for preact
|
||||
|
||||
return build_model_with_cfg(
|
||||
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
|
||||
feature_cfg=feature_cfg, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x2_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x4_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M.
|
||||
|
||||
# @register_model
|
||||
# def resnetv2_50x1_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_50x1_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_50x3_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_101x1_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_101x3_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_152x2_bits', pretrained=pretrained,
|
||||
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_152x4_bits', pretrained=pretrained,
|
||||
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
#
|
@ -1 +1 @@
|
||||
__version__ = '0.3.4'
|
||||
__version__ = '0.4.0'
|
||||
|
Loading…
Reference in new issue