Fix import issue, use devenv for dist info in parser_tfds

pull/1239/head
Ross Wightman 4 years ago
parent 76de984a5f
commit 938716c753

@ -4,7 +4,6 @@ import torch
try:
import torch_xla.core.xla_model as xm
import torch_xla.amp as xa
_HAS_XLA = True
except ImportError as e:
xm = None

@ -23,6 +23,7 @@ except ImportError as e:
exit(1)
from .parser import Parser
from timm.bits import get_device
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
@ -79,9 +80,14 @@ class ParserTfds(Parser):
self.worker_info = None
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
dev_env = get_device()
# FIXME allow to work without devenv usage?
if dev_env.is_distributed and dev_env.world_size > 1:
self.dist_rank = dev_env.global_rank
self.dist_num_replicas = dev_env.world_size
# if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
# self.dist_rank = dist.get_rank()
# self.dist_num_replicas = dist.get_world_size()
def _lazy_init(self):
""" Lazily initialize the dataset.

Loading…
Cancel
Save