|
|
|
@ -31,14 +31,24 @@ SHUFFLE_SIZE = 8192
|
|
|
|
|
def _load_info(root, basename='info'):
|
|
|
|
|
info_json = os.path.join(root, basename + '.json')
|
|
|
|
|
info_yaml = os.path.join(root, basename + '.yaml')
|
|
|
|
|
info_dict = {}
|
|
|
|
|
if os.path.exists(info_json):
|
|
|
|
|
with open(info_json, 'r') as f:
|
|
|
|
|
err_str = ''
|
|
|
|
|
try:
|
|
|
|
|
with wds.gopen.gopen(info_json) as f:
|
|
|
|
|
info_dict = json.load(f)
|
|
|
|
|
elif os.path.exists(info_yaml):
|
|
|
|
|
with open(info_yaml, 'r') as f:
|
|
|
|
|
return info_dict
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
try:
|
|
|
|
|
with wds.gopen.gopen(info_yaml) as f:
|
|
|
|
|
info_dict = yaml.safe_load(f)
|
|
|
|
|
return info_dict
|
|
|
|
|
return info_dict
|
|
|
|
|
except Exception as e:
|
|
|
|
|
err_str = str(e)
|
|
|
|
|
# FIXME change to log
|
|
|
|
|
print(f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. '
|
|
|
|
|
f'Falling back to provided split and size arg.')
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class SplitInfo:
|
|
|
|
@ -171,6 +181,9 @@ class ParserWebdataset(Parser):
|
|
|
|
|
shuffle_size=None,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
if wds is None:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
'Please install webdataset 0.2.x package `pip install git+https://github.com/webdataset/webdataset`.')
|
|
|
|
|
self.root = root
|
|
|
|
|
self.is_training = is_training
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|