Switch TFDS dataset to use INTEGER_ACCURATE jpeg decode by default

pull/1520/head
Ross Wightman 2 years ago
parent 0ed0cc7eba
commit d1e0a4607d

@ -43,6 +43,15 @@ SHUFFLE_SIZE = int(os.environ.get('TFDS_SHUFFLE_SIZE', 8192)) # samples to shuf
PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch
@tfds.decode.make_decoder()
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE'):
return tf.image.decode_jpeg(
serialized_image,
channels=3,
dct_method=dct_method,
)
def even_split_indices(split, n, num_samples): def even_split_indices(split, n, num_samples):
partitions = [round(i * num_samples / n) for i in range(n + 1)] partitions = [round(i * num_samples / n) for i in range(n + 1)]
return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)] return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
@ -242,6 +251,7 @@ class ReaderTfds(Reader):
ds = self.builder.as_dataset( ds = self.builder.as_dataset(
split=self.subsplit or self.split, split=self.subsplit or self.split,
shuffle_files=self.is_training, shuffle_files=self.is_training,
decoders=dict(image=decode_example()),
read_config=read_config, read_config=read_config,
) )
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers # avoid overloading threading w/ combo of TF ds threads + PyTorch workers

Loading…
Cancel
Save