Add synset/label indices for results generation. Add 'valid labels' to validation script to support imagenet-a/r label subsets properly.

pull/175/head
Ross Wightman 4 years ago
parent ec4976fdba
commit c53ec33ae0

@ -0,0 +1,200 @@
6
11
13
15
17
22
23
27
30
37
39
42
47
50
57
70
71
76
79
89
90
94
96
97
99
105
107
108
110
113
124
125
130
132
143
144
150
151
207
234
235
254
277
283
287
291
295
298
301
306
307
308
309
310
311
313
314
315
317
319
323
324
326
327
330
334
335
336
347
361
363
372
378
386
397
400
401
402
404
407
411
416
417
420
425
428
430
437
438
445
456
457
461
462
470
472
483
486
488
492
496
514
516
528
530
539
542
543
549
552
557
561
562
569
572
573
575
579
589
606
607
609
614
626
627
640
641
642
643
658
668
677
682
684
687
701
704
719
736
746
749
752
758
763
765
768
773
774
776
779
780
786
792
797
802
803
804
813
815
820
823
831
833
835
839
845
847
850
859
862
870
879
880
888
890
897
900
907
913
924
932
933
934
937
943
945
947
951
954
956
957
959
971
972
980
981
984
986
987
988

@ -0,0 +1,200 @@
n01498041
n01531178
n01534433
n01558993
n01580077
n01614925
n01616318
n01631663
n01641577
n01669191
n01677366
n01687978
n01694178
n01698640
n01735189
n01770081
n01770393
n01774750
n01784675
n01819313
n01820546
n01833805
n01843383
n01847000
n01855672
n01882714
n01910747
n01914609
n01924916
n01944390
n01985128
n01986214
n02007558
n02009912
n02037110
n02051845
n02077923
n02085620
n02099601
n02106550
n02106662
n02110958
n02119022
n02123394
n02127052
n02129165
n02133161
n02137549
n02165456
n02174001
n02177972
n02190166
n02206856
n02219486
n02226429
n02231487
n02233338
n02236044
n02259212
n02268443
n02279972
n02280649
n02281787
n02317335
n02325366
n02346627
n02356798
n02361337
n02410509
n02445715
n02454379
n02486410
n02492035
n02504458
n02655020
n02669723
n02672831
n02676566
n02690373
n02701002
n02730930
n02777292
n02782093
n02787622
n02793495
n02797295
n02802426
n02814860
n02815834
n02837789
n02879718
n02883205
n02895154
n02906734
n02948072
n02951358
n02980441
n02992211
n02999410
n03014705
n03026506
n03124043
n03125729
n03187595
n03196217
n03223299
n03250847
n03255030
n03291819
n03325584
n03355925
n03384352
n03388043
n03417042
n03443371
n03444034
n03445924
n03452741
n03483316
n03584829
n03590841
n03594945
n03617480
n03666591
n03670208
n03717622
n03720891
n03721384
n03724870
n03775071
n03788195
n03804744
n03837869
n03840681
n03854065
n03888257
n03891332
n03935335
n03982430
n04019541
n04033901
n04039381
n04067472
n04086273
n04099969
n04118538
n04131690
n04133789
n04141076
n04146614
n04147183
n04179913
n04208210
n04235860
n04252077
n04252225
n04254120
n04270147
n04275548
n04310018
n04317175
n04344873
n04347754
n04355338
n04366367
n04376876
n04389033
n04399382
n04442312
n04456115
n04482393
n04507155
n04509417
n04532670
n04540053
n04554684
n04562935
n04591713
n04606251
n07583066
n07695742
n07697313
n07697537
n07714990
n07718472
n07720875
n07734744
n07749582
n07753592
n07760859
n07768694
n07831146
n09229709
n09246464
n09472597
n09835506
n11879895
n12057211
n12144580
n12267677

@ -0,0 +1,200 @@
1
2
4
6
8
9
11
13
22
23
26
29
31
39
47
63
71
76
79
84
90
94
96
97
99
100
105
107
113
122
125
130
132
144
145
147
148
150
151
155
160
161
162
163
171
172
178
187
195
199
203
207
208
219
231
232
234
235
242
245
247
250
251
254
259
260
263
265
267
269
276
277
281
288
289
291
292
293
296
299
301
308
309
310
311
314
315
319
323
327
330
334
335
337
338
340
341
344
347
353
355
361
362
365
366
367
368
372
388
390
393
397
401
407
413
414
425
428
430
435
437
441
447
448
457
462
463
469
470
471
472
476
483
487
515
546
555
558
570
579
583
587
593
594
596
609
613
617
621
629
637
657
658
701
717
724
763
768
774
776
779
780
787
805
812
815
820
824
833
847
852
866
875
883
889
895
907
928
931
932
933
934
936
937
943
945
947
948
949
951
953
954
957
963
965
967
980
981
983
988

@ -0,0 +1,200 @@
n01443537
n01484850
n01494475
n01498041
n01514859
n01518878
n01531178
n01534433
n01614925
n01616318
n01630670
n01632777
n01644373
n01677366
n01694178
n01748264
n01770393
n01774750
n01784675
n01806143
n01820546
n01833805
n01843383
n01847000
n01855672
n01860187
n01882714
n01910747
n01944390
n01983481
n01986214
n02007558
n02009912
n02051845
n02056570
n02066245
n02071294
n02077923
n02085620
n02086240
n02088094
n02088238
n02088364
n02088466
n02091032
n02091134
n02092339
n02094433
n02096585
n02097298
n02098286
n02099601
n02099712
n02102318
n02106030
n02106166
n02106550
n02106662
n02108089
n02108915
n02109525
n02110185
n02110341
n02110958
n02112018
n02112137
n02113023
n02113624
n02113799
n02114367
n02117135
n02119022
n02123045
n02128385
n02128757
n02129165
n02129604
n02130308
n02134084
n02138441
n02165456
n02190166
n02206856
n02219486
n02226429
n02233338
n02236044
n02268443
n02279972
n02317335
n02325366
n02346627
n02356798
n02363005
n02364673
n02391049
n02395406
n02398521
n02410509
n02423022
n02437616
n02445715
n02447366
n02480495
n02480855
n02481823
n02483362
n02486410
n02510455
n02526121
n02607072
n02655020
n02672831
n02701002
n02749479
n02769748
n02793495
n02797295
n02802426
n02808440
n02814860
n02823750
n02841315
n02843684
n02883205
n02906734
n02909870
n02939185
n02948072
n02950826
n02951358
n02966193
n02980441
n02992529
n03124170
n03272010
n03345487
n03372029
n03424325
n03452741
n03467068
n03481172
n03494278
n03495258
n03498962
n03594945
n03602883
n03630383
n03649909
n03676483
n03710193
n03773504
n03775071
n03888257
n03930630
n03947888
n04086273
n04118538
n04133789
n04141076
n04146614
n04147183
n04192698
n04254680
n04266014
n04275548
n04310018
n04325704
n04347754
n04389033
n04409515
n04465501
n04487394
n04522168
n04536866
n04552348
n04591713
n07614500
n07693725
n07695742
n07697313
n07697537
n07714571
n07714990
n07718472
n07720875
n07734744
n07742313
n07745940
n07749582
n07753275
n07753592
n07768694
n07873807
n07880968
n07920052
n09472597
n09835506
n10565667
n12267677

@ -39,14 +39,13 @@ def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, lea
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = zip(filenames, [class_to_idx[l] for l in labels])
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx
def load_class_map(filename, root=''):
class_to_idx = {}
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
@ -74,8 +73,8 @@ class Dataset(data.Dataset):
class_to_idx = load_class_map(class_map, root)
images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(images) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
raise RuntimeError(f'Found 0 images in subfolders of {root}. '
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
self.root = root
self.samples = images
self.imgs = self.samples # torchvision ImageFolder compat
@ -124,7 +123,7 @@ def _extract_tar_info(tarfile, class_to_idx=None, sort=True):
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
tarinfo_and_targets = zip(files, [class_to_idx[l] for l in labels])
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
if sort:
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
return tarinfo_and_targets, class_to_idx
@ -141,6 +140,7 @@ class DatasetTar(data.Dataset):
self.root = root
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
self.samples, self.class_to_idx = _extract_tar_info(tf, class_to_idx)
self.imgs = self.samples
self.tarfile = None # lazy init in __getitem__
self.load_bytes = load_bytes
self.transform = transform

@ -83,6 +83,8 @@ parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space')
def set_jit_legacy():
@ -141,6 +143,13 @@ def validate(args):
else:
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:
valid_labels = {int(line.rstrip()) for line in f}
valid_labels = [i in valid_labels for i in range(args.num_classes)]
else:
valid_labels = None
if args.real_labels:
real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
else:
@ -180,6 +189,8 @@ def validate(args):
# compute output
output = model(input)
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
if real_labels is not None:

Loading…
Cancel
Save