You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
18 lines
735 B
18 lines
735 B
import torch
|
|
import random
|
|
from torch.utils.data.sampler import Sampler
|
|
|
|
##TODO: Ensure that all videos are covered. Make a copy of indices and remove the indices that were already picked.
|
|
class BagSampler(Sampler):
|
|
def __init__(self, dataset):
|
|
halfway_point = int(len(dataset)/2)
|
|
self.first_half_indices = list(range(halfway_point))
|
|
self.second_half_indices = list(range(halfway_point, len(dataset)))
|
|
|
|
def __iter__(self):
|
|
random.shuffle(self.first_half_indices)
|
|
random.shuffle(self.second_half_indices)
|
|
return iter(self.first_half_indices + self.second_half_indices)
|
|
|
|
def __len__(self):
|
|
return len(self.first_half_indices) + len(self.second_half_indices) |