@ -10,7 +10,7 @@ import torch
import torch . utils . data as data
from PIL import Image
from . parsers import create_pars er
from . readers import create_read er
_logger = logging . getLogger ( __name__ )
@ -23,7 +23,7 @@ class ImageDataset(data.Dataset):
def __init__ (
self ,
root ,
pars er= None ,
read er= None ,
split = ' train ' ,
class_map = None ,
load_bytes = False ,
@ -31,14 +31,14 @@ class ImageDataset(data.Dataset):
transform = None ,
target_transform = None ,
) :
if parser is None or isinstance ( pars er, str ) :
parser = create_pars er(
pars er or ' ' ,
if reader is None or isinstance ( read er, str ) :
reader = create_read er(
read er or ' ' ,
root = root ,
split = split ,
class_map = class_map
)
self . parser = pars er
self . reader = read er
self . load_bytes = load_bytes
self . img_mode = img_mode
self . transform = transform
@ -46,15 +46,15 @@ class ImageDataset(data.Dataset):
self . _consecutive_errors = 0
def __getitem__ ( self , index ) :
img , target = self . pars er[ index ]
img , target = self . read er[ index ]
try :
img = img . read ( ) if self . load_bytes else Image . open ( img )
except Exception as e :
_logger . warning ( f ' Skipped sample (index { index } , file { self . pars er. filename ( index ) } ). { str ( e ) } ' )
_logger . warning ( f ' Skipped sample (index { index } , file { self . read er. filename ( index ) } ). { str ( e ) } ' )
self . _consecutive_errors + = 1
if self . _consecutive_errors < _ERROR_RETRY :
return self . __getitem__ ( ( index + 1 ) % len ( self . pars er) )
return self . __getitem__ ( ( index + 1 ) % len ( self . read er) )
else :
raise e
self . _consecutive_errors = 0
@ -72,13 +72,13 @@ class ImageDataset(data.Dataset):
return img , target
def __len__ ( self ) :
return len ( self . pars er)
return len ( self . read er)
def filename ( self , index , basename = False , absolute = False ) :
return self . pars er. filename ( index , basename , absolute )
return self . read er. filename ( index , basename , absolute )
def filenames ( self , basename = False , absolute = False ) :
return self . pars er. filenames ( basename , absolute )
return self . read er. filenames ( basename , absolute )
class IterableImageDataset ( data . IterableDataset ) :
@ -86,7 +86,7 @@ class IterableImageDataset(data.IterableDataset):
def __init__ (
self ,
root ,
pars er= None ,
read er= None ,
split = ' train ' ,
is_training = False ,
batch_size = None ,
@ -96,10 +96,10 @@ class IterableImageDataset(data.IterableDataset):
transform = None ,
target_transform = None ,
) :
assert pars er is not None
if isinstance ( pars er, str ) :
self . parser = create_pars er(
pars er,
assert read er is not None
if isinstance ( read er, str ) :
self . reader = create_read er(
read er,
root = root ,
split = split ,
is_training = is_training ,
@ -109,13 +109,13 @@ class IterableImageDataset(data.IterableDataset):
download = download ,
)
else :
self . parser = pars er
self . reader = read er
self . transform = transform
self . target_transform = target_transform
self . _consecutive_errors = 0
def __iter__ ( self ) :
for img , target in self . pars er:
for img , target in self . read er:
if self . transform is not None :
img = self . transform ( img )
if self . target_transform is not None :
@ -123,29 +123,29 @@ class IterableImageDataset(data.IterableDataset):
yield img , target
def __len__ ( self ) :
if hasattr ( self . pars er, ' __len__ ' ) :
return len ( self . pars er)
if hasattr ( self . read er, ' __len__ ' ) :
return len ( self . read er)
else :
return 0
def set_epoch ( self , count ) :
# TFDS and WDS need external epoch count for deterministic cross process shuffle
if hasattr ( self . pars er, ' set_epoch ' ) :
self . pars er. set_epoch ( count )
if hasattr ( self . read er, ' set_epoch ' ) :
self . read er. set_epoch ( count )
def set_loader_cfg (
self ,
num_workers : Optional [ int ] = None ,
) :
# TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
if hasattr ( self . pars er, ' set_loader_cfg ' ) :
self . pars er. set_loader_cfg ( num_workers = num_workers )
if hasattr ( self . read er, ' set_loader_cfg ' ) :
self . read er. set_loader_cfg ( num_workers = num_workers )
def filename ( self , index , basename = False , absolute = False ) :
assert False , ' Filename lookup by index not supported, use filenames(). '
def filenames ( self , basename = False , absolute = False ) :
return self . pars er. filenames ( basename , absolute )
return self . read er. filenames ( basename , absolute )
class AugMixDataset ( torch . utils . data . Dataset ) :