diff --git a/timm/data/readers/reader_tfds.py b/timm/data/readers/reader_tfds.py index 7ccbf908..a327df7b 100644 --- a/timm/data/readers/reader_tfds.py +++ b/timm/data/readers/reader_tfds.py @@ -38,9 +38,9 @@ from .reader import Reader from .shared_count import SharedCount -MAX_TP_SIZE = os.environ.get('TFDS_TP_SIZE', 8) # maximum TF threadpool size, for jpeg decodes and queuing activities -SHUFFLE_SIZE = os.environ.get('TFDS_SHUFFLE_SIZE', 8192) # samples to shuffle in DS queue -PREFETCH_SIZE = os.environ.get('TFDS_PREFETCH_SIZE', 2048) # samples to prefetch +MAX_TP_SIZE = int(os.environ.get('TFDS_TP_SIZE', 8)) # maximum TF threadpool size, for jpeg decodes and queuing activities +SHUFFLE_SIZE = int(os.environ.get('TFDS_SHUFFLE_SIZE', 8192)) # samples to shuffle in DS queue +PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch def even_split_indices(split, n, num_samples): diff --git a/timm/data/readers/reader_wds.py b/timm/data/readers/reader_wds.py index afc9030b..36890eed 100644 --- a/timm/data/readers/reader_wds.py +++ b/timm/data/readers/reader_wds.py @@ -34,7 +34,7 @@ from .shared_count import SharedCount _logger = logging.getLogger(__name__) -SHUFFLE_SIZE = os.environ.get('WDS_SHUFFLE_SIZE', 8192) +SHUFFLE_SIZE = int(os.environ.get('WDS_SHUFFLE_SIZE', 8192)) def _load_info(root, basename='info'):