diff --git a/README.md b/README.md index d0f6cd0e..677ecafc 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ More models, more fixes * `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs. * Hugging Face Hub support fixes verified, demo notebook TBA * Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation. -* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103) +* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103) * Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases. * a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges. * previous impl exists as `LayerNormExp2d` in `models/layers/norm.py` diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 0eb10a66..7cc7b0b0 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -6,8 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset from .loader import create_loader from .mixup import Mixup, FastCollateMixup -from .parsers import create_parser,\ - get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions +from .readers import create_reader +from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions from .real_labels import RealLabelsImagenet from .transforms import * from .transforms_factory import create_transform diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 93b429e0..e7f67925 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -10,7 +10,7 @@ import torch import torch.utils.data as data from PIL import Image -from .parsers import create_parser +from .readers import create_reader _logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ class ImageDataset(data.Dataset): def __init__( self, root, - parser=None, + reader=None, split='train', class_map=None, load_bytes=False, @@ -31,14 +31,14 @@ class ImageDataset(data.Dataset): transform=None, target_transform=None, ): - if parser is None or isinstance(parser, str): - parser = create_parser( - parser or '', + if reader is None or isinstance(reader, str): + reader = create_reader( + reader or '', root=root, split=split, class_map=class_map ) - self.parser = parser + self.reader = reader self.load_bytes = load_bytes self.img_mode = img_mode self.transform = transform @@ -46,15 +46,15 @@ class ImageDataset(data.Dataset): self._consecutive_errors = 0 def __getitem__(self, index): - img, target = self.parser[index] + img, target = self.reader[index] try: img = img.read() if self.load_bytes else Image.open(img) except Exception as e: - _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}') + _logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}') self._consecutive_errors += 1 if self._consecutive_errors < _ERROR_RETRY: - return self.__getitem__((index + 1) % len(self.parser)) + return self.__getitem__((index + 1) % len(self.reader)) else: raise e self._consecutive_errors = 0 @@ -72,13 +72,13 @@ class ImageDataset(data.Dataset): return img, target def __len__(self): - return len(self.parser) + return len(self.reader) def filename(self, index, basename=False, absolute=False): - return self.parser.filename(index, basename, absolute) + return self.reader.filename(index, basename, absolute) def filenames(self, basename=False, absolute=False): - return self.parser.filenames(basename, absolute) + return self.reader.filenames(basename, absolute) class IterableImageDataset(data.IterableDataset): @@ -86,7 +86,7 @@ class IterableImageDataset(data.IterableDataset): def __init__( self, root, - parser=None, + reader=None, split='train', is_training=False, batch_size=None, @@ -96,10 +96,10 @@ class IterableImageDataset(data.IterableDataset): transform=None, target_transform=None, ): - assert parser is not None - if isinstance(parser, str): - self.parser = create_parser( - parser, + assert reader is not None + if isinstance(reader, str): + self.reader = create_reader( + reader, root=root, split=split, is_training=is_training, @@ -109,13 +109,13 @@ class IterableImageDataset(data.IterableDataset): download=download, ) else: - self.parser = parser + self.reader = reader self.transform = transform self.target_transform = target_transform self._consecutive_errors = 0 def __iter__(self): - for img, target in self.parser: + for img, target in self.reader: if self.transform is not None: img = self.transform(img) if self.target_transform is not None: @@ -123,29 +123,29 @@ class IterableImageDataset(data.IterableDataset): yield img, target def __len__(self): - if hasattr(self.parser, '__len__'): - return len(self.parser) + if hasattr(self.reader, '__len__'): + return len(self.reader) else: return 0 def set_epoch(self, count): # TFDS and WDS need external epoch count for deterministic cross process shuffle - if hasattr(self.parser, 'set_epoch'): - self.parser.set_epoch(count) + if hasattr(self.reader, 'set_epoch'): + self.reader.set_epoch(count) def set_loader_cfg( self, num_workers: Optional[int] = None, ): # TFDS and WDS readers need # workers for correct # samples estimate before loader processes created - if hasattr(self.parser, 'set_loader_cfg'): - self.parser.set_loader_cfg(num_workers=num_workers) + if hasattr(self.reader, 'set_loader_cfg'): + self.reader.set_loader_cfg(num_workers=num_workers) def filename(self, index, basename=False, absolute=False): assert False, 'Filename lookup by index not supported, use filenames().' def filenames(self, basename=False, absolute=False): - return self.parser.filenames(basename, absolute) + return self.reader.filenames(basename, absolute) class AugMixDataset(torch.utils.data.Dataset): diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 2c2bb0bf..3777a5aa 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -137,11 +137,11 @@ def create_dataset( elif name.startswith('hfds/'): # NOTE right now, HF datasets default arrow format is a random-access Dataset, # There will be a IterableDataset variant too, TBD - ds = ImageDataset(root, parser=name, split=split, **kwargs) + ds = ImageDataset(root, reader=name, split=split, **kwargs) elif name.startswith('tfds/'): ds = IterableImageDataset( root, - parser=name, + reader=name, split=split, is_training=is_training, download=download, @@ -153,7 +153,7 @@ def create_dataset( elif name.startswith('wds/'): ds = IterableImageDataset( root, - parser=name, + reader=name, split=split, is_training=is_training, batch_size=batch_size, @@ -166,5 +166,5 @@ def create_dataset( if search_split and os.path.isdir(root): # look for split specific sub-folder in root root = _search_split(root, split) - ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) + ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs) return ds diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py deleted file mode 100644 index 4e820d5e..00000000 --- a/timm/data/parsers/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .parser_factory import create_parser -from .img_extensions import * diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py deleted file mode 100644 index f5133433..00000000 --- a/timm/data/parsers/parser_factory.py +++ /dev/null @@ -1,35 +0,0 @@ -import os - -from .parser_image_folder import ParserImageFolder -from .parser_image_in_tar import ParserImageInTar - - -def create_parser(name, root, split='train', **kwargs): - name = name.lower() - name = name.split('/', 2) - prefix = '' - if len(name) > 1: - prefix = name[0] - name = name[-1] - - # FIXME improve the selection right now just tfds prefix or fallback path, will need options to - # explicitly select other options shortly - if prefix == 'hfds': - from .parser_hfds import ParserHfds # defer tensorflow import - parser = ParserHfds(root, name, split=split, **kwargs) - elif prefix == 'tfds': - from .parser_tfds import ParserTfds # defer tensorflow import - parser = ParserTfds(root, name, split=split, **kwargs) - elif prefix == 'wds': - from .parser_wds import ParserWds - kwargs.pop('download', False) - parser = ParserWds(root, name, split=split, **kwargs) - else: - assert os.path.exists(root) - # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder - # FIXME support split here, in parser? - if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': - parser = ParserImageInTar(root, **kwargs) - else: - parser = ParserImageFolder(root, **kwargs) - return parser diff --git a/timm/data/readers/__init__.py b/timm/data/readers/__init__.py new file mode 100644 index 00000000..63e4e649 --- /dev/null +++ b/timm/data/readers/__init__.py @@ -0,0 +1,2 @@ +from .reader_factory import create_reader +from .img_extensions import * diff --git a/timm/data/parsers/class_map.py b/timm/data/readers/class_map.py similarity index 100% rename from timm/data/parsers/class_map.py rename to timm/data/readers/class_map.py diff --git a/timm/data/parsers/img_extensions.py b/timm/data/readers/img_extensions.py similarity index 100% rename from timm/data/parsers/img_extensions.py rename to timm/data/readers/img_extensions.py diff --git a/timm/data/parsers/parser.py b/timm/data/readers/reader.py similarity index 97% rename from timm/data/parsers/parser.py rename to timm/data/readers/reader.py index 76ab6d18..fe55d313 100644 --- a/timm/data/parsers/parser.py +++ b/timm/data/readers/reader.py @@ -1,7 +1,7 @@ from abc import abstractmethod -class Parser: +class Reader: def __init__(self): pass diff --git a/timm/data/readers/reader_factory.py b/timm/data/readers/reader_factory.py new file mode 100644 index 00000000..58ff56cd --- /dev/null +++ b/timm/data/readers/reader_factory.py @@ -0,0 +1,35 @@ +import os + +from .reader_image_folder import ReaderImageFolder +from .reader_image_in_tar import ReaderImageInTar + + +def create_reader(name, root, split='train', **kwargs): + name = name.lower() + name = name.split('/', 2) + prefix = '' + if len(name) > 1: + prefix = name[0] + name = name[-1] + + # FIXME improve the selection right now just tfds prefix or fallback path, will need options to + # explicitly select other options shortly + if prefix == 'hfds': + from .reader_hfds import ReaderHfds # defer tensorflow import + reader = ReaderHfds(root, name, split=split, **kwargs) + elif prefix == 'tfds': + from .reader_tfds import ReaderTfds # defer tensorflow import + reader = ReaderTfds(root, name, split=split, **kwargs) + elif prefix == 'wds': + from .reader_wds import ReaderWds + kwargs.pop('download', False) + reader = ReaderWds(root, name, split=split, **kwargs) + else: + assert os.path.exists(root) + # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder + # FIXME support split here or in reader? + if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': + reader = ReaderImageInTar(root, **kwargs) + else: + reader = ReaderImageFolder(root, **kwargs) + return reader diff --git a/timm/data/parsers/parser_hfds.py b/timm/data/readers/reader_hfds.py similarity index 94% rename from timm/data/parsers/parser_hfds.py rename to timm/data/readers/reader_hfds.py index a558aaf3..901cf4bc 100644 --- a/timm/data/parsers/parser_hfds.py +++ b/timm/data/readers/reader_hfds.py @@ -1,4 +1,5 @@ -""" Dataset parser interface that wraps Hugging Face datasets +""" Dataset reader that wraps Hugging Face datasets + Hacked together by / Copyright 2022 Ross Wightman """ import io @@ -12,7 +13,7 @@ try: except ImportError as e: print("Please install Hugging Face datasets package `pip install datasets`.") exit(1) -from .parser import Parser +from .reader import Reader def get_class_labels(info): @@ -23,7 +24,7 @@ def get_class_labels(info): return class_to_idx -class ParserHfds(Parser): +class ReaderHfds(Reader): def __init__( self, diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/readers/reader_image_folder.py similarity index 94% rename from timm/data/parsers/parser_image_folder.py rename to timm/data/readers/reader_image_folder.py index 3d22a17b..05823372 100644 --- a/timm/data/parsers/parser_image_folder.py +++ b/timm/data/readers/reader_image_folder.py @@ -1,6 +1,6 @@ -""" A dataset parser that reads images from folders +""" A dataset reader that extracts images from folders -Folders are scannerd recursively to find image files. Labels are based +Folders are scanned recursively to find image files. Labels are based on the folder hierarchy, just leaf folders by default. Hacked together by / Copyright 2020 Ross Wightman @@ -12,7 +12,7 @@ from timm.utils.misc import natural_key from .class_map import load_class_map from .img_extensions import get_img_extensions -from .parser import Parser +from .reader import Reader def find_images_and_targets( @@ -56,7 +56,7 @@ def find_images_and_targets( return images_and_targets, class_to_idx -class ParserImageFolder(Parser): +class ReaderImageFolder(Reader): def __init__( self, diff --git a/timm/data/parsers/parser_image_in_tar.py b/timm/data/readers/reader_image_in_tar.py similarity index 97% rename from timm/data/parsers/parser_image_in_tar.py rename to timm/data/readers/reader_image_in_tar.py index 4fcad797..001c9f4e 100644 --- a/timm/data/parsers/parser_image_in_tar.py +++ b/timm/data/readers/reader_image_in_tar.py @@ -1,6 +1,6 @@ -""" A dataset parser that reads tarfile based datasets +""" A dataset reader that reads tarfile based datasets -This parser can read and extract image samples from: +This reader can extract image samples from: * a single tar of image files * a folder of multiple tarfiles containing imagefiles * a tar of tars containing image files @@ -22,7 +22,7 @@ from timm.utils.misc import natural_key from .class_map import load_class_map from .img_extensions import get_img_extensions -from .parser import Parser +from .reader import Reader _logger = logging.getLogger(__name__) CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' @@ -169,8 +169,8 @@ def extract_tarinfos( return samples, targets, class_name_to_idx, tarfiles -class ParserImageInTar(Parser): - """ Multi-tarfile dataset parser where there is one .tar file per class +class ReaderImageInTar(Reader): + """ Multi-tarfile dataset reader where there is one .tar file per class """ def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None): diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/readers/reader_image_tar.py similarity index 91% rename from timm/data/parsers/parser_image_tar.py rename to timm/data/readers/reader_image_tar.py index c2ed429d..6051f26d 100644 --- a/timm/data/parsers/parser_image_tar.py +++ b/timm/data/readers/reader_image_tar.py @@ -1,6 +1,6 @@ -""" A dataset parser that reads single tarfile based datasets +""" A dataset reader that reads single tarfile based datasets -This parser can read datasets consisting if a single tarfile containing images. +This reader can read datasets consisting if a single tarfile containing images. I am planning to deprecated it in favour of ParerImageInTar. Hacked together by / Copyright 2020 Ross Wightman @@ -12,7 +12,7 @@ from timm.utils.misc import natural_key from .class_map import load_class_map from .img_extensions import get_img_extensions -from .parser import Parser +from .reader import Reader def extract_tarinfo(tarfile, class_to_idx=None, sort=True): @@ -38,9 +38,9 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True): return tarinfo_and_targets, class_to_idx -class ParserImageTar(Parser): +class ReaderImageTar(Reader): """ Single tarfile dataset where classes are mapped to folders within tar - NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can + NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can operate on folders of tars or tars in tars. """ def __init__(self, root, class_map=''): diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/readers/reader_tfds.py similarity index 99% rename from timm/data/parsers/parser_tfds.py rename to timm/data/readers/reader_tfds.py index f55f012f..7ccbf908 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/readers/reader_tfds.py @@ -1,4 +1,4 @@ -""" Dataset parser interface that wraps TFDS datasets +""" Dataset reader that wraps TFDS datasets Wraps many (most?) TFDS image-classification datasets from https://github.com/tensorflow/datasets @@ -34,7 +34,7 @@ except ImportError as e: print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") exit(1) -from .parser import Parser +from .reader import Reader from .shared_count import SharedCount @@ -56,7 +56,7 @@ def get_class_labels(info): return class_to_idx -class ParserTfds(Parser): +class ReaderTfds(Reader): """ Wrap Tensorflow Datasets for use in PyTorch There several things to be aware of: diff --git a/timm/data/parsers/parser_wds.py b/timm/data/readers/reader_wds.py similarity index 99% rename from timm/data/parsers/parser_wds.py rename to timm/data/readers/reader_wds.py index 21bfc3e2..c0121dae 100644 --- a/timm/data/parsers/parser_wds.py +++ b/timm/data/readers/reader_wds.py @@ -1,4 +1,4 @@ -""" Dataset parser interface for webdataset +""" Dataset reader for webdataset Hacked together by / Copyright 2022 Ross Wightman """ @@ -29,7 +29,7 @@ except ImportError: wds = None expand_urls = None -from .parser import Parser +from .reader import Reader from .shared_count import SharedCount _logger = logging.getLogger(__name__) @@ -280,7 +280,7 @@ class ResampledShards2(IterableDataset): yield dict(url=self.urls[index]) -class ParserWds(Parser): +class ReaderWds(Reader): def __init__( self, root, diff --git a/timm/data/parsers/shared_count.py b/timm/data/readers/shared_count.py similarity index 100% rename from timm/data/parsers/shared_count.py rename to timm/data/readers/shared_count.py