Fix logs in WDS parser

pull/1479/head
Ross Wightman 2 years ago
parent b8c8550841
commit 4f18d6dc5f

@ -228,7 +228,7 @@ if wds is not None:
seed = pytorch_worker_seed() + epoch seed = pytorch_worker_seed() + epoch
else: else:
seed = self.seed + epoch seed = self.seed + epoch
_logger.info('shuffle', self.seed, epoch, seed) # FIXME temporary _logger.info(f'shuffle seed: {self.seed}, {seed}, epoch: {epoch}') # FIXME temporary
rng = random.Random(seed) rng = random.Random(seed)
return _shuffle(src, self.bufsize, self.initial, rng) return _shuffle(src, self.bufsize, self.initial, rng)
@ -417,11 +417,11 @@ class ParserWds(Parser):
ds = self.ds ds = self.ds
i = 0 i = 0
_logger.info('start', i, self.worker_id) # FIXME temporary debug _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
for sample in ds: for sample in ds:
yield sample[self.image_key], sample[self.target_key] yield sample[self.image_key], sample[self.target_key]
i += 1 i += 1
_logger.info('end', i, self.worker_id) # FIXME temporary debug _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
def __len__(self): def __len__(self):
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)

Loading…
Cancel
Save