Merge branch 'master' of https://github.com/rwightman/pytorch-image-models into timm/code_refactoring

 Conflicts:
	timm/models/hrnet.py
pull/137/head
Vyacheslav Shults 5 years ago
commit 4de2869e23

@ -140,6 +140,7 @@ def create_loader(
pin_memory=False, pin_memory=False,
fp16=False, fp16=False,
tf_preprocessing=False, tf_preprocessing=False,
use_multi_epochs_loader=False
): ):
re_num_splits = 0 re_num_splits = 0
if re_split: if re_split:
@ -175,7 +176,12 @@ def create_loader(
if collate_fn is None: if collate_fn is None:
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
loader = torch.utils.data.DataLoader( loader_class = torch.utils.data.DataLoader
if use_multi_epochs_loader:
loader_class = MultiEpochsDataLoader
loader = loader_class(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=sampler is None and is_training, shuffle=sampler is None and is_training,
@ -198,3 +204,35 @@ def create_loader(
) )
return loader return loader
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)

@ -675,7 +675,10 @@ class HighResolutionNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
num_features = self.num_features * self.global_pool.feat_mult() num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Linear(num_features, num_classes) if num_classes else nn.Identity() if num_classes:
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.conv1(x) x = self.conv1(x)

@ -198,6 +198,8 @@ parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_MET
parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
def _parse_args(): def _parse_args():
@ -391,6 +393,7 @@ def main():
distributed=args.distributed, distributed=args.distributed,
collate_fn=collate_fn, collate_fn=collate_fn,
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader
) )
eval_dir = os.path.join(args.data, 'val') eval_dir = os.path.join(args.data, 'val')

Loading…
Cancel
Save