Rename dataset/parsers -> dataset/readers, create_parser to create_reader, etc

pull/1479/head
Ross Wightman 2 years ago
parent f67a7ee8bd
commit e9dccc918c

@ -112,7 +112,7 @@ More models, more fixes
* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs. * `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs.
* Hugging Face Hub support fixes verified, demo notebook TBA * Hugging Face Hub support fixes verified, demo notebook TBA
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation. * Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103) * Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases. * Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges. * a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py` * previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`

@ -6,8 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset from .dataset_factory import create_dataset
from .loader import create_loader from .loader import create_loader
from .mixup import Mixup, FastCollateMixup from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser,\ from .readers import create_reader
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet from .real_labels import RealLabelsImagenet
from .transforms import * from .transforms import *
from .transforms_factory import create_transform from .transforms_factory import create_transform

@ -10,7 +10,7 @@ import torch
import torch.utils.data as data import torch.utils.data as data
from PIL import Image from PIL import Image
from .parsers import create_parser from .readers import create_reader
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -23,7 +23,7 @@ class ImageDataset(data.Dataset):
def __init__( def __init__(
self, self,
root, root,
parser=None, reader=None,
split='train', split='train',
class_map=None, class_map=None,
load_bytes=False, load_bytes=False,
@ -31,14 +31,14 @@ class ImageDataset(data.Dataset):
transform=None, transform=None,
target_transform=None, target_transform=None,
): ):
if parser is None or isinstance(parser, str): if reader is None or isinstance(reader, str):
parser = create_parser( reader = create_reader(
parser or '', reader or '',
root=root, root=root,
split=split, split=split,
class_map=class_map class_map=class_map
) )
self.parser = parser self.reader = reader
self.load_bytes = load_bytes self.load_bytes = load_bytes
self.img_mode = img_mode self.img_mode = img_mode
self.transform = transform self.transform = transform
@ -46,15 +46,15 @@ class ImageDataset(data.Dataset):
self._consecutive_errors = 0 self._consecutive_errors = 0
def __getitem__(self, index): def __getitem__(self, index):
img, target = self.parser[index] img, target = self.reader[index]
try: try:
img = img.read() if self.load_bytes else Image.open(img) img = img.read() if self.load_bytes else Image.open(img)
except Exception as e: except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}') _logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
self._consecutive_errors += 1 self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY: if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser)) return self.__getitem__((index + 1) % len(self.reader))
else: else:
raise e raise e
self._consecutive_errors = 0 self._consecutive_errors = 0
@ -72,13 +72,13 @@ class ImageDataset(data.Dataset):
return img, target return img, target
def __len__(self): def __len__(self):
return len(self.parser) return len(self.reader)
def filename(self, index, basename=False, absolute=False): def filename(self, index, basename=False, absolute=False):
return self.parser.filename(index, basename, absolute) return self.reader.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False): def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute) return self.reader.filenames(basename, absolute)
class IterableImageDataset(data.IterableDataset): class IterableImageDataset(data.IterableDataset):
@ -86,7 +86,7 @@ class IterableImageDataset(data.IterableDataset):
def __init__( def __init__(
self, self,
root, root,
parser=None, reader=None,
split='train', split='train',
is_training=False, is_training=False,
batch_size=None, batch_size=None,
@ -96,10 +96,10 @@ class IterableImageDataset(data.IterableDataset):
transform=None, transform=None,
target_transform=None, target_transform=None,
): ):
assert parser is not None assert reader is not None
if isinstance(parser, str): if isinstance(reader, str):
self.parser = create_parser( self.reader = create_reader(
parser, reader,
root=root, root=root,
split=split, split=split,
is_training=is_training, is_training=is_training,
@ -109,13 +109,13 @@ class IterableImageDataset(data.IterableDataset):
download=download, download=download,
) )
else: else:
self.parser = parser self.reader = reader
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self._consecutive_errors = 0 self._consecutive_errors = 0
def __iter__(self): def __iter__(self):
for img, target in self.parser: for img, target in self.reader:
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if self.target_transform is not None: if self.target_transform is not None:
@ -123,29 +123,29 @@ class IterableImageDataset(data.IterableDataset):
yield img, target yield img, target
def __len__(self): def __len__(self):
if hasattr(self.parser, '__len__'): if hasattr(self.reader, '__len__'):
return len(self.parser) return len(self.reader)
else: else:
return 0 return 0
def set_epoch(self, count): def set_epoch(self, count):
# TFDS and WDS need external epoch count for deterministic cross process shuffle # TFDS and WDS need external epoch count for deterministic cross process shuffle
if hasattr(self.parser, 'set_epoch'): if hasattr(self.reader, 'set_epoch'):
self.parser.set_epoch(count) self.reader.set_epoch(count)
def set_loader_cfg( def set_loader_cfg(
self, self,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
): ):
# TFDS and WDS readers need # workers for correct # samples estimate before loader processes created # TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
if hasattr(self.parser, 'set_loader_cfg'): if hasattr(self.reader, 'set_loader_cfg'):
self.parser.set_loader_cfg(num_workers=num_workers) self.reader.set_loader_cfg(num_workers=num_workers)
def filename(self, index, basename=False, absolute=False): def filename(self, index, basename=False, absolute=False):
assert False, 'Filename lookup by index not supported, use filenames().' assert False, 'Filename lookup by index not supported, use filenames().'
def filenames(self, basename=False, absolute=False): def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute) return self.reader.filenames(basename, absolute)
class AugMixDataset(torch.utils.data.Dataset): class AugMixDataset(torch.utils.data.Dataset):

@ -137,11 +137,11 @@ def create_dataset(
elif name.startswith('hfds/'): elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset, # NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD # There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, parser=name, split=split, **kwargs) ds = ImageDataset(root, reader=name, split=split, **kwargs)
elif name.startswith('tfds/'): elif name.startswith('tfds/'):
ds = IterableImageDataset( ds = IterableImageDataset(
root, root,
parser=name, reader=name,
split=split, split=split,
is_training=is_training, is_training=is_training,
download=download, download=download,
@ -153,7 +153,7 @@ def create_dataset(
elif name.startswith('wds/'): elif name.startswith('wds/'):
ds = IterableImageDataset( ds = IterableImageDataset(
root, root,
parser=name, reader=name,
split=split, split=split,
is_training=is_training, is_training=is_training,
batch_size=batch_size, batch_size=batch_size,
@ -166,5 +166,5 @@ def create_dataset(
if search_split and os.path.isdir(root): if search_split and os.path.isdir(root):
# look for split specific sub-folder in root # look for split specific sub-folder in root
root = _search_split(root, split) root = _search_split(root, split)
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
return ds return ds

@ -1,2 +0,0 @@
from .parser_factory import create_parser
from .img_extensions import *

@ -1,35 +0,0 @@
import os
from .parser_image_folder import ParserImageFolder
from .parser_image_in_tar import ParserImageInTar
def create_parser(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'hfds':
from .parser_hfds import ParserHfds # defer tensorflow import
parser = ParserHfds(root, name, split=split, **kwargs)
elif prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, **kwargs)
elif prefix == 'wds':
from .parser_wds import ParserWds
kwargs.pop('download', False)
parser = ParserWds(root, name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here, in parser?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageInTar(root, **kwargs)
else:
parser = ParserImageFolder(root, **kwargs)
return parser

@ -0,0 +1,2 @@
from .reader_factory import create_reader
from .img_extensions import *

@ -1,7 +1,7 @@
from abc import abstractmethod from abc import abstractmethod
class Parser: class Reader:
def __init__(self): def __init__(self):
pass pass

@ -0,0 +1,35 @@
import os
from .reader_image_folder import ReaderImageFolder
from .reader_image_in_tar import ReaderImageInTar
def create_reader(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'hfds':
from .reader_hfds import ReaderHfds # defer tensorflow import
reader = ReaderHfds(root, name, split=split, **kwargs)
elif prefix == 'tfds':
from .reader_tfds import ReaderTfds # defer tensorflow import
reader = ReaderTfds(root, name, split=split, **kwargs)
elif prefix == 'wds':
from .reader_wds import ReaderWds
kwargs.pop('download', False)
reader = ReaderWds(root, name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here or in reader?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
reader = ReaderImageInTar(root, **kwargs)
else:
reader = ReaderImageFolder(root, **kwargs)
return reader

@ -1,4 +1,5 @@
""" Dataset parser interface that wraps Hugging Face datasets """ Dataset reader that wraps Hugging Face datasets
Hacked together by / Copyright 2022 Ross Wightman Hacked together by / Copyright 2022 Ross Wightman
""" """
import io import io
@ -12,7 +13,7 @@ try:
except ImportError as e: except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.") print("Please install Hugging Face datasets package `pip install datasets`.")
exit(1) exit(1)
from .parser import Parser from .reader import Reader
def get_class_labels(info): def get_class_labels(info):
@ -23,7 +24,7 @@ def get_class_labels(info):
return class_to_idx return class_to_idx
class ParserHfds(Parser): class ReaderHfds(Reader):
def __init__( def __init__(
self, self,

@ -1,6 +1,6 @@
""" A dataset parser that reads images from folders """ A dataset reader that extracts images from folders
Folders are scannerd recursively to find image files. Labels are based Folders are scanned recursively to find image files. Labels are based
on the folder hierarchy, just leaf folders by default. on the folder hierarchy, just leaf folders by default.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
from .class_map import load_class_map from .class_map import load_class_map
from .img_extensions import get_img_extensions from .img_extensions import get_img_extensions
from .parser import Parser from .reader import Reader
def find_images_and_targets( def find_images_and_targets(
@ -56,7 +56,7 @@ def find_images_and_targets(
return images_and_targets, class_to_idx return images_and_targets, class_to_idx
class ParserImageFolder(Parser): class ReaderImageFolder(Reader):
def __init__( def __init__(
self, self,

@ -1,6 +1,6 @@
""" A dataset parser that reads tarfile based datasets """ A dataset reader that reads tarfile based datasets
This parser can read and extract image samples from: This reader can extract image samples from:
* a single tar of image files * a single tar of image files
* a folder of multiple tarfiles containing imagefiles * a folder of multiple tarfiles containing imagefiles
* a tar of tars containing image files * a tar of tars containing image files
@ -22,7 +22,7 @@ from timm.utils.misc import natural_key
from .class_map import load_class_map from .class_map import load_class_map
from .img_extensions import get_img_extensions from .img_extensions import get_img_extensions
from .parser import Parser from .reader import Reader
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
@ -169,8 +169,8 @@ def extract_tarinfos(
return samples, targets, class_name_to_idx, tarfiles return samples, targets, class_name_to_idx, tarfiles
class ParserImageInTar(Parser): class ReaderImageInTar(Reader):
""" Multi-tarfile dataset parser where there is one .tar file per class """ Multi-tarfile dataset reader where there is one .tar file per class
""" """
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None): def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):

@ -1,6 +1,6 @@
""" A dataset parser that reads single tarfile based datasets """ A dataset reader that reads single tarfile based datasets
This parser can read datasets consisting if a single tarfile containing images. This reader can read datasets consisting if a single tarfile containing images.
I am planning to deprecated it in favour of ParerImageInTar. I am planning to deprecated it in favour of ParerImageInTar.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
from .class_map import load_class_map from .class_map import load_class_map
from .img_extensions import get_img_extensions from .img_extensions import get_img_extensions
from .parser import Parser from .reader import Reader
def extract_tarinfo(tarfile, class_to_idx=None, sort=True): def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
@ -38,9 +38,9 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
return tarinfo_and_targets, class_to_idx return tarinfo_and_targets, class_to_idx
class ParserImageTar(Parser): class ReaderImageTar(Reader):
""" Single tarfile dataset where classes are mapped to folders within tar """ Single tarfile dataset where classes are mapped to folders within tar
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can
operate on folders of tars or tars in tars. operate on folders of tars or tars in tars.
""" """
def __init__(self, root, class_map=''): def __init__(self, root, class_map=''):

@ -1,4 +1,4 @@
""" Dataset parser interface that wraps TFDS datasets """ Dataset reader that wraps TFDS datasets
Wraps many (most?) TFDS image-classification datasets Wraps many (most?) TFDS image-classification datasets
from https://github.com/tensorflow/datasets from https://github.com/tensorflow/datasets
@ -34,7 +34,7 @@ except ImportError as e:
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
exit(1) exit(1)
from .parser import Parser from .reader import Reader
from .shared_count import SharedCount from .shared_count import SharedCount
@ -56,7 +56,7 @@ def get_class_labels(info):
return class_to_idx return class_to_idx
class ParserTfds(Parser): class ReaderTfds(Reader):
""" Wrap Tensorflow Datasets for use in PyTorch """ Wrap Tensorflow Datasets for use in PyTorch
There several things to be aware of: There several things to be aware of:

@ -1,4 +1,4 @@
""" Dataset parser interface for webdataset """ Dataset reader for webdataset
Hacked together by / Copyright 2022 Ross Wightman Hacked together by / Copyright 2022 Ross Wightman
""" """
@ -29,7 +29,7 @@ except ImportError:
wds = None wds = None
expand_urls = None expand_urls = None
from .parser import Parser from .reader import Reader
from .shared_count import SharedCount from .shared_count import SharedCount
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -280,7 +280,7 @@ class ResampledShards2(IterableDataset):
yield dict(url=self.urls[index]) yield dict(url=self.urls[index])
class ParserWds(Parser): class ReaderWds(Reader):
def __init__( def __init__(
self, self,
root, root,
Loading…
Cancel
Save