|
|
|
@ -23,6 +23,7 @@ except ImportError as e:
|
|
|
|
|
exit(1)
|
|
|
|
|
from .parser import Parser
|
|
|
|
|
|
|
|
|
|
from timm.bits import get_device
|
|
|
|
|
|
|
|
|
|
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
|
|
|
|
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
|
|
|
|
@ -79,9 +80,14 @@ class ParserTfds(Parser):
|
|
|
|
|
self.worker_info = None
|
|
|
|
|
self.dist_rank = 0
|
|
|
|
|
self.dist_num_replicas = 1
|
|
|
|
|
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
|
|
|
|
self.dist_rank = dist.get_rank()
|
|
|
|
|
self.dist_num_replicas = dist.get_world_size()
|
|
|
|
|
dev_env = get_device()
|
|
|
|
|
# FIXME allow to work without devenv usage?
|
|
|
|
|
if dev_env.is_distributed and dev_env.world_size > 1:
|
|
|
|
|
self.dist_rank = dev_env.global_rank
|
|
|
|
|
self.dist_num_replicas = dev_env.world_size
|
|
|
|
|
# if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
|
|
|
|
# self.dist_rank = dist.get_rank()
|
|
|
|
|
# self.dist_num_replicas = dist.get_world_size()
|
|
|
|
|
|
|
|
|
|
def _lazy_init(self):
|
|
|
|
|
""" Lazily initialize the dataset.
|
|
|
|
|