diff --git a/inference.py b/inference.py index 16d19944..243f6f38 100755 --- a/inference.py +++ b/inference.py @@ -31,8 +31,8 @@ parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') -parser.add_argument('--img-size', default=None, type=int, - metavar='N', help='Input image dimension') +parser.add_argument('--img-size', default=None, nargs=3, type=int, + metavar='N N N', help='Input image dimension (e.g. --img-size 3 224 224), uses model default if empty') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', diff --git a/timm/data/config.py b/timm/data/config.py index 9cb4bda8..cb686e84 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -23,8 +23,10 @@ def resolve_data_config(args, default_cfg={}, model=None, verbose=True): input_size = tuple(args['input_size']) in_chans = input_size[0] # input_size overrides in_chans elif 'img_size' in args and args['img_size'] is not None: - assert isinstance(args['img_size'], int) - input_size = (in_chans, args['img_size'], args['img_size']) + assert isinstance(args['img_size'], (tuple, list)) + assert len(args['img_size']) == 3 + input_size = tuple(args['img_size']) + in_chans = input_size[0] # input_size overrides in_chans elif 'input_size' in default_cfg: input_size = default_cfg['input_size'] new_config['input_size'] = input_size diff --git a/train.py b/train.py index 7a93a1b6..b95953ed 100755 --- a/train.py +++ b/train.py @@ -80,8 +80,8 @@ parser.add_argument('--num-classes', type=int, default=1000, metavar='N', help='number of label classes (default: 1000)') parser.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') -parser.add_argument('--img-size', type=int, default=None, metavar='N', - help='Image patch size (default: None => model default)') +parser.add_argument('--img-size', default=None, nargs=3, type=int, + metavar='N N N', help='Input image dimension (e.g. --img-size 3 224 224), uses model default if empty') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', diff --git a/validate.py b/validate.py index 5a0d388c..8e5a9662 100755 --- a/validate.py +++ b/validate.py @@ -50,8 +50,8 @@ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') -parser.add_argument('--img-size', default=None, type=int, - metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--img-size', default=None, nargs=3, type=int, + metavar='N N N', help='Input image dimension (e.g. --img-size 3 224 224), uses model default if empty') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop pct') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',