You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/bag_sampler.py

18 lines
735 B

import torch
import random
from torch.utils.data.sampler import Sampler
##TODO: Ensure that all videos are covered. Make a copy of indices and remove the indices that were already picked.
class BagSampler(Sampler):
def __init__(self, dataset):
halfway_point = int(len(dataset)/2)
self.first_half_indices = list(range(halfway_point))
self.second_half_indices = list(range(halfway_point, len(dataset)))
def __iter__(self):
random.shuffle(self.first_half_indices)
random.shuffle(self.second_half_indices)
return iter(self.first_half_indices + self.second_half_indices)
def __len__(self):
return len(self.first_half_indices) + len(self.second_half_indices)