diff --git a/timm/data/parsers/class_map.py b/timm/data/parsers/class_map.py index 6b6fe453..6cf3f57e 100644 --- a/timm/data/parsers/class_map.py +++ b/timm/data/parsers/class_map.py @@ -1,5 +1,5 @@ import os - +import pickle def load_class_map(map_or_filename, root=''): if isinstance(map_or_filename, dict): @@ -13,6 +13,9 @@ def load_class_map(map_or_filename, root=''): if class_map_ext == '.txt': with open(class_map_path) as f: class_to_idx = {v.strip(): k for k, v in enumerate(f)} + elif class_map_ext == '.pkl': + with open(class_map_path,'rb') as f: + class_to_idx = pickle.load(f) else: assert False, f'Unsupported class map file extension ({class_map_ext}).' return class_to_idx