diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 7d3cb2b4..0eb10a66 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -6,7 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset from .loader import create_loader from .mixup import Mixup, FastCollateMixup -from .parsers import create_parser +from .parsers import create_parser,\ + get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions from .real_labels import RealLabelsImagenet from .transforms import * -from .transforms_factory import create_transform \ No newline at end of file +from .transforms_factory import create_transform diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py index eeb44e37..4e820d5e 100644 --- a/timm/data/parsers/__init__.py +++ b/timm/data/parsers/__init__.py @@ -1 +1,2 @@ from .parser_factory import create_parser +from .img_extensions import * diff --git a/timm/data/parsers/constants.py b/timm/data/parsers/constants.py deleted file mode 100644 index e7ba484e..00000000 --- a/timm/data/parsers/constants.py +++ /dev/null @@ -1 +0,0 @@ -IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') diff --git a/timm/data/parsers/img_extensions.py b/timm/data/parsers/img_extensions.py new file mode 100644 index 00000000..45c85aab --- /dev/null +++ b/timm/data/parsers/img_extensions.py @@ -0,0 +1,50 @@ +from copy import deepcopy + +__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] + + +IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use +_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync + + +def _set_extensions(extensions): + global IMG_EXTENSIONS + global _IMG_EXTENSIONS_SET + dedupe = set() # NOTE de-duping tuple while keeping original order + IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) + _IMG_EXTENSIONS_SET = set(extensions) + + +def _valid_extension(x: str): + return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') + + +def is_img_extension(ext): + return ext in _IMG_EXTENSIONS_SET + + +def get_img_extensions(as_set=False): + return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) + + +def set_img_extensions(extensions): + assert len(extensions) + for x in extensions: + assert _valid_extension(x) + _set_extensions(extensions) + + +def add_img_extensions(ext): + if not isinstance(ext, (list, tuple, set)): + ext = (ext,) + for x in ext: + assert _valid_extension(x) + extensions = IMG_EXTENSIONS + tuple(ext) + _set_extensions(extensions) + + +def del_img_extensions(ext): + if not isinstance(ext, (list, tuple, set)): + ext = (ext,) + extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) + _set_extensions(extensions) diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py index 892090ad..0665c02a 100644 --- a/timm/data/parsers/parser_factory.py +++ b/timm/data/parsers/parser_factory.py @@ -1,7 +1,6 @@ import os from .parser_image_folder import ParserImageFolder -from .parser_image_tar import ParserImageTar from .parser_image_in_tar import ParserImageInTar diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/parsers/parser_image_folder.py index ed349009..3d22a17b 100644 --- a/timm/data/parsers/parser_image_folder.py +++ b/timm/data/parsers/parser_image_folder.py @@ -6,15 +6,35 @@ on the folder hierarchy, just leaf folders by default. Hacked together by / Copyright 2020 Ross Wightman """ import os +from typing import Dict, List, Optional, Set, Tuple, Union from timm.utils.misc import natural_key -from .parser import Parser from .class_map import load_class_map -from .constants import IMG_EXTENSIONS +from .img_extensions import get_img_extensions +from .parser import Parser + + +def find_images_and_targets( + folder: str, + types: Optional[Union[List, Tuple, Set]] = None, + class_to_idx: Optional[Dict] = None, + leaf_name_only: bool = True, + sort: bool = True +): + """ Walk folder recursively to discover images and map them to classes by folder names. + Args: + folder: root of folder to recrusively search + types: types (file extensions) to search for in path + class_to_idx: specify mapping for class (folder name) to class index if set + leaf_name_only: use only leaf-name of folder walk for class names + sort: re-sort found images by name (for consistent ordering) -def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): + Returns: + A list of image and target tuples, class_to_idx mapping + """ + types = get_img_extensions(as_set=True) if not types else set(types) labels = [] filenames = [] for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): @@ -51,7 +71,8 @@ class ParserImageFolder(Parser): self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) if len(self.samples) == 0: raise RuntimeError( - f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') + f'Found 0 images in subfolders of {root}. ' + f'Supported image extensions are {", ".join(get_img_extensions())}') def __getitem__(self, index): path, target = self.samples[index] diff --git a/timm/data/parsers/parser_image_in_tar.py b/timm/data/parsers/parser_image_in_tar.py index c6ada962..4fcad797 100644 --- a/timm/data/parsers/parser_image_in_tar.py +++ b/timm/data/parsers/parser_image_in_tar.py @@ -9,20 +9,20 @@ Labels are based on the combined folder and/or tar name structure. Hacked together by / Copyright 2020 Ross Wightman """ +import logging import os -import tarfile import pickle -import logging -import numpy as np +import tarfile from glob import glob -from typing import List, Dict +from typing import List, Tuple, Dict, Set, Optional, Union + +import numpy as np from timm.utils.misc import natural_key -from .parser import Parser from .class_map import load_class_map -from .constants import IMG_EXTENSIONS - +from .img_extensions import get_img_extensions +from .parser import Parser _logger = logging.getLogger(__name__) CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' @@ -39,7 +39,7 @@ class TarState: self.tf = None -def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS): +def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]): sample_count = 0 for i, ti in enumerate(tf): if not ti.isfile(): @@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE return sample_count -def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True): +def extract_tarinfos( + root, + class_name_to_idx: Optional[Dict] = None, + cache_tarinfo: Optional[bool] = None, + extensions: Optional[Union[List, Tuple, Set]] = None, + sort: bool = True +): + extensions = get_img_extensions(as_set=True) if not extensions else set(extensions) root_is_tar = False if os.path.isfile(root): assert os.path.splitext(root)[-1].lower() == '.tar' @@ -176,8 +183,8 @@ class ParserImageInTar(Parser): self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos( self.root, class_name_to_idx=class_name_to_idx, - cache_tarinfo=cache_tarinfo, - extensions=IMG_EXTENSIONS) + cache_tarinfo=cache_tarinfo + ) self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} if len(tarfiles) == 1 and tarfiles[0][0] is None: self.root_is_tar = True diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/parsers/parser_image_tar.py index 467537f4..c2ed429d 100644 --- a/timm/data/parsers/parser_image_tar.py +++ b/timm/data/parsers/parser_image_tar.py @@ -8,13 +8,15 @@ Hacked together by / Copyright 2020 Ross Wightman import os import tarfile -from .parser import Parser -from .class_map import load_class_map -from .constants import IMG_EXTENSIONS from timm.utils.misc import natural_key +from .class_map import load_class_map +from .img_extensions import get_img_extensions +from .parser import Parser + def extract_tarinfo(tarfile, class_to_idx=None, sort=True): + extensions = get_img_extensions(as_set=True) files = [] labels = [] for ti in tarfile.getmembers(): @@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True): dirname, basename = os.path.split(ti.path) label = os.path.basename(dirname) ext = os.path.splitext(basename)[1] - if ext.lower() in IMG_EXTENSIONS: + if ext.lower() in extensions: files.append(ti) labels.append(label) if class_to_idx is None: