diff --git a/timm/bits/device_env_xla.py b/timm/bits/device_env_xla.py index 18e0fd3b..518cd993 100644 --- a/timm/bits/device_env_xla.py +++ b/timm/bits/device_env_xla.py @@ -4,7 +4,6 @@ import torch try: import torch_xla.core.xla_model as xm - import torch_xla.amp as xa _HAS_XLA = True except ImportError as e: xm = None diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 0b12a4db..92495d12 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -23,6 +23,7 @@ except ImportError as e: exit(1) from .parser import Parser +from timm.bits import get_device MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue @@ -79,9 +80,14 @@ class ParserTfds(Parser): self.worker_info = None self.dist_rank = 0 self.dist_num_replicas = 1 - if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: - self.dist_rank = dist.get_rank() - self.dist_num_replicas = dist.get_world_size() + dev_env = get_device() + # FIXME allow to work without devenv usage? + if dev_env.is_distributed and dev_env.world_size > 1: + self.dist_rank = dev_env.global_rank + self.dist_num_replicas = dev_env.world_size + # if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + # self.dist_rank = dist.get_rank() + # self.dist_num_replicas = dist.get_world_size() def _lazy_init(self): """ Lazily initialize the dataset.