|
|
|
@ -29,6 +29,7 @@ except ImportError:
|
|
|
|
|
wds = None
|
|
|
|
|
expand_urls = None
|
|
|
|
|
|
|
|
|
|
from .class_map import load_class_map
|
|
|
|
|
from .reader import Reader
|
|
|
|
|
from .shared_count import SharedCount
|
|
|
|
|
|
|
|
|
@ -42,13 +43,13 @@ def _load_info(root, basename='info'):
|
|
|
|
|
info_yaml = os.path.join(root, basename + '.yaml')
|
|
|
|
|
err_str = ''
|
|
|
|
|
try:
|
|
|
|
|
with wds.gopen.gopen(info_json) as f:
|
|
|
|
|
with wds.gopen(info_json) as f:
|
|
|
|
|
info_dict = json.load(f)
|
|
|
|
|
return info_dict
|
|
|
|
|
except Exception as e:
|
|
|
|
|
err_str = str(e)
|
|
|
|
|
try:
|
|
|
|
|
with wds.gopen.gopen(info_yaml) as f:
|
|
|
|
|
with wds.gopen(info_yaml) as f:
|
|
|
|
|
info_dict = yaml.safe_load(f)
|
|
|
|
|
return info_dict
|
|
|
|
|
except Exception:
|
|
|
|
@ -110,8 +111,8 @@ def _parse_split_info(split: str, info: Dict):
|
|
|
|
|
filenames=split_filenames,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
if split not in info['splits']:
|
|
|
|
|
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
|
|
|
|
|
if 'splits' not in info or split not in info['splits']:
|
|
|
|
|
raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})")
|
|
|
|
|
split = split
|
|
|
|
|
split_info = info['splits'][split]
|
|
|
|
|
split_info = _info_convert(split_info)
|
|
|
|
@ -290,6 +291,7 @@ class ReaderWds(Reader):
|
|
|
|
|
batch_size=None,
|
|
|
|
|
repeats=0,
|
|
|
|
|
seed=42,
|
|
|
|
|
class_map=None,
|
|
|
|
|
input_name='jpg',
|
|
|
|
|
input_image='RGB',
|
|
|
|
|
target_name='cls',
|
|
|
|
@ -320,6 +322,12 @@ class ReaderWds(Reader):
|
|
|
|
|
self.num_samples = self.split_info.num_samples
|
|
|
|
|
if not self.num_samples:
|
|
|
|
|
raise RuntimeError(f'Invalid split definition, no samples found.')
|
|
|
|
|
self.remap_class = False
|
|
|
|
|
if class_map:
|
|
|
|
|
self.class_to_idx = load_class_map(class_map)
|
|
|
|
|
self.remap_class = True
|
|
|
|
|
else:
|
|
|
|
|
self.class_to_idx = {}
|
|
|
|
|
|
|
|
|
|
# Distributed world state
|
|
|
|
|
self.dist_rank = 0
|
|
|
|
@ -431,7 +439,10 @@ class ReaderWds(Reader):
|
|
|
|
|
i = 0
|
|
|
|
|
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
|
|
|
|
|
for sample in ds:
|
|
|
|
|
yield sample[self.image_key], sample[self.target_key]
|
|
|
|
|
target = sample[self.target_key]
|
|
|
|
|
if self.remap_class:
|
|
|
|
|
target = self.class_to_idx[target]
|
|
|
|
|
yield sample[self.image_key], target
|
|
|
|
|
i += 1
|
|
|
|
|
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
|
|
|
|
|
|
|
|
|
|