Improve image extension handling, add methods to modify / get defaults. Fix #1335 fix #1274.

pull/1327/head
Ross Wightman 2 years ago
parent 7d4b3807d5
commit bfc0dccb0e

@ -6,7 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset from .dataset_factory import create_dataset
from .loader import create_loader from .loader import create_loader
from .mixup import Mixup, FastCollateMixup from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser from .parsers import create_parser,\
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet from .real_labels import RealLabelsImagenet
from .transforms import * from .transforms import *
from .transforms_factory import create_transform from .transforms_factory import create_transform

@ -1 +1,2 @@
from .parser_factory import create_parser from .parser_factory import create_parser
from .img_extensions import *

@ -1 +0,0 @@
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')

@ -0,0 +1,50 @@
from copy import deepcopy
__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
def _set_extensions(extensions):
global IMG_EXTENSIONS
global _IMG_EXTENSIONS_SET
dedupe = set() # NOTE de-duping tuple while keeping original order
IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
_IMG_EXTENSIONS_SET = set(extensions)
def _valid_extension(x: str):
return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
def is_img_extension(ext):
return ext in _IMG_EXTENSIONS_SET
def get_img_extensions(as_set=False):
return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
def set_img_extensions(extensions):
assert len(extensions)
for x in extensions:
assert _valid_extension(x)
_set_extensions(extensions)
def add_img_extensions(ext):
if not isinstance(ext, (list, tuple, set)):
ext = (ext,)
for x in ext:
assert _valid_extension(x)
extensions = IMG_EXTENSIONS + tuple(ext)
_set_extensions(extensions)
def del_img_extensions(ext):
if not isinstance(ext, (list, tuple, set)):
ext = (ext,)
extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
_set_extensions(extensions)

@ -1,7 +1,6 @@
import os import os
from .parser_image_folder import ParserImageFolder from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar
from .parser_image_in_tar import ParserImageInTar from .parser_image_in_tar import ParserImageInTar

@ -6,15 +6,35 @@ on the folder hierarchy, just leaf folders by default.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import os import os
from typing import Dict, List, Optional, Set, Tuple, Union
from timm.utils.misc import natural_key from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map from .class_map import load_class_map
from .constants import IMG_EXTENSIONS from .img_extensions import get_img_extensions
from .parser import Parser
def find_images_and_targets(
folder: str,
types: Optional[Union[List, Tuple, Set]] = None,
class_to_idx: Optional[Dict] = None,
leaf_name_only: bool = True,
sort: bool = True
):
""" Walk folder recursively to discover images and map them to classes by folder names.
Args:
folder: root of folder to recrusively search
types: types (file extensions) to search for in path
class_to_idx: specify mapping for class (folder name) to class index if set
leaf_name_only: use only leaf-name of folder walk for class names
sort: re-sort found images by name (for consistent ordering)
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): Returns:
A list of image and target tuples, class_to_idx mapping
"""
types = get_img_extensions(as_set=True) if not types else set(types)
labels = [] labels = []
filenames = [] filenames = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
@ -51,7 +71,8 @@ class ParserImageFolder(Parser):
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(self.samples) == 0: if len(self.samples) == 0:
raise RuntimeError( raise RuntimeError(
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') f'Found 0 images in subfolders of {root}. '
f'Supported image extensions are {", ".join(get_img_extensions())}')
def __getitem__(self, index): def __getitem__(self, index):
path, target = self.samples[index] path, target = self.samples[index]

@ -9,20 +9,20 @@ Labels are based on the combined folder and/or tar name structure.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import logging
import os import os
import tarfile
import pickle import pickle
import logging import tarfile
import numpy as np
from glob import glob from glob import glob
from typing import List, Dict from typing import List, Tuple, Dict, Set, Optional, Union
import numpy as np
from timm.utils.misc import natural_key from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map from .class_map import load_class_map
from .constants import IMG_EXTENSIONS from .img_extensions import get_img_extensions
from .parser import Parser
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
@ -39,7 +39,7 @@ class TarState:
self.tf = None self.tf = None
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS): def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
sample_count = 0 sample_count = 0
for i, ti in enumerate(tf): for i, ti in enumerate(tf):
if not ti.isfile(): if not ti.isfile():
@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE
return sample_count return sample_count
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True): def extract_tarinfos(
root,
class_name_to_idx: Optional[Dict] = None,
cache_tarinfo: Optional[bool] = None,
extensions: Optional[Union[List, Tuple, Set]] = None,
sort: bool = True
):
extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
root_is_tar = False root_is_tar = False
if os.path.isfile(root): if os.path.isfile(root):
assert os.path.splitext(root)[-1].lower() == '.tar' assert os.path.splitext(root)[-1].lower() == '.tar'
@ -176,8 +183,8 @@ class ParserImageInTar(Parser):
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos( self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
self.root, self.root,
class_name_to_idx=class_name_to_idx, class_name_to_idx=class_name_to_idx,
cache_tarinfo=cache_tarinfo, cache_tarinfo=cache_tarinfo
extensions=IMG_EXTENSIONS) )
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
if len(tarfiles) == 1 and tarfiles[0][0] is None: if len(tarfiles) == 1 and tarfiles[0][0] is None:
self.root_is_tar = True self.root_is_tar = True

@ -8,13 +8,15 @@ Hacked together by / Copyright 2020 Ross Wightman
import os import os
import tarfile import tarfile
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from timm.utils.misc import natural_key from timm.utils.misc import natural_key
from .class_map import load_class_map
from .img_extensions import get_img_extensions
from .parser import Parser
def extract_tarinfo(tarfile, class_to_idx=None, sort=True): def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
extensions = get_img_extensions(as_set=True)
files = [] files = []
labels = [] labels = []
for ti in tarfile.getmembers(): for ti in tarfile.getmembers():
@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
dirname, basename = os.path.split(ti.path) dirname, basename = os.path.split(ti.path)
label = os.path.basename(dirname) label = os.path.basename(dirname)
ext = os.path.splitext(basename)[1] ext = os.path.splitext(basename)[1]
if ext.lower() in IMG_EXTENSIONS: if ext.lower() in extensions:
files.append(ti) files.append(ti)
labels.append(label) labels.append(label)
if class_to_idx is None: if class_to_idx is None:

Loading…
Cancel
Save