Fix partially removed alt_lable impl from TFDS variant of ImageNet22/12k

pull/1414/head
Ross Wightman 3 years ago
parent 5e1be34a60
commit 95739b45d7

@ -77,12 +77,12 @@ class Imagenet22k(tfds.core.GeneratorBasedBuilder):
'validation': self._generate_examples(val_records, manual_dir), 'validation': self._generate_examples(val_records, manual_dir),
} }
def _generate_examples(self, records, manual_dir, alt_label=None, resize_short=True, max_img_size=MAX_DIM): def _generate_examples(self, records, manual_dir, resize_short=True, max_img_size=MAX_DIM):
"""Yields examples.""" """Yields examples."""
for r in records: for r in records:
try: try:
filename, output_record = _process_record( filename, output_record = _process_record(
r, manual_dir, alt_label=alt_label, resize_short=resize_short, max_img_size=max_img_size) r, manual_dir, resize_short=resize_short, max_img_size=max_img_size)
yield filename, output_record yield filename, output_record
except Exception as e: except Exception as e:
print('Exception:', e) print('Exception:', e)
@ -114,8 +114,6 @@ def _load_records(
train_csv, train_csv,
validation_csv, validation_csv,
labels, labels,
alt_labels=None,
alt_label_name='',
min_img_size=MIN_DIM, min_img_size=MIN_DIM,
): ):
pd = tfds.core.lazy_imports.pandas pd = tfds.core.lazy_imports.pandas
@ -133,12 +131,10 @@ def _load_records(
train_record_df['label'] = train_record_df['cls'].map(class_to_idx).astype(int) train_record_df['label'] = train_record_df['cls'].map(class_to_idx).astype(int)
train_record_df = train_record_df[['filename', 'label']] train_record_df = train_record_df[['filename', 'label']]
train_record_df = train_record_df.sample(frac=1, random_state=42)
print('num train records:', len(train_record_df.index)) print('num train records:', len(train_record_df.index))
val_record_df['label'] = val_record_df['cls'].map(class_to_idx).astype(int) val_record_df['label'] = val_record_df['cls'].map(class_to_idx).astype(int)
val_record_df = val_record_df[['filename', 'label']] val_record_df = val_record_df[['filename', 'label']]
val_record_df = val_record_df.sample(frac=1, random_state=42)
print('num val records:', len(val_record_df.index)) print('num val records:', len(val_record_df.index))
train_records = train_record_df.to_records(index=False) train_records = train_record_df.to_records(index=False)

Loading…
Cancel
Save