|
|
|
@ -43,6 +43,15 @@ SHUFFLE_SIZE = int(os.environ.get('TFDS_SHUFFLE_SIZE', 8192)) # samples to shuf
|
|
|
|
|
PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@tfds.decode.make_decoder()
|
|
|
|
|
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE'):
|
|
|
|
|
return tf.image.decode_jpeg(
|
|
|
|
|
serialized_image,
|
|
|
|
|
channels=3,
|
|
|
|
|
dct_method=dct_method,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def even_split_indices(split, n, num_samples):
|
|
|
|
|
partitions = [round(i * num_samples / n) for i in range(n + 1)]
|
|
|
|
|
return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
|
|
|
|
@ -242,6 +251,7 @@ class ReaderTfds(Reader):
|
|
|
|
|
ds = self.builder.as_dataset(
|
|
|
|
|
split=self.subsplit or self.split,
|
|
|
|
|
shuffle_files=self.is_training,
|
|
|
|
|
decoders=dict(image=decode_example()),
|
|
|
|
|
read_config=read_config,
|
|
|
|
|
)
|
|
|
|
|
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
|
|
|
|