Merge pull request #323 from rwightman/imagenet21k_datasets_more
BiT (Big Transfer) ResNetV2 models, Official ViT Hybrid R50 weights, VIT IN21K weights updated w/ repr layer, ImageNet21k and dataset / parser refactorpull/401/head
File diff suppressed because it is too large
Load Diff
@ -1,10 +1,12 @@
from .constants import *
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .config import resolve_data_config
from .dataset import Dataset, DatasetTar, AugMixDataset
from .transforms import *
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .loader import create_loader
from .transforms_factory import create_transform
from .mixup import Mixup, FastCollateMixup
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .parsers import create_parser
from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform
@ -0,0 +1,29 @@
import os
from .dataset import IterableImageDataset, ImageDataset
def _search_split(root, split):
# look for sub-folder with name of split in root and use that if it exists
split_name = split.split('[')[0]
try_root = os.path.join(root, split_name)
if os.path.exists(try_root):
return try_root
if split_name == 'validation':
try_root = os.path.join(root, 'val')
if os.path.exists(try_root):
return try_root
return root
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
name = name.lower()
if name.startswith('tfds'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
if search_split and os.path.isdir(root):
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs)
return ds
@ -0,0 +1 @@
from .parser_factory import create_parser
@ -0,0 +1,16 @@
import os
def load_class_map(filename, root=''):
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
assert False, 'Unsupported class map extension'
return class_to_idx
@ -0,0 +1 @@
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
@ -0,0 +1,17 @@
from abc import abstractmethod
class Parser:
def __init__(self):
def _filename(self, index, basename=False, absolute=False):
def filename(self, index, basename=False, absolute=False):
return self._filename(index, basename=basename, absolute=absolute)
def filenames(self, basename=False, absolute=False):
return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
@ -0,0 +1,29 @@
import os
from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar
from .parser_image_in_tar import ParserImageInTar
def create_parser(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here, in parser?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageInTar(root, **kwargs)
parser = ParserImageFolder(root, **kwargs)
return parser
@ -0,0 +1,69 @@
""" A dataset parser that reads images from folders
Folders are scannerd recursively to find image files. Labels are based
on the folder hierarchy, just leaf folders by default.
Hacked together by / Copyright 2020 Ross Wightman
import os
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
filenames.append(os.path.join(root, f))
if class_to_idx is None:
# building class index
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx
class ParserImageFolder(Parser):
def __init__(
self.root = root
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(self.samples) == 0:
raise RuntimeError(
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
def __getitem__(self, index):
path, target = self.samples[index]
return open(path, 'rb'), target
def __len__(self):
return len(self.samples)
def _filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0]
if basename:
filename = os.path.basename(filename)
elif not absolute:
filename = os.path.relpath(filename, self.root)
return filename
@ -0,0 +1,222 @@
""" A dataset parser that reads tarfile based datasets
This parser can read and extract image samples from:
* a single tar of image files
* a folder of multiple tarfiles containing imagefiles
* a tar of tars containing image files
Labels are based on the combined folder and/or tar name structure.
Hacked together by / Copyright 2020 Ross Wightman
import os
import tarfile
import pickle
import logging
import numpy as np
from glob import glob
from typing import List, Dict
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
_logger = logging.getLogger(__name__)
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
class TarState:
def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None):
|||| tarfile.TarFile = tf
self.ti: tarfile.TarInfo = ti
self.children: Dict[str, TarState] = {} # child states (tars within tars)
def reset(self):
|||| = None
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
sample_count = 0
for i, ti in enumerate(tf):
if not ti.isfile():
dirname, basename = os.path.split(ti.path)
name, ext = os.path.splitext(basename)
ext = ext.lower()
if ext == '.tar':
with, mode='r|') as ctf:
child_info = dict(
||||, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[])
sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions)
_logger.debug(f'{i}/?. Extracted child tarinfos from {}. {len(child_info["samples"])} images.')
elif ext in extensions:
sample_count += 1
return sample_count
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
root_is_tar = False
if os.path.isfile(root):
assert os.path.splitext(root)[-1].lower() == '.tar'
tar_filenames = [root]
root, root_name = os.path.split(root)
root_name = os.path.splitext(root_name)[0]
root_is_tar = True
root_name = root.strip(os.path.sep).split(os.path.sep)[-1]
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
num_tars = len(tar_filenames)
tar_bytes = sum([os.path.getsize(f) for f in tar_filenames])
assert num_tars, f'No .tar files found at specified path ({root}).'
||||'Scanning {tar_bytes/1024**2:.2f}MB of tar files...')
info = dict(tartrees=[])
cache_path = ''
if cache_tarinfo is None:
cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB
if cache_tarinfo:
cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX
cache_path = os.path.join(root, cache_filename)
if os.path.exists(cache_path):
||||'Reading tar info from cache file {cache_path}.')
with open(cache_path, 'rb') as pf:
info = pickle.load(pf)
assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles"
for i, fn in enumerate(tar_filenames):
path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0]
with, mode='r|') as tf: # tarinfo scans done in streaming mode
parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[])
num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions)
num_children = len(parent_info["children"])
f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.')
if cache_path:
||||'Writing tar info to cache file {cache_path}.')
with open(cache_path, 'wb') as pf:
pickle.dump(info, pf)
samples = []
labels = []
build_class_map = False
if class_name_to_idx is None:
build_class_map = True
# Flatten tartree info into lists of samples and targets w/ targets based on label id via
# class map arg or from unique paths.
# NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children
# this covers my current use cases and keeps things a little easier to test for now.
tarfiles = []
def _label_from_paths(*path, leaf_only=True):
path = os.path.join(*path).strip(os.path.sep)
return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_')
def _add_samples(info, fn):
added = 0
for s in info['samples']:
label = _label_from_paths(info['path'], os.path.dirname(s.path))
if not build_class_map and label not in class_name_to_idx:
samples.append((s, fn, info['ti']))
added += 1
return added
||||'Collecting samples and building tar states.')
for parent_info in info['tartrees']:
# if tartree has children, we assume all samples are at the child level
tar_name = None if root_is_tar else parent_info['name']
tar_state = TarState()
parent_added = 0
for child_info in parent_info['children']:
child_added = _add_samples(child_info, fn=tar_name)
if child_added:
tar_state.children[child_info['name']] = TarState(ti=child_info['ti'])
parent_added += child_added
parent_added += _add_samples(parent_info, fn=tar_name)
if parent_added:
tarfiles.append((tar_name, tar_state))
del info
if build_class_map:
# build class index
sorted_labels = list(sorted(set(labels), key=natural_key))
class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
||||'Mapping targets and sorting samples.')
samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
if sort:
samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
samples, targets = zip(*samples_and_targets)
samples = np.array(samples)
targets = np.array(targets)
||||'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.')
return samples, targets, class_name_to_idx, tarfiles
class ParserImageInTar(Parser):
""" Multi-tarfile dataset parser where there is one .tar file per class
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
class_name_to_idx = None
if class_map:
class_name_to_idx = load_class_map(class_map, root)
self.root = root
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
if len(tarfiles) == 1 and tarfiles[0][0] is None:
self.root_is_tar = True
self.tar_state = tarfiles[0][1]
self.root_is_tar = False
self.tar_state = dict(tarfiles)
self.cache_tarfiles = cache_tarfiles
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
target = self.targets[index]
sample_ti, parent_fn, child_ti = sample
parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
tf = None
cache_state = None
if self.cache_tarfiles:
cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn]
tf =
if tf is None:
tf =
if self.cache_tarfiles:
|||| = tf
if child_ti is not None:
ctf = cache_state.children[].tf if self.cache_tarfiles else None
if ctf is None:
ctf =
if self.cache_tarfiles:
cache_state.children[].tf = ctf
tf = ctf
return tf.extractfile(sample_ti), target
def _filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0].name
if basename:
filename = os.path.basename(filename)
return filename
@ -0,0 +1,72 @@
""" A dataset parser that reads single tarfile based datasets
This parser can read datasets consisting if a single tarfile containing images.
I am planning to deprecated it in favour of ParerImageInTar.
Hacked together by / Copyright 2020 Ross Wightman
import os
import tarfile
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from timm.utils.misc import natural_key
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
files = []
labels = []
for ti in tarfile.getmembers():
if not ti.isfile():
dirname, basename = os.path.split(ti.path)
label = os.path.basename(dirname)
ext = os.path.splitext(basename)[1]
if ext.lower() in IMG_EXTENSIONS:
if class_to_idx is None:
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
if sort:
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
return tarinfo_and_targets, class_to_idx
class ParserImageTar(Parser):
""" Single tarfile dataset where classes are mapped to folders within tar
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
operate on folders of tars or tars in tars.
def __init__(self, root, class_map=''):
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
assert os.path.isfile(root)
self.root = root
with as tf: # cannot keep this open across processes, reopen later
self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
self.imgs = self.samples
self.tarfile = None # lazy init in __getitem__
def __getitem__(self, index):
if self.tarfile is None:
self.tarfile =
tarinfo, target = self.samples[index]
fileobj = self.tarfile.extractfile(tarinfo)
return fileobj, target
def __len__(self):
return len(self.samples)
def _filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0].name
if basename:
filename = os.path.basename(filename)
return filename
@ -0,0 +1,201 @@
""" Dataset parser interface that wraps TFDS datasets
Wraps many (most?) TFDS image-classification datasets
Hacked together by / Copyright 2020 Ross Wightman
import os
import io
import math
import torch
import torch.distributed as dist
from PIL import Image
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
import tensorflow_datasets as tfds
except ImportError as e:
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
from .parser import Parser
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
PREFETCH_SIZE = 4096 # samples to prefetch
class ParserTfds(Parser):
""" Wrap Tensorflow Datasets for use in PyTorch
There several things to be aware of:
* To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
from each worker could be a different size. For training this is worked around by option above, for
validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
across replicas are of same size. This will slightly alter the results, distributed validation will not be
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
since there are up to N * J extra samples with IterableDatasets.
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
benefit of distributed training or fast dataloading should be much less for small datasets.
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
self.root = root
self.split = split
self.shuffle = shuffle
self.is_training = is_training
if self.is_training:
assert batch_size is not None,\
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size
self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
# download_and_prepare() by default here as it's caused issues generating unwanted paths.
self.num_samples =[split].num_examples
self.ds = None # initialized lazily on each dataloader worker process
self.worker_info = None
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
def _lazy_init(self):
""" Lazily initialize the dataset.
This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that
will be using the dataset instance. The __init__ method is called on the main process,
this will be called in a dataloader worker process.
NOTE: There will be problems if you try to re-use this dataset across different loader/worker
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
before it is passed to dataloader.
worker_info =
# setup input context to split dataset across distributed processes
split = self.split
num_workers = 1
if worker_info is not None:
self.worker_info = worker_info
num_workers = worker_info.num_workers
worker_id =
# FIXME I need to spend more time figuring out the best way to distribute/split data across
# combo of distributed replicas + dataloader worker processes
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
between the splits each iteration, but that understanding could be wrong.
Possible split options include:
* InputContext for both distributed & worker processes (current)
* InputContext for distributed and sub-splits for worker processes
* sub-splits for both
# split_size = self.num_samples // num_workers
# start = worker_id * split_size
# if worker_id == num_workers - 1:
# split = split + '[{}:]'.format(start)
# else:
# split = split + '[{}:{}]'.format(start, start + split_size)
input_context = tf.distribute.InputContext(
num_input_pipelines=self.dist_num_replicas * num_workers,
input_pipeline_id=self.dist_rank * num_workers + worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact?
read_config = tfds.ReadConfig(input_context=input_context)
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
ds.options().experimental_threading.max_intra_op_parallelism = 1
if self.is_training:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at
ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle:
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
self.ds = tfds.as_numpy(ds)
def __iter__(self):
if self.ds is None:
# compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation.
# This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around)
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
if self.is_training:
# round up to nearest batch_size per worker-replica
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
sample_count = 0
for sample in self.ds:
img = Image.fromarray(sample['image'], mode='RGB')
yield img, sample['label']
sample_count += 1
if self.is_training and sample_count >= target_sample_count:
# Need to break out of loop when repeat() is enabled for training w/ oversampling
# this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes.
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
yield img, sample['label'] # yield prev sample again
sample_count += 1
def _num_workers(self):
return 1 if self.worker_info is None else self.worker_info.num_workers
def _num_pipelines(self):
return self._num_workers * self.dist_num_replicas
def __len__(self):
# this is just an estimate and does not factor in extra samples added to pad batches based on
# complete worker & replica info (not available until init in dataloader).
return math.ceil(self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to samples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
if self.ds is None:
names = []
for sample in self.ds:
if len(names) > self.num_samples:
break # safety for ds.repeat() case
if 'file_name' in sample:
name = sample['file_name']
elif 'filename' in sample:
name = sample['filename']
elif 'id' in sample:
name = sample['id']
assert False, "No supported name field present"
return names
@ -0,0 +1,593 @@
"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization.
A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code
at to match timm interfaces. The BiT weights have
been included here as pretrained models from their original .NPZ checkpoints.
Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and
extra padding support to allow porting of official Hybrid ResNet pretrained weights from
Thanks to the Google team for the above two repositories and associated papers:
* Big Transfer (BiT): General Visual Representation Learning -
* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale -
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
# Copyright 2020 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict # pylint: disable=g-importing-member
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import get_padding, GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, conv2d_same
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 480, 480), 'pool_size': (7, 7),
'crop_pct': 1.0, 'interpolation': 'bilinear',
'first_conv': 'stem.conv', 'classifier': 'head.fc',
default_cfgs = {
# pretrained on imagenet21k, finetuned on imagenet1k
'resnetv2_50x1_bitm': _cfg(
'resnetv2_50x3_bitm': _cfg(
'resnetv2_101x1_bitm': _cfg(
'resnetv2_101x3_bitm': _cfg(
'resnetv2_152x2_bitm': _cfg(
'resnetv2_152x4_bitm': _cfg(
# trained on imagenet-21k
'resnetv2_50x1_bitm_in21k': _cfg(
'resnetv2_50x3_bitm_in21k': _cfg(
'resnetv2_101x1_bitm_in21k': _cfg(
'resnetv2_101x3_bitm_in21k': _cfg(
'resnetv2_152x2_bitm_in21k': _cfg(
'resnetv2_152x4_bitm_in21k': _cfg(
# trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now
# 'resnetv2_50x1_bits': _cfg(
# url=''),
# 'resnetv2_50x3_bits': _cfg(
# url=''),
# 'resnetv2_101x1_bits': _cfg(
# url=''),
# 'resnetv2_101x3_bits': _cfg(
# url=''),
# 'resnetv2_152x2_bits': _cfg(
# url=''),
# 'resnetv2_152x4_bits': _cfg(
# url=''),
def make_div(v, divisor=8):
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class StdConv2d(nn.Conv2d):
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
padding = get_padding(kernel_size, stride, dilation)
in_channel, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias, groups=groups)
self.eps = eps
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / (torch.sqrt(v) + self.eps)
x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
class StdConv2dSame(nn.Conv2d):
"""StdConv2d w/ TF compatible SAME padding. Used for ViT Hybrid model.
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
padding = get_padding(kernel_size, stride, dilation)
in_channel, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias, groups=groups)
self.eps = eps
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / (torch.sqrt(v) + self.eps)
x = conv2d_same(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
def tf2th(conv_weights):
"""Possibly convert HWIO to OIHW."""
if conv_weights.ndim == 4:
conv_weights = conv_weights.transpose([3, 2, 0, 1])
return torch.from_numpy(conv_weights)
class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block.
Follows the implementation of "Identity Mappings in Deep Residual Networks":
Except it puts the stride on 3x3 conv when available.
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
first_dilation = first_dilation or dilation
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
conv_layer=conv_layer, norm_layer=norm_layer)
self.downsample = None
self.norm1 = norm_layer(in_chs)
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.norm2 = norm_layer(mid_chs)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
self.norm3 = norm_layer(mid_chs)
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def forward(self, x):
x_preact = self.norm1(x)
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x_preact)
# residual branch
x = self.conv1(x_preact)
x = self.conv2(self.norm2(x))
x = self.conv3(self.norm3(x))
x = self.drop_path(x)
return x + shortcut
class Bottleneck(nn.Module):
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
first_dilation = first_dilation or dilation
act_layer = act_layer or nn.ReLU
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
conv_layer=conv_layer, norm_layer=norm_layer)
self.downsample = None
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.norm1 = norm_layer(mid_chs)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
self.norm2 = norm_layer(mid_chs)
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.norm3 = norm_layer(out_chs, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.act3 = act_layer(inplace=True)
def forward(self, x):
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
# residual
x = self.conv1(x)
x = self.norm1(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.drop_path(x)
x = self.act3(x + shortcut)
return x
class DownsampleConv(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
conv_layer=None, norm_layer=None):
super(DownsampleConv, self).__init__()
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
def forward(self, x):
return self.norm(self.conv(x))
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
preact=True, conv_layer=None, norm_layer=None):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
self.pool = nn.Identity()
self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
def forward(self, x):
return self.norm(self.conv(self.pool(x)))
class ResNetStage(nn.Module):
"""ResNet Stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
super(ResNetStage, self).__init__()
first_dilation = 1 if dilation in (1, 2) else 2
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
proj_layer = DownsampleAvg if avg_down else DownsampleConv
prev_chs = in_chs
self.blocks = nn.Sequential()
for block_idx in range(depth):
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
stride = stride if block_idx == 0 else 1
self.blocks.add_module(str(block_idx), block_fn(
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
**layer_kwargs, **block_kwargs))
prev_chs = out_chs
first_dilation = dilation
proj_layer = None
def forward(self, x):
x = self.blocks(x)
return x
def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None):
stem = OrderedDict()
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
# NOTE conv padding mode can be changed by overriding the conv_layer def
if 'deep' in stem_type:
# A 3 deep 3x3 conv stack as in ResNet V1D models
mid_chs = out_chs // 2
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
# The usual 7x7 stem conv
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
if not preact:
stem['norm'] = norm_layer(out_chs)
if 'fixed' in stem_type:
# 'fixed' SAME padding approximation that is used in BiT models
stem['pad'] = nn.ConstantPad2d(1, 0.)
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
elif 'same' in stem_type:
# full, input size based 'SAME' padding, used in ViT Hybrid model
stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same')
# the usual PyTorch symmetric padding
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
return nn.Sequential(stem)
class ResNetV2(nn.Module):
"""Implementation of Pre-activation (v2) ResNet mode.
def __init__(self, layers, channels=(256, 512, 1024, 2048),
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0.):
self.num_classes = num_classes
self.drop_rate = drop_rate
wf = width_factor
self.feature_info = []
stem_chs = make_div(stem_chs * wf)
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
# NOTE no, reduction 2 feature if preact
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm'))
prev_chs = stem_chs
curr_stride = 4
dilation = 1
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
block_fn = PreActBottleneck if preact else Bottleneck
self.stages = nn.Sequential()
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
out_chs = make_div(c * wf)
stride = 1 if stage_idx == 0 else 2
if curr_stride >= output_stride:
dilation *= stride
stride = 1
stage = ResNetStage(
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
prev_chs = out_chs
curr_stride *= stride
feat_name = f'stages.{stage_idx}'
if preact:
feat_name = f'stages.{stage_idx + 1}.blocks.0.norm1' if (stage_idx + 1) != len(channels) else 'norm'
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=feat_name)]
self.stages.add_module(str(stage_idx), stage)
self.num_features = prev_chs
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
for n, m in self.named_modules():
if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
if not self.head.global_pool.is_identity():
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
return x
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
import numpy as np
weights = np.load(checkpoint_path)
with torch.no_grad():
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])
if self.stem.conv.weight.shape[1] == 1:
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True))
# FIXME handle > 3 in_chans?
for i, (sname, stage) in enumerate(self.stages.named_children()):
for j, (bname, block) in enumerate(stage.blocks.named_children()):
convname = 'standardized_conv2d'
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
if block.downsample is not None:
w = weights[f'{block_prefix}a/proj/{convname}/kernel']
def _create_resnetv2(variant, pretrained=False, **kwargs):
# FIXME feature map extraction is not setup properly for pre-activation mode right now
preact = kwargs.get('preact', True)
feature_cfg = dict(flatten_sequential=True)
if preact:
feature_cfg['feature_cls'] = 'hook'
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for preact
return build_model_with_cfg(
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
feature_cfg=feature_cfg, **kwargs)
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x1_bitm', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x3_bitm', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bitm', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x3_bitm', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x2_bitm', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x4_bitm', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M.
# @register_model
# def resnetv2_50x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x1_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
# @register_model
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x3_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
# @register_model
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x1_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
# @register_model
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x3_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
# @register_model
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x2_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
# @register_model
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x4_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
@ -1 +1 @@
__version__ = '0.3.4'
__version__ = '0.4.0'
Reference in new issue