Rename dataset/parsers -> dataset/readers, create_parser to create_reader, etc

pull/1479/head
Ross Wightman 2 years ago
parent f67a7ee8bd
commit e9dccc918c

@ -112,7 +112,7 @@ More models, more fixes
* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs.
* Hugging Face Hub support fixes verified, demo notebook TBA
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`

@ -6,8 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser,\
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .readers import create_reader
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform

@ -10,7 +10,7 @@ import torch
import torch.utils.data as data
from PIL import Image
from .parsers import create_parser
from .readers import create_reader
_logger = logging.getLogger(__name__)
@ -23,7 +23,7 @@ class ImageDataset(data.Dataset):
def __init__(
self,
root,
parser=None,
reader=None,
split='train',
class_map=None,
load_bytes=False,
@ -31,14 +31,14 @@ class ImageDataset(data.Dataset):
transform=None,
target_transform=None,
):
if parser is None or isinstance(parser, str):
parser = create_parser(
parser or '',
if reader is None or isinstance(reader, str):
reader = create_reader(
reader or '',
root=root,
split=split,
class_map=class_map
)
self.parser = parser
self.reader = reader
self.load_bytes = load_bytes
self.img_mode = img_mode
self.transform = transform
@ -46,15 +46,15 @@ class ImageDataset(data.Dataset):
self._consecutive_errors = 0
def __getitem__(self, index):
img, target = self.parser[index]
img, target = self.reader[index]
try:
img = img.read() if self.load_bytes else Image.open(img)
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
_logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
return self.__getitem__((index + 1) % len(self.reader))
else:
raise e
self._consecutive_errors = 0
@ -72,13 +72,13 @@ class ImageDataset(data.Dataset):
return img, target
def __len__(self):
return len(self.parser)
return len(self.reader)
def filename(self, index, basename=False, absolute=False):
return self.parser.filename(index, basename, absolute)
return self.reader.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
return self.reader.filenames(basename, absolute)
class IterableImageDataset(data.IterableDataset):
@ -86,7 +86,7 @@ class IterableImageDataset(data.IterableDataset):
def __init__(
self,
root,
parser=None,
reader=None,
split='train',
is_training=False,
batch_size=None,
@ -96,10 +96,10 @@ class IterableImageDataset(data.IterableDataset):
transform=None,
target_transform=None,
):
assert parser is not None
if isinstance(parser, str):
self.parser = create_parser(
parser,
assert reader is not None
if isinstance(reader, str):
self.reader = create_reader(
reader,
root=root,
split=split,
is_training=is_training,
@ -109,13 +109,13 @@ class IterableImageDataset(data.IterableDataset):
download=download,
)
else:
self.parser = parser
self.reader = reader
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __iter__(self):
for img, target in self.parser:
for img, target in self.reader:
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
@ -123,29 +123,29 @@ class IterableImageDataset(data.IterableDataset):
yield img, target
def __len__(self):
if hasattr(self.parser, '__len__'):
return len(self.parser)
if hasattr(self.reader, '__len__'):
return len(self.reader)
else:
return 0
def set_epoch(self, count):
# TFDS and WDS need external epoch count for deterministic cross process shuffle
if hasattr(self.parser, 'set_epoch'):
self.parser.set_epoch(count)
if hasattr(self.reader, 'set_epoch'):
self.reader.set_epoch(count)
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
# TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
if hasattr(self.parser, 'set_loader_cfg'):
self.parser.set_loader_cfg(num_workers=num_workers)
if hasattr(self.reader, 'set_loader_cfg'):
self.reader.set_loader_cfg(num_workers=num_workers)
def filename(self, index, basename=False, absolute=False):
assert False, 'Filename lookup by index not supported, use filenames().'
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
return self.reader.filenames(basename, absolute)
class AugMixDataset(torch.utils.data.Dataset):

@ -137,11 +137,11 @@ def create_dataset(
elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, parser=name, split=split, **kwargs)
ds = ImageDataset(root, reader=name, split=split, **kwargs)
elif name.startswith('tfds/'):
ds = IterableImageDataset(
root,
parser=name,
reader=name,
split=split,
is_training=is_training,
download=download,
@ -153,7 +153,7 @@ def create_dataset(
elif name.startswith('wds/'):
ds = IterableImageDataset(
root,
parser=name,
reader=name,
split=split,
is_training=is_training,
batch_size=batch_size,
@ -166,5 +166,5 @@ def create_dataset(
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
return ds

@ -1,2 +0,0 @@
from .parser_factory import create_parser
from .img_extensions import *

@ -1,35 +0,0 @@
import os
from .parser_image_folder import ParserImageFolder
from .parser_image_in_tar import ParserImageInTar
def create_parser(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'hfds':
from .parser_hfds import ParserHfds # defer tensorflow import
parser = ParserHfds(root, name, split=split, **kwargs)
elif prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, **kwargs)
elif prefix == 'wds':
from .parser_wds import ParserWds
kwargs.pop('download', False)
parser = ParserWds(root, name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here, in parser?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageInTar(root, **kwargs)
else:
parser = ParserImageFolder(root, **kwargs)
return parser

@ -0,0 +1,2 @@
from .reader_factory import create_reader
from .img_extensions import *

@ -1,7 +1,7 @@
from abc import abstractmethod
class Parser:
class Reader:
def __init__(self):
pass

@ -0,0 +1,35 @@
import os
from .reader_image_folder import ReaderImageFolder
from .reader_image_in_tar import ReaderImageInTar
def create_reader(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'hfds':
from .reader_hfds import ReaderHfds # defer tensorflow import
reader = ReaderHfds(root, name, split=split, **kwargs)
elif prefix == 'tfds':
from .reader_tfds import ReaderTfds # defer tensorflow import
reader = ReaderTfds(root, name, split=split, **kwargs)
elif prefix == 'wds':
from .reader_wds import ReaderWds
kwargs.pop('download', False)
reader = ReaderWds(root, name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here or in reader?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
reader = ReaderImageInTar(root, **kwargs)
else:
reader = ReaderImageFolder(root, **kwargs)
return reader

@ -1,4 +1,5 @@
""" Dataset parser interface that wraps Hugging Face datasets
""" Dataset reader that wraps Hugging Face datasets
Hacked together by / Copyright 2022 Ross Wightman
"""
import io
@ -12,7 +13,7 @@ try:
except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.")
exit(1)
from .parser import Parser
from .reader import Reader
def get_class_labels(info):
@ -23,7 +24,7 @@ def get_class_labels(info):
return class_to_idx
class ParserHfds(Parser):
class ReaderHfds(Reader):
def __init__(
self,

@ -1,6 +1,6 @@
""" A dataset parser that reads images from folders
""" A dataset reader that extracts images from folders
Folders are scannerd recursively to find image files. Labels are based
Folders are scanned recursively to find image files. Labels are based
on the folder hierarchy, just leaf folders by default.
Hacked together by / Copyright 2020 Ross Wightman
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
from .class_map import load_class_map
from .img_extensions import get_img_extensions
from .parser import Parser
from .reader import Reader
def find_images_and_targets(
@ -56,7 +56,7 @@ def find_images_and_targets(
return images_and_targets, class_to_idx
class ParserImageFolder(Parser):
class ReaderImageFolder(Reader):
def __init__(
self,

@ -1,6 +1,6 @@
""" A dataset parser that reads tarfile based datasets
""" A dataset reader that reads tarfile based datasets
This parser can read and extract image samples from:
This reader can extract image samples from:
* a single tar of image files
* a folder of multiple tarfiles containing imagefiles
* a tar of tars containing image files
@ -22,7 +22,7 @@ from timm.utils.misc import natural_key
from .class_map import load_class_map
from .img_extensions import get_img_extensions
from .parser import Parser
from .reader import Reader
_logger = logging.getLogger(__name__)
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
@ -169,8 +169,8 @@ def extract_tarinfos(
return samples, targets, class_name_to_idx, tarfiles
class ParserImageInTar(Parser):
""" Multi-tarfile dataset parser where there is one .tar file per class
class ReaderImageInTar(Reader):
""" Multi-tarfile dataset reader where there is one .tar file per class
"""
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):

@ -1,6 +1,6 @@
""" A dataset parser that reads single tarfile based datasets
""" A dataset reader that reads single tarfile based datasets
This parser can read datasets consisting if a single tarfile containing images.
This reader can read datasets consisting if a single tarfile containing images.
I am planning to deprecated it in favour of ParerImageInTar.
Hacked together by / Copyright 2020 Ross Wightman
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
from .class_map import load_class_map
from .img_extensions import get_img_extensions
from .parser import Parser
from .reader import Reader
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
@ -38,9 +38,9 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
return tarinfo_and_targets, class_to_idx
class ParserImageTar(Parser):
class ReaderImageTar(Reader):
""" Single tarfile dataset where classes are mapped to folders within tar
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can
operate on folders of tars or tars in tars.
"""
def __init__(self, root, class_map=''):

@ -1,4 +1,4 @@
""" Dataset parser interface that wraps TFDS datasets
""" Dataset reader that wraps TFDS datasets
Wraps many (most?) TFDS image-classification datasets
from https://github.com/tensorflow/datasets
@ -34,7 +34,7 @@ except ImportError as e:
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
exit(1)
from .parser import Parser
from .reader import Reader
from .shared_count import SharedCount
@ -56,7 +56,7 @@ def get_class_labels(info):
return class_to_idx
class ParserTfds(Parser):
class ReaderTfds(Reader):
""" Wrap Tensorflow Datasets for use in PyTorch
There several things to be aware of:

@ -1,4 +1,4 @@
""" Dataset parser interface for webdataset
""" Dataset reader for webdataset
Hacked together by / Copyright 2022 Ross Wightman
"""
@ -29,7 +29,7 @@ except ImportError:
wds = None
expand_urls = None
from .parser import Parser
from .reader import Reader
from .shared_count import SharedCount
_logger = logging.getLogger(__name__)
@ -280,7 +280,7 @@ class ResampledShards2(IterableDataset):
yield dict(url=self.urls[index])
class ParserWds(Parser):
class ReaderWds(Reader):
def __init__(
self,
root,
Loading…
Cancel
Save