diff --git a/clean_checkpoint.py b/clean_checkpoint.py index d51e0d96..b088aa8f 100644 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -8,12 +8,15 @@ from collections import OrderedDict parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') -parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH', +parser.add_argument('--output', default='', type=str, metavar='PATH', help='output path') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') +_TEMP_NAME = './_checkpoint.pth' + + def main(): args = parser.parse_args() @@ -40,13 +43,18 @@ def main(): new_state_dict[name] = v print("=> Loaded state_dict from '{}'".format(args.checkpoint)) - torch.save(new_state_dict, args.output) - with open(args.output, 'rb') as f: + torch.save(new_state_dict, _TEMP_NAME) + with open(_TEMP_NAME, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() - checkpoint_base = os.path.splitext(args.checkpoint)[0] + if args.output: + checkpoint_root, checkpoint_base = os.path.split(args.output) + checkpoint_base = os.path.splitext(checkpoint_base)[0] + else: + checkpoint_root = '' + checkpoint_base = os.path.splitext(args.checkpoint)[0] final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' - shutil.move(args.output, final_filename) + shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) else: print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) diff --git a/results/results-inv2-matched-frequency.csv b/results/results-inv2-matched-frequency.csv index 4b42fdd9..3f791d6e 100644 --- a/results/results-inv2-matched-frequency.csv +++ b/results/results-inv2-matched-frequency.csv @@ -1,97 +1,155 @@ model,top1,top1_err,top5,top5_err,param_count,img_size,cropt_pct,interpolation -resnet18,57.18,42.82,80.19,19.81,11.69,224,0.875,bilinear -gluon_resnet18_v1b,58.32,41.68,80.96,19.04,11.69,224,0.875,bicubic -seresnet18,59.81,40.19,81.68,18.32,11.78,224,0.875,bicubic -tv_resnet34,61.2,38.8,82.72,17.28,21.8,224,0.875,bilinear -spnasnet_100,61.21,38.79,82.77,17.23,4.42,224,0.875,bilinear -mnasnet_100,61.91,38.09,83.71,16.29,4.38,224,0.875,bicubic -fbnetc_100,62.43,37.57,83.39,16.61,5.57,224,0.875,bilinear -gluon_resnet34_v1b,62.56,37.44,84,16,21.8,224,0.875,bicubic -resnet34,62.82,37.18,84.12,15.88,21.8,224,0.875,bilinear -seresnet34,62.89,37.11,84.22,15.78,21.96,224,0.875,bilinear -densenet121,62.94,37.06,84.26,15.74,7.98,224,0.875,bicubic -semnasnet_100,63.12,36.88,84.53,15.47,3.89,224,0.875,bicubic -mobilenetv3_100,63.23,36.77,84.52,15.48,5.48,224,0.875,bicubic -tv_resnet50,63.33,36.67,84.65,15.35,25.56,224,0.875,bilinear -mixnet_s,63.38,36.62,84.71,15.29,4.13,224,0.875,bicubic -resnet26,63.45,36.55,84.27,15.73,16,224,0.875,bicubic -tf_mixnet_s,63.59,36.41,84.27,15.73,4.13,224,0.875,bicubic -dpn68,64.22,35.78,85.18,14.82,12.61,224,0.875,bicubic -tf_mixnet_m,64.27,35.73,85.09,14.91,5.01,224,0.875,bicubic -tf_efficientnet_b0,64.29,35.71,85.25,14.75,5.29,224,0.875,bicubic -efficientnet_b0,64.58,35.42,85.89,14.11,5.29,224,0.875,bicubic -resnet26d,64.63,35.37,85.12,14.88,16.01,224,0.875,bicubic -mixnet_m,64.69,35.31,85.47,14.53,5.01,224,0.875,bicubic -densenet169,64.78,35.22,85.25,14.75,14.15,224,0.875,bicubic -seresnext26_32x4d,65.04,34.96,85.65,14.35,16.79,224,0.875,bicubic -tf_efficientnet_es,65.24,34.76,85.54,14.46,5.44,224,0.875,bicubic -densenet201,65.28,34.72,85.67,14.33,20.01,224,0.875,bicubic -dpn68b,65.6,34.4,85.94,14.06,12.61,224,0.875,bicubic -resnet101,65.68,34.32,85.98,14.02,44.55,224,0.875,bilinear -densenet161,65.85,34.15,86.46,13.54,28.68,224,0.875,bicubic -gluon_resnet50_v1b,66.04,33.96,86.27,13.73,25.56,224,0.875,bicubic -inception_v3,66.12,33.88,86.34,13.66,27.16,299,0.875,bicubic -tv_resnext50_32x4d,66.18,33.82,86.04,13.96,25.03,224,0.875,bilinear -seresnet50,66.24,33.76,86.33,13.67,28.09,224,0.875,bilinear -tf_inception_v3,66.41,33.59,86.68,13.32,23.83,299,0.875,bicubic -gluon_resnet50_v1c,66.54,33.46,86.16,13.84,25.58,224,0.875,bicubic -adv_inception_v3,66.6,33.4,86.56,13.44,23.83,299,0.875,bicubic -wide_resnet50_2,66.65,33.35,86.81,13.19,68.88,224,0.875,bilinear -wide_resnet101_2,66.68,33.32,87.04,12.96,126.89,224,0.875,bilinear -tf_mixnet_l,66.78,33.22,86.46,13.54,7.33,224,0.875,bicubic -resnet50,66.81,33.19,87,13,25.56,224,0.875,bicubic -tf_efficientnet_em,66.87,33.13,86.98,13.02,6.9,240,0.882,bicubic -resnext50_32x4d,66.88,33.12,86.36,13.64,25.03,224,0.875,bicubic -tf_efficientnet_b1,66.89,33.11,87.04,12.96,7.79,240,0.882,bicubic -mixnet_l,66.97,33.03,86.94,13.06,7.33,224,0.875,bicubic -resnet152,67.02,32.98,87.57,12.43,60.19,224,0.875,bilinear -gluon_resnet50_v1s,67.1,32.9,86.86,13.14,25.68,224,0.875,bicubic -seresnet101,67.15,32.85,87.05,12.95,49.33,224,0.875,bilinear -gluon_resnet101_v1b,67.45,32.55,87.23,12.77,44.55,224,0.875,bicubic -efficientnet_b1,67.55,32.45,87.29,12.71,7.79,240,0.882,bicubic -seresnet152,67.55,32.45,87.39,12.61,66.82,224,0.875,bilinear -gluon_resnet101_v1c,67.56,32.44,87.16,12.84,44.57,224,0.875,bicubic -gluon_inception_v3,67.59,32.41,87.46,12.54,23.83,299,0.875,bicubic -xception,67.67,32.33,87.57,12.43,22.86,299,0.8975,bicubic -efficientnet_b2,67.8,32.2,88.2,11.8,9.11,260,0.89,bicubic -resnext101_32x8d,67.85,32.15,87.48,12.52,88.79,224,0.875,bilinear -seresnext50_32x4d,67.87,32.13,87.62,12.38,27.56,224,0.875,bilinear -gluon_resnet50_v1d,67.91,32.09,87.12,12.88,25.58,224,0.875,bicubic -dpn92,68.01,31.99,87.59,12.41,37.67,224,0.875,bicubic -tf_efficientnet_el,68.18,31.82,88.35,11.65,10.59,300,0.904,bicubic -gluon_resnext50_32x4d,68.28,31.72,87.32,12.68,25.03,224,0.875,bicubic -dpn98,68.58,31.42,87.66,12.34,61.57,224,0.875,bicubic -gluon_seresnext50_32x4d,68.67,31.33,88.32,11.68,27.56,224,0.875,bicubic -dpn107,68.71,31.29,88.13,11.87,86.92,224,0.875,bicubic -gluon_resnet101_v1s,68.72,31.28,87.9,12.1,44.67,224,0.875,bicubic +ig_resnext101_32x48d,76.87,23.13,93.32,6.68,828.41,224,0.875,bilinear +ig_resnext101_32x32d,76.84,23.16,93.19,6.81,468.53,224,0.875,bilinear +tf_efficientnet_b7_ap,76.09,23.91,92.97,7.03,66.35,600,0.949,bicubic +tf_efficientnet_b8_ap,76.09,23.91,92.73,7.27,87.41,672,0.954,bicubic +ig_resnext101_32x16d,75.71,24.29,92.9,7.1,194.03,224,0.875,bilinear +swsl_resnext101_32x8d,75.45,24.55,92.75,7.25,88.79,224,0.875,bilinear +tf_efficientnet_b6_ap,75.38,24.62,92.44,7.56,43.04,528,0.942,bicubic +tf_efficientnet_b7,74.72,25.28,92.22,7.78,66.35,600,0.949,bicubic +tf_efficientnet_b5_ap,74.59,25.41,91.99,8.01,30.39,456,0.934,bicubic +swsl_resnext101_32x4d,74.15,25.85,91.99,8.01,44.18,224,0.875,bilinear +swsl_resnext101_32x16d,74.01,25.99,92.17,7.83,194.03,224,0.875,bilinear +tf_efficientnet_b6,73.9,26.1,91.75,8.25,43.04,528,0.942,bicubic +ig_resnext101_32x8d,73.66,26.34,92.15,7.85,88.79,224,0.875,bilinear +tf_efficientnet_b5,73.54,26.46,91.46,8.54,30.39,456,0.934,bicubic +tf_efficientnet_b4_ap,72.89,27.11,90.98,9.02,19.34,380,0.922,bicubic +swsl_resnext50_32x4d,72.58,27.42,90.84,9.16,25.03,224,0.875,bilinear +pnasnet5large,72.37,27.63,90.26,9.74,86.06,331,0.875,bicubic +nasnetalarge,72.31,27.69,90.51,9.49,88.75,331,0.875,bicubic +tf_efficientnet_b4,72.28,27.72,90.6,9.4,19.34,380,0.922,bicubic +swsl_resnet50,71.69,28.31,90.51,9.49,25.56,224,0.875,bilinear +ssl_resnext101_32x8d,71.49,28.51,90.47,9.53,88.79,224,0.875,bilinear +ssl_resnext101_32x16d,71.4,28.6,90.55,9.45,194.03,224,0.875,bilinear +tf_efficientnet_b3_ap,70.92,29.08,89.43,10.57,12.23,300,0.904,bicubic +tf_efficientnet_b3,70.62,29.38,89.44,10.56,12.23,300,0.904,bicubic +gluon_senet154,70.6,29.4,88.92,11.08,115.09,224,0.875,bicubic +ssl_resnext101_32x4d,70.5,29.5,89.76,10.24,44.18,224,0.875,bilinear +senet154,70.48,29.52,88.99,11.01,115.09,224,0.875,bilinear +gluon_seresnext101_64x4d,70.44,29.56,89.35,10.65,88.23,224,0.875,bicubic +gluon_resnet152_v1s,70.32,29.68,88.87,11.13,60.32,224,0.875,bicubic +inception_resnet_v2,70.12,29.88,88.68,11.32,55.84,299,0.8975,bicubic +gluon_seresnext101_32x4d,70.01,29.99,88.91,11.09,48.96,224,0.875,bicubic +gluon_resnet152_v1d,69.95,30.05,88.47,11.53,60.21,224,0.875,bicubic +gluon_resnext101_64x4d,69.69,30.31,88.26,11.74,83.46,224,0.875,bicubic +ssl_resnext50_32x4d,69.69,30.31,89.42,10.58,25.03,224,0.875,bilinear +ens_adv_inception_resnet_v2,69.52,30.48,88.5,11.5,55.84,299,0.8975,bicubic +inception_v4,69.35,30.65,88.78,11.22,42.68,299,0.875,bicubic +seresnext101_32x4d,69.34,30.66,88.05,11.95,48.96,224,0.875,bilinear +gluon_resnet152_v1c,69.13,30.87,87.89,12.11,60.21,224,0.875,bicubic +mixnet_xl,69,31,88.19,11.81,11.9,224,0.875,bicubic +gluon_resnet101_v1d,68.99,31.01,88.08,11.92,44.57,224,0.875,bicubic +gluon_xception65,68.98,31.02,88.32,11.68,39.92,299,0.875,bicubic +gluon_resnext101_32x4d,68.96,31.04,88.34,11.66,44.18,224,0.875,bicubic +tf_efficientnet_b2_ap,68.93,31.07,88.34,11.66,9.11,260,0.89,bicubic +gluon_resnet152_v1b,68.81,31.19,87.71,12.29,60.19,224,0.875,bicubic +dpn131,68.76,31.24,87.48,12.52,79.25,224,0.875,bicubic resnext50d_32x4d,68.75,31.25,88.31,11.69,25.05,224,0.875,bicubic tf_efficientnet_b2,68.75,31.25,87.95,12.05,9.11,260,0.89,bicubic -dpn131,68.76,31.24,87.48,12.52,79.25,224,0.875,bicubic -gluon_resnet152_v1b,68.81,31.19,87.71,12.29,60.19,224,0.875,bicubic -gluon_resnext101_32x4d,68.96,31.04,88.34,11.66,44.18,224,0.875,bicubic -gluon_xception65,68.98,31.02,88.32,11.68,39.92,299,0.875,bicubic -gluon_resnet101_v1d,68.99,31.01,88.08,11.92,44.57,224,0.875,bicubic -mixnet_xl,69,31,88.19,11.81,11.9,224,0.875,bicubic -gluon_resnet152_v1c,69.13,30.87,87.89,12.11,60.21,224,0.875,bicubic -seresnext101_32x4d,69.34,30.66,88.05,11.95,48.96,224,0.875,bilinear -inception_v4,69.35,30.65,88.78,11.22,42.68,299,0.875,bicubic -ens_adv_inception_resnet_v2,69.52,30.48,88.5,11.5,55.84,299,0.8975,bicubic -gluon_resnext101_64x4d,69.69,30.31,88.26,11.74,83.46,224,0.875,bicubic -gluon_resnet152_v1d,69.95,30.05,88.47,11.53,60.21,224,0.875,bicubic -gluon_seresnext101_32x4d,70.01,29.99,88.91,11.09,48.96,224,0.875,bicubic -inception_resnet_v2,70.12,29.88,88.68,11.32,55.84,299,0.8975,bicubic -gluon_resnet152_v1s,70.32,29.68,88.87,11.13,60.32,224,0.875,bicubic -gluon_seresnext101_64x4d,70.44,29.56,89.35,10.65,88.23,224,0.875,bicubic -senet154,70.48,29.52,88.99,11.01,115.09,224,0.875,bilinear -gluon_senet154,70.6,29.4,88.92,11.08,115.09,224,0.875,bicubic -tf_efficientnet_b3,70.62,29.38,89.44,10.56,12.23,300,0.904,bicubic -tf_efficientnet_b4,72.28,27.72,90.6,9.4,19.34,380,0.922,bicubic -nasnetalarge,72.31,27.69,90.51,9.49,88.75,331,0.875,bicubic -pnasnet5large,72.37,27.63,90.26,9.74,86.06,331,0.875,bicubic -tf_efficientnet_b5,73.37,26.63,91.21,8.79,30.39,456,0.934,bicubic -ig_resnext101_32x8d,73.66,26.34,92.15,7.85,88.79,224,0.875,bilinear -tf_efficientnet_b6,73.9,26.1,91.75,8.25,43.04,528,0.942,bicubic -tf_efficientnet_b7,74.04,25.96,91.86,8.14,66.35,600,0.949,bicubic -ig_resnext101_32x16d,75.71,24.29,92.9,7.1,194.03,224,0.875,bilinear -ig_resnext101_32x32d,76.84,23.16,93.19,6.81,468.53,224,0.875,bilinear -ig_resnext101_32x48d,76.87,23.13,93.32,6.68,828.41,224,0.875,bilinear +gluon_resnet101_v1s,68.72,31.28,87.9,12.1,44.67,224,0.875,bicubic +dpn107,68.71,31.29,88.13,11.87,86.92,224,0.875,bicubic +gluon_seresnext50_32x4d,68.67,31.33,88.32,11.68,27.56,224,0.875,bicubic +hrnet_w64,68.63,31.37,88.07,11.93,128.06,224,0.875,bilinear +dpn98,68.58,31.42,87.66,12.34,61.57,224,0.875,bicubic +ssl_resnet50,68.42,31.58,88.58,11.42,25.56,224,0.875,bilinear +dla102x2,68.34,31.66,87.87,12.13,41.75,224,0.875,bilinear +gluon_resnext50_32x4d,68.28,31.72,87.32,12.68,25.03,224,0.875,bicubic +tf_efficientnet_el,68.18,31.82,88.35,11.65,10.59,300,0.904,bicubic +dpn92,68.01,31.99,87.59,12.41,37.67,224,0.875,bicubic +gluon_resnet50_v1d,67.91,32.09,87.12,12.88,25.58,224,0.875,bicubic +seresnext50_32x4d,67.87,32.13,87.62,12.38,27.56,224,0.875,bilinear +resnext101_32x8d,67.85,32.15,87.48,12.52,88.79,224,0.875,bilinear +efficientnet_b2,67.8,32.2,88.2,11.8,9.11,260,0.89,bicubic +hrnet_w44,67.77,32.23,87.53,12.47,67.06,224,0.875,bilinear +hrnet_w48,67.77,32.23,87.42,12.58,77.47,224,0.875,bilinear +xception,67.67,32.33,87.57,12.43,22.86,299,0.8975,bicubic +dla169,67.61,32.39,87.56,12.44,53.99,224,0.875,bilinear +gluon_inception_v3,67.59,32.41,87.46,12.54,23.83,299,0.875,bicubic +hrnet_w40,67.59,32.41,87.13,12.87,57.56,224,0.875,bilinear +gluon_resnet101_v1c,67.56,32.44,87.16,12.84,44.57,224,0.875,bicubic +efficientnet_b1,67.55,32.45,87.29,12.71,7.79,240,0.882,bicubic +seresnet152,67.55,32.45,87.39,12.61,66.82,224,0.875,bilinear +res2net50_26w_8s,67.53,32.47,87.27,12.73,48.4,224,0.875,bilinear +tf_efficientnet_b1_ap,67.52,32.48,87.77,12.23,7.79,240,0.882,bicubic +tf_efficientnet_cc_b1_8e,67.48,32.52,87.31,12.69,39.72,240,0.882,bicubic +gluon_resnet101_v1b,67.45,32.55,87.23,12.77,44.55,224,0.875,bicubic +res2net101_26w_4s,67.45,32.55,87.01,12.99,45.21,224,0.875,bilinear +seresnet101,67.15,32.85,87.05,12.95,49.33,224,0.875,bilinear +gluon_resnet50_v1s,67.1,32.9,86.86,13.14,25.68,224,0.875,bicubic +dla60x,67.08,32.92,87.17,12.83,17.65,224,0.875,bilinear +dla60_res2net,67.03,32.97,87.14,12.86,21.15,224,0.875,bilinear +resnet152,67.02,32.98,87.57,12.43,60.19,224,0.875,bilinear +dla102x,67,33,86.77,13.23,26.77,224,0.875,bilinear +mixnet_l,66.97,33.03,86.94,13.06,7.33,224,0.875,bicubic +res2net50_26w_6s,66.91,33.09,86.9,13.1,37.05,224,0.875,bilinear +tf_efficientnet_b1,66.89,33.11,87.04,12.96,7.79,240,0.882,bicubic +resnext50_32x4d,66.88,33.12,86.36,13.64,25.03,224,0.875,bicubic +tf_efficientnet_em,66.87,33.13,86.98,13.02,6.9,240,0.882,bicubic +resnet50,66.81,33.19,87,13,25.56,224,0.875,bicubic +hrnet_w32,66.79,33.21,87.29,12.71,41.23,224,0.875,bilinear +tf_mixnet_l,66.78,33.22,86.46,13.54,7.33,224,0.875,bicubic +hrnet_w30,66.76,33.24,86.79,13.21,37.71,224,0.875,bilinear +wide_resnet101_2,66.68,33.32,87.04,12.96,126.89,224,0.875,bilinear +wide_resnet50_2,66.65,33.35,86.81,13.19,68.88,224,0.875,bilinear +dla60_res2next,66.64,33.36,87.02,12.98,17.33,224,0.875,bilinear +adv_inception_v3,66.6,33.4,86.56,13.44,23.83,299,0.875,bicubic +dla102,66.55,33.45,86.91,13.09,33.73,224,0.875,bilinear +gluon_resnet50_v1c,66.54,33.46,86.16,13.84,25.58,224,0.875,bicubic +tf_inception_v3,66.42,33.58,86.68,13.32,23.83,299,0.875,bicubic +seresnet50,66.24,33.76,86.33,13.67,28.09,224,0.875,bilinear +tf_efficientnet_cc_b0_8e,66.21,33.79,86.22,13.78,24.01,224,0.875,bicubic +tv_resnext50_32x4d,66.18,33.82,86.04,13.96,25.03,224,0.875,bilinear +res2net50_26w_4s,66.17,33.83,86.6,13.4,25.7,224,0.875,bilinear +inception_v3,66.12,33.88,86.34,13.66,27.16,299,0.875,bicubic +gluon_resnet50_v1b,66.04,33.96,86.27,13.73,25.56,224,0.875,bicubic +res2net50_14w_8s,66.02,33.98,86.24,13.76,25.06,224,0.875,bilinear +densenet161,65.85,34.15,86.46,13.54,28.68,224,0.875,bicubic +res2next50,65.85,34.15,85.83,14.17,24.67,224,0.875,bilinear +resnet101,65.68,34.32,85.98,14.02,44.55,224,0.875,bilinear +dpn68b,65.6,34.4,85.94,14.06,12.61,224,0.875,bicubic +tf_efficientnet_b0_ap,65.49,34.51,85.55,14.45,5.29,224,0.875,bicubic +res2net50_48w_2s,65.32,34.68,85.96,14.04,25.29,224,0.875,bilinear +densenet201,65.28,34.72,85.67,14.33,20.01,224,0.875,bicubic +tf_efficientnet_es,65.24,34.76,85.54,14.46,5.44,224,0.875,bicubic +dla60,65.22,34.78,85.75,14.25,22.33,224,0.875,bilinear +tf_efficientnet_cc_b0_4e,65.13,34.87,85.13,14.87,13.31,224,0.875,bicubic +seresnext26_32x4d,65.04,34.96,85.65,14.35,16.79,224,0.875,bicubic +hrnet_w18,64.91,35.09,85.75,14.25,21.3,224,0.875,bilinear +densenet169,64.78,35.22,85.25,14.75,14.15,224,0.875,bicubic +mixnet_m,64.69,35.31,85.47,14.53,5.01,224,0.875,bicubic +resnet26d,64.63,35.37,85.12,14.88,16.01,224,0.875,bicubic +efficientnet_b0,64.58,35.42,85.89,14.11,5.29,224,0.875,bicubic +tf_efficientnet_b0,64.29,35.71,85.25,14.75,5.29,224,0.875,bicubic +tf_mixnet_m,64.27,35.73,85.09,14.91,5.01,224,0.875,bicubic +dpn68,64.22,35.78,85.18,14.82,12.61,224,0.875,bicubic +tf_mixnet_s,63.59,36.41,84.27,15.73,4.13,224,0.875,bicubic +resnet26,63.45,36.55,84.27,15.73,16,224,0.875,bicubic +mixnet_s,63.38,36.62,84.71,15.29,4.13,224,0.875,bicubic +tv_resnet50,63.33,36.67,84.65,15.35,25.56,224,0.875,bilinear +mobilenetv3_rw,63.23,36.77,84.52,15.48,5.48,224,0.875,bicubic +semnasnet_100,63.12,36.88,84.53,15.47,3.89,224,0.875,bicubic +densenet121,62.94,37.06,84.26,15.74,7.98,224,0.875,bicubic +seresnet34,62.89,37.11,84.22,15.78,21.96,224,0.875,bilinear +hrnet_w18_small_v2,62.83,37.17,83.97,16.03,15.6,224,0.875,bilinear +resnet34,62.82,37.18,84.12,15.88,21.8,224,0.875,bilinear +swsl_resnet18,62.73,37.27,84.3,15.7,11.69,224,0.875,bilinear +gluon_resnet34_v1b,62.56,37.44,84,16,21.8,224,0.875,bicubic +dla34,62.51,37.49,83.92,16.08,15.78,224,0.875,bilinear +tf_mobilenetv3_large_100,62.47,37.53,83.96,16.04,5.48,224,0.875,bilinear +fbnetc_100,62.43,37.57,83.39,16.61,5.57,224,0.875,bilinear +mnasnet_100,61.91,38.09,83.71,16.29,4.38,224,0.875,bicubic +ssl_resnet18,61.49,38.51,83.33,16.67,11.69,224,0.875,bilinear +spnasnet_100,61.21,38.79,82.77,17.23,4.42,224,0.875,bilinear +tv_resnet34,61.2,38.8,82.72,17.28,21.8,224,0.875,bilinear +tf_mobilenetv3_large_075,60.38,39.62,81.96,18.04,3.99,224,0.875,bilinear +seresnet18,59.81,40.19,81.68,18.32,11.78,224,0.875,bicubic +tf_mobilenetv3_large_minimal_100,59.07,40.93,81.14,18.86,3.92,224,0.875,bilinear +hrnet_w18_small,58.97,41.03,81.34,18.66,13.19,224,0.875,bilinear +gluon_resnet18_v1b,58.32,41.68,80.96,19.04,11.69,224,0.875,bicubic +resnet18,57.18,42.82,80.19,19.81,11.69,224,0.875,bilinear +dla60x_c,56.02,43.98,78.96,21.04,1.34,224,0.875,bilinear +tf_mobilenetv3_small_100,54.51,45.49,77.08,22.92,2.54,224,0.875,bilinear +dla46x_c,53.08,46.92,76.84,23.16,1.08,224,0.875,bilinear +dla46_c,52.2,47.8,75.68,24.32,1.31,224,0.875,bilinear +tf_mobilenetv3_small_075,52.15,47.85,75.46,24.54,2.04,224,0.875,bilinear +tf_mobilenetv3_small_minimal_100,49.53,50.47,73.05,26.95,2.04,224,0.875,bilinear diff --git a/sotabench.py b/sotabench.py index cd25412f..5f6345f5 100644 --- a/sotabench.py +++ b/sotabench.py @@ -294,6 +294,17 @@ model_list = [ _entry('res2next50', 'Res2NeXt-50', '1904.01169'), _entry('dla60_res2net', 'Res2Net-DLA-60', '1904.01169'), _entry('dla60_res2next', 'Res2NeXt-DLA-60', '1904.01169'), + + ## HRNet official impl weights + _entry('hrnet_w18_small', 'HRNet-W18-C-Small-V1', '1908.07919'), + _entry('hrnet_w18_small_v2', 'HRNet-W18-C-Small-V2', '1908.07919'), + _entry('hrnet_w18', 'HRNet-W18-C', '1908.07919'), + _entry('hrnet_w30', 'HRNet-W30-C', '1908.07919'), + _entry('hrnet_w32', 'HRNet-W32-C', '1908.07919'), + _entry('hrnet_w40', 'HRNet-W40-C', '1908.07919'), + _entry('hrnet_w44', 'HRNet-W44-C', '1908.07919'), + _entry('hrnet_w48', 'HRNet-W48-C', '1908.07919'), + _entry('hrnet_w64', 'HRNet-W64-C', '1908.07919'), ] for m in model_list: diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 1e49f6df..d1ac5857 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import select_adaptive_pool2d +from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import re @@ -88,8 +88,8 @@ class DenseNet(nn.Module): def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, in_chans=3, global_pool='avg'): - self.global_pool = global_pool self.num_classes = num_classes + self.drop_rate = drop_rate super(DenseNet, self).__init__() # First convolution @@ -117,32 +117,31 @@ class DenseNet(nn.Module): self.features.add_module('norm5', nn.BatchNorm2d(num_features)) # Linear layer - self.classifier = nn.Linear(num_features, num_classes) - self.num_features = num_features + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) def get_classifier(self): return self.classifier def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = global_pool self.num_classes = num_classes - del self.classifier - if num_classes: - self.classifier = nn.Linear(self.num_features, num_classes) - else: - self.classifier = None + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None - def forward_features(self, x, pool=True): + def forward_features(self, x): x = self.features(x) x = F.relu(x, inplace=True) - if pool: - x = select_adaptive_pool2d(x, self.global_pool) - x = x.view(x.size(0), -1) return x def forward(self, x): - return self.classifier(self.forward_features(x, pool=True)) + x = self.forward_features(x) + x = self.global_pool(x).flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x def _filter_pretrained(state_dict): diff --git a/timm/models/dla.py b/timm/models/dla.py index 255a389d..cd560f44 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -276,8 +276,7 @@ class DLA(nn.Module): self.num_features = channels[-1] self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, - kernel_size=1, stride=1, padding=0, bias=True) + self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -302,15 +301,14 @@ class DLA(nn.Module): return self.fc def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - del self.fc + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if num_classes: - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True) else: self.fc = None - def forward_features(self, x, pool=True): + def forward_features(self, x): x = self.base_layer(x) x = self.level0(x) x = self.level1(x) @@ -318,17 +316,15 @@ class DLA(nn.Module): x = self.level3(x) x = self.level4(x) x = self.level5(x) - if pool: - x = self.global_pool(x) return x def forward(self, x): x = self.forward_features(x) + x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) - x = x.flatten(1) - return x + return x.flatten(1) @register_model diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 1496a067..7f46e8e0 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -16,7 +16,7 @@ from collections import OrderedDict from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import select_adaptive_pool2d +from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD @@ -160,7 +160,6 @@ class DPN(nn.Module): super(DPN, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate - self.global_pool = global_pool self.b = b bw_factor = 1 if small else 4 @@ -218,32 +217,32 @@ class DPN(nn.Module): self.features = nn.Sequential(blocks) # Using 1x1 conv for the FC layer to allow the extra pooling scheme - self.classifier = nn.Conv2d(in_chs, num_classes, kernel_size=1, bias=True) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.classifier = nn.Conv2d( + self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True) def get_classifier(self): return self.classifier def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = global_pool - del self.classifier + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if num_classes: - self.classifier = nn.Conv2d(self.num_features, num_classes, kernel_size=1, bias=True) + self.classifier = nn.Conv2d( + self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True) else: self.classifier = None - def forward_features(self, x, pool=True): - x = self.features(x) - if pool: - x = select_adaptive_pool2d(x, pool_type=self.global_pool) - return x + def forward_features(self, x): + return self.features(x) def forward(self, x): x = self.forward_features(x) + x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) out = self.classifier(x) - return out.view(out.size(0), -1) + return out.flatten(1) @register_model diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 9163a023..ac3c244c 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -211,8 +211,7 @@ class EfficientNet(nn.Module): def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., - se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, - global_pool='avg', weight_init='goog'): + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): super(EfficientNet, self).__init__() norm_kwargs = norm_kwargs or {} @@ -245,11 +244,7 @@ class EfficientNet(nn.Module): # Classifier self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) - for m in self.modules(): - if weight_init == 'goog': - efficientnet_init_goog(m) - else: - efficientnet_init_default(m) + efficientnet_init_weights(self) def as_sequential(self): layers = [self.conv_stem, self.bn1, self.act1] @@ -262,14 +257,10 @@ class EfficientNet(nn.Module): return self.classifier def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - del self.classifier - if num_classes: - self.classifier = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) - else: - self.classifier = None + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None def forward_features(self, x): x = self.conv_stem(x) @@ -300,7 +291,7 @@ class EfficientNetFeatures(nn.Module): def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., - se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(EfficientNetFeatures, self).__init__() norm_kwargs = norm_kwargs or {} @@ -326,12 +317,7 @@ class EfficientNetFeatures(nn.Module): self.feature_info = builder.features # builder provides info about feature channels for each block self._in_chs = builder.in_chs - for m in self.modules(): - if weight_init == 'goog': - efficientnet_init_goog(m) - else: - efficientnet_init_default(m) - + efficientnet_init_weights(self) if _DEBUG: for k, v in self.feature_info.items(): print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index c2b3a801..db6f54f9 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -358,9 +358,13 @@ class EfficientNetBuilder: return stages -def efficientnet_init_goog(m, n=''): - # weight init as per Tensorflow Official impl - # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py +def _init_weight_goog(m, n=''): + """ Weight initialization as per Tensorflow official implementations. + + Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: + * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + """ if isinstance(m, CondConv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels init_weight_fn = get_condconv_initializer( @@ -386,7 +390,8 @@ def efficientnet_init_goog(m, n=''): m.bias.data.zero_() -def efficientnet_init_default(m, n=''): +def _init_weight_default(m, n=''): + """ Basic ResNet (Kaiming) style weight init""" if isinstance(m, CondConv2d): init_fn = get_condconv_initializer(partial( nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) @@ -400,3 +405,8 @@ def efficientnet_init_default(m, n=''): nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') +def efficientnet_init_weights(model: nn.Module, init_fn=None): + init_fn = init_fn or _init_weight_goog + for n, m in model.named_modules(): + init_fn(m, n) + diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 9393e5ba..5a35d226 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -13,7 +13,7 @@ from collections import OrderedDict from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import select_adaptive_pool2d +from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['Xception65', 'Xception71'] @@ -185,7 +185,6 @@ class Xception65(nn.Module): super(Xception65, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate - self.global_pool = global_pool norm_kwargs = norm_kwargs if norm_kwargs is not None else {} if output_stride == 32: entry_block3_stride = 2 @@ -249,21 +248,18 @@ class Xception65(nn.Module): 1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs) - self.fc = nn.Linear(in_features=self.num_features, out_features=num_classes) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) def get_classifier(self): return self.fc def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = global_pool - del self.fc - if num_classes: - self.fc = nn.Linear(self.num_features, num_classes) - else: - self.fc = None + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None - def forward_features(self, x, pool=True): + def forward_features(self, x): # Entry flow x = self.conv1(x) x = self.bn1(x) @@ -299,14 +295,11 @@ class Xception65(nn.Module): x = self.conv5(x) x = self.bn5(x) x = self.relu(x) - - if pool: - x = select_adaptive_pool2d(x, pool_type=self.global_pool) - x = x.view(x.size(0), -1) return x def forward(self, x): x = self.forward_features(x) + x = self.global_pool(x).flatten(1) if self.drop_rate: F.dropout(x, self.drop_rate, training=self.training) x = self.fc(x) @@ -322,7 +315,6 @@ class Xception71(nn.Module): super(Xception71, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate - self.global_pool = global_pool norm_kwargs = norm_kwargs if norm_kwargs is not None else {} if output_stride == 32: entry_block3_stride = 2 @@ -393,21 +385,18 @@ class Xception71(nn.Module): 1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs) - self.fc = nn.Linear(in_features=self.num_features, out_features=num_classes) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) def get_classifier(self): return self.fc def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = global_pool - del self.fc - if num_classes: - self.fc = nn.Linear(self.num_features, num_classes) - else: - self.fc = None + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None - def forward_features(self, x, pool=True): + def forward_features(self, x): # Entry flow x = self.conv1(x) x = self.bn1(x) @@ -443,14 +432,11 @@ class Xception71(nn.Module): x = self.conv5(x) x = self.bn5(x) x = self.relu(x) - - if pool: - x = select_adaptive_pool2d(x, pool_type=self.global_pool) - x = x.view(x.size(0), -1) return x def forward(self, x): x = self.forward_features(x) + x = self.global_pool(x).flatten(1) if self.drop_rate: F.dropout(x, self.drop_rate, training=self.training) x = self.fc(x) diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 59ded4ab..99a2bd91 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -17,20 +17,18 @@ import os import logging import functools -import numpy as np - import torch import torch.nn as nn import torch._utils import torch.nn.functional as F +from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE from .registry import register_model from .helpers import load_pretrained -from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -BN_MOMENTUM = 0.1 +_BN_MOMENTUM = 0.1 logger = logging.getLogger(__name__) @@ -46,380 +44,353 @@ def _cfg(url='', **kwargs): default_cfgs = { - 'hrnet_w18_small': _cfg(url=''), - 'hrnet_w18_small_v2': _cfg(url=''), - 'hrnet_w18': _cfg(url=''), - 'hrnet_w30': _cfg(url=''), - 'hrnet_w32': _cfg(url=''), - 'hrnet_w40': _cfg(url=''), - 'hrnet_w44': _cfg(url=''), - 'hrnet_w48': _cfg(url=''), + 'hrnet_w18_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v1-f460c6bc.pth'), + 'hrnet_w18_small_v2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v2-4c50a8cb.pth'), + 'hrnet_w18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth'), + 'hrnet_w30': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w30-8d7f8dab.pth'), + 'hrnet_w32': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w32-90d8c5fb.pth'), + 'hrnet_w40': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w40-7cd397a4.pth'), + 'hrnet_w44': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w44-c9ac8c18.pth'), + 'hrnet_w48': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w48-abd2e6ab.pth'), + 'hrnet_w64': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w64-b47cc881.pth'), } -cfg_cls_hrnet_w18_small = dict( - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(1,), - NUM_CHANNELS=(32,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2), - NUM_CHANNELS=(16, 32), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=1, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2, 2), - NUM_CHANNELS=(16, 32, 64), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=1, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2, 2, 2), - NUM_CHANNELS=(16, 32, 64, 128), - FUSE_METHOD='SUM', - ), -) - - -cfg_cls_hrnet_w18_small_v2 = dict( - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(2,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2), - NUM_CHANNELS=(18, 36), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=3, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2, 2), - NUM_CHANNELS=(18, 36, 72), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=2, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2, 2, 2), - NUM_CHANNELS=(18, 36, 72, 144), - FUSE_METHOD='SUM', - ), -) - -cfg_cls_hrnet_w18 = dict( - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(18, 36), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(18, 36, 72), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(18, 36, 72, 144), - FUSE_METHOD='SUM', - ), -) - - -cfg_cls_hrnet_w30 = dict( - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(30, 60), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(30, 60, 120), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(30, 60, 120, 240), - FUSE_METHOD='SUM', - ), -) - - -cfg_cls_hrnet_w32 = dict( - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(32, 64), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(32, 64, 128), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(32, 64, 128, 256), - FUSE_METHOD='SUM', - ), -) - -cfg_cls_hrnet_w40 = dict( - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(40, 80), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(40, 80, 160), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(40, 80, 160, 320), - FUSE_METHOD='SUM', - ), +cfg_cls = dict( + hrnet_w18_small=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(1,), + NUM_CHANNELS=(32,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(16, 32), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=1, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(16, 32, 64), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=1, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(16, 32, 64, 128), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w18_small_v2 = dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(2,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=3, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=2, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w18 = dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w30 = dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(30, 60), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(30, 60, 120), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(30, 60, 120, 240), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w32 = dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(32, 64), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(32, 64, 128), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(32, 64, 128, 256), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w40 = dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(40, 80), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(40, 80, 160), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(40, 80, 160, 320), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w44 = dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(44, 88), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(44, 88, 176), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(44, 88, 176, 352), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w48 = dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(48, 96), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(48, 96, 192), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(48, 96, 192, 384), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w64 = dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(64, 128), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(64, 128, 256), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(64, 128, 256, 512), + FUSE_METHOD='SUM', + ), + ) ) -cfg_cls_hrnet_w44 = dict( - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(44, 88), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(44, 88, 176), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(44, 88, 176, 352), - FUSE_METHOD='SUM', - ), -) - - -cfg_cls_hrnet_w48 = dict( - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(48, 96), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(48, 96, 192), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(48, 96, 192, 384), - FUSE_METHOD='SUM', - ), -) - - -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(BasicBlock, self).__init__() - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) - self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) - self.conv3 = nn.Conv2d( - planes, planes * self.expansion, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d( - planes * self.expansion, momentum=BN_MOMENTUM) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - class HighResolutionModule(nn.Module): def __init__(self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True): @@ -466,11 +437,10 @@ class HighResolutionModule(nn.Module): nn.Conv2d( self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM), ) - layers = [] - layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) + layers = [block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)] self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) @@ -479,7 +449,6 @@ class HighResolutionModule(nn.Module): def _make_branches(self, num_branches, block, num_blocks, num_channels): branches = [] - for i in range(num_branches): branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) @@ -498,7 +467,7 @@ class HighResolutionModule(nn.Module): if j > i: fuse_layer.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), - nn.BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM), + nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM), nn.Upsample(scale_factor=2 ** (j - i), mode='nearest'))) elif j == i: fuse_layer.append(None) @@ -509,12 +478,12 @@ class HighResolutionModule(nn.Module): num_outchannels_conv3x3 = num_inchannels[i] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), - nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM))) + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM))) else: num_outchannels_conv3x3 = num_inchannels[j] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), - nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM), nn.ReLU(False))) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) @@ -552,13 +521,16 @@ blocks_dict = { class HighResolutionNet(nn.Module): - def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg'): + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0): super(HighResolutionNet, self).__init__() - - self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) - self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.num_classes = num_classes + self.drop_rate = drop_rate + + stem_width = cfg['STEM_WIDTH'] + self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM) + self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.stage1_cfg = cfg['STAGE1'] @@ -590,9 +562,10 @@ class HighResolutionNet(nn.Module): self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) # Classification Head + self.num_features = 2048 self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels) - - self.classifier = nn.Linear(2048, num_classes) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.init_weights() @@ -616,7 +589,7 @@ class HighResolutionNet(nn.Module): downsamp_module = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1), - nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), + nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True) ) downsamp_modules.append(downsamp_module) @@ -625,9 +598,9 @@ class HighResolutionNet(nn.Module): final_layer = nn.Sequential( nn.Conv2d( in_channels=head_channels[3] * head_block.expansion, - out_channels=2048, kernel_size=1, stride=1, padding=0 + out_channels=self.num_features, kernel_size=1, stride=1, padding=0 ), - nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), + nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True) ) @@ -643,7 +616,7 @@ class HighResolutionNet(nn.Module): if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append(nn.Sequential( nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), - nn.BatchNorm2d(num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM), nn.ReLU(inplace=True))) else: transition_layers.append(None) @@ -654,7 +627,7 @@ class HighResolutionNet(nn.Module): outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels conv3x3s.append(nn.Sequential( nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), - nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.BatchNorm2d(outchannels, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True))) transition_layers.append(nn.Sequential(*conv3x3s)) @@ -665,11 +638,10 @@ class HighResolutionNet(nn.Module): if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + nn.BatchNorm2d(planes * block.expansion, momentum=_BN_MOMENTUM), ) - layers = [] - layers.append(block(inplanes, planes, stride, downsample)) + layers = [block(inplanes, planes, stride, downsample)] inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(inplanes, planes)) @@ -699,8 +671,7 @@ class HighResolutionNet(nn.Module): return nn.Sequential(*modules), num_inchannels - def init_weights(self, pretrained='', ): - logger.info('=> init weights from normal distribution') + def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( @@ -709,7 +680,16 @@ class HighResolutionNet(nn.Module): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) - def forward(self, x): + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + + def forward_features(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -746,124 +726,79 @@ class HighResolutionNet(nn.Module): y = self.incre_modules[0](y_list[0]) for i in range(len(self.downsamp_modules)): y = self.incre_modules[i + 1](y_list[i + 1]) + self.downsamp_modules[i](y) - y = self.final_layer(y) - - if torch._C._get_tracing_state(): - y = y.flatten(start_dim=2).mean(dim=2) - else: - y = F.avg_pool2d(y, kernel_size=y.size()[2:]).view(y.size(0), -1) - - y = self.classifier(y) - return y - - -@register_model -def hrnet_w18_small(pretrained=True, **kwargs): - default_cfg = default_cfgs['hrnet_w18_small'] - model = HighResolutionNet(cfg_cls_hrnet_w18_small, **kwargs) - model.default_cfg = default_cfg + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x).flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x + + +def _create_model(variant, pretrained, model_kwargs): + if model_kwargs.pop('features_only', False): + assert False, 'Not Implemented' # TODO + load_strict = False + model_kwargs.pop('num_classes', 0) + model_class = HighResolutionNet + else: + load_strict = True + model_class = HighResolutionNet + + model = model_class(cfg_cls[variant], **model_kwargs) + model.default_cfg = default_cfgs[variant] if pretrained: load_pretrained( model, - default_cfg, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3)) + num_classes=model_kwargs.get('num_classes', 0), + in_chans=model_kwargs.get('in_chans', 3), + strict=load_strict) return model +@register_model +def hrnet_w18_small(pretrained=True, **kwargs): + return _create_model('hrnet_w18_small', pretrained, kwargs) + + @register_model def hrnet_w18_small_v2(pretrained=True, **kwargs): - default_cfg = default_cfgs['hrnet_w18_small_v2'] - model = HighResolutionNet(cfg_cls_hrnet_w18_small_v2, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained( - model, - default_cfg, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3)) - return model + return _create_model('hrnet_w18_small_v2', pretrained, kwargs) + @register_model def hrnet_w18(pretrained=True, **kwargs): - default_cfg = default_cfgs['hrnet_w18'] - model = HighResolutionNet(cfg_cls_hrnet_w18, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained( - model, - default_cfg, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3)) - return model + return _create_model('hrnet_w18', pretrained, kwargs) @register_model def hrnet_w30(pretrained=True, **kwargs): - default_cfg = default_cfgs['hrnet_w30'] - model = HighResolutionNet(cfg_cls_hrnet_w30, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained( - model, - default_cfg, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3)) - return model + return _create_model('hrnet_w30', pretrained, kwargs) + @register_model def hrnet_w32(pretrained=True, **kwargs): - default_cfg = default_cfgs['hrnet_w32'] - model = HighResolutionNet(cfg_cls_hrnet_w32, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained( - model, - default_cfg, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3)) - return model + return _create_model('hrnet_w32', pretrained, kwargs) + @register_model def hrnet_w40(pretrained=True, **kwargs): - default_cfg = default_cfgs['hrnet_w40'] - model = HighResolutionNet(cfg_cls_hrnet_w40, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained( - model, - default_cfg, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3)) - return model + return _create_model('hrnet_w40', pretrained, kwargs) @register_model def hrnet_w44(pretrained=True, **kwargs): - default_cfg = default_cfgs['hrnet_w44'] - model = HighResolutionNet(cfg_cls_hrnet_w44, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained( - model, - default_cfg, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3)) - return model + return _create_model('hrnet_w44', pretrained, kwargs) @register_model def hrnet_w48(pretrained=True, **kwargs): - default_cfg = default_cfgs['hrnet_w48'] - model = HighResolutionNet(cfg_cls_hrnet_w48, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained( - model, - default_cfg, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3)) - return model + return _create_model('hrnet_w48', pretrained, kwargs) + + +@register_model +def hrnet_w64(pretrained=True, **kwargs): + return _create_model('hrnet_w64', pretrained, kwargs) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index fe5679fe..da019075 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import select_adaptive_pool2d +from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD __all__ = ['InceptionResnetV2'] @@ -226,7 +226,6 @@ class InceptionResnetV2(nn.Module): def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'): super(InceptionResnetV2, self).__init__() self.drop_rate = drop_rate - self.global_pool = global_pool self.num_classes = num_classes self.num_features = 1536 @@ -287,22 +286,20 @@ class InceptionResnetV2(nn.Module): ) self.block8 = Block8(noReLU=True) self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC - self.classif = nn.Linear(self.num_features, num_classes) + self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) def get_classifier(self): return self.classif def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = global_pool + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - del self.classif - if num_classes: - self.classif = torch.nn.Linear(self.num_features, num_classes) - else: - self.classif = None + self.classif = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None - def forward_features(self, x, pool=True): + def forward_features(self, x): x = self.conv2d_1a(x) x = self.conv2d_2a(x) x = self.conv2d_2b(x) @@ -318,14 +315,11 @@ class InceptionResnetV2(nn.Module): x = self.repeat_2(x) x = self.block8(x) x = self.conv2d_7b(x) - if pool: - x = select_adaptive_pool2d(x, self.global_pool) - #x = F.avg_pool2d(x, 8, count_include_pad=False) - x = x.view(x.size(0), -1) return x def forward(self, x): - x = self.forward_features(x, pool=True) + x = self.forward_features(x) + x = self.global_pool(x).flatten(1) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.classif(x) diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index e389eb88..8c3dee86 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import select_adaptive_pool2d +from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD __all__ = ['InceptionV4'] @@ -244,7 +244,6 @@ class InceptionV4(nn.Module): def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'): super(InceptionV4, self).__init__() self.drop_rate = drop_rate - self.global_pool = global_pool self.num_classes = num_classes self.num_features = 1536 @@ -272,25 +271,24 @@ class InceptionV4(nn.Module): Inception_C(), Inception_C(), ) - self.last_linear = nn.Linear(self.num_features, num_classes) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) def get_classifier(self): - return self.classif + return self.last_linear def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = global_pool + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - self.classif = nn.Linear(self.num_features, num_classes) + self.last_linear = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None - def forward_features(self, x, pool=True): - x = self.features(x) - if pool: - x = select_adaptive_pool2d(x, self.global_pool) - x = x.view(x.size(0), -1) - return x + def forward_features(self, x): + return self.features(x) def forward(self, x): x = self.forward_features(x) + x = self.global_pool(x).flatten(1) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.last_linear(x) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index a89adea4..a6b67532 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -75,8 +75,7 @@ class MobileNetV3(nn.Module): def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., - se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, - global_pool='avg', weight_init='goog'): + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): super(MobileNetV3, self).__init__() self.num_classes = num_classes @@ -107,11 +106,7 @@ class MobileNetV3(nn.Module): # Classifier self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) - for m in self.modules(): - if weight_init == 'goog': - efficientnet_init_goog(m) - else: - efficientnet_init_default(m) + efficientnet_init_weights(self) def as_sequential(self): layers = [self.conv_stem, self.bn1, self.act1] @@ -126,12 +121,8 @@ class MobileNetV3(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - del self.classifier - if num_classes: - self.classifier = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) - else: - self.classifier = None + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) if self.num_classes else None def forward_features(self, x): x = self.conv_stem(x) @@ -161,7 +152,7 @@ class MobileNetV3Features(nn.Module): def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., se_kwargs=None, - norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(MobileNetV3Features, self).__init__() norm_kwargs = norm_kwargs or {} @@ -187,12 +178,7 @@ class MobileNetV3Features(nn.Module): self.feature_info = builder.features # builder provides info about feature channels for each block self._in_chs = builder.in_chs - for m in self.modules(): - if weight_init == 'goog': - efficientnet_init_goog(m) - else: - efficientnet_init_default(m) - + efficientnet_init_weights(self) if _DEBUG: for k, v in self.feature_info.items(): print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 9caee809..009c62d3 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -556,8 +556,18 @@ class NASNetALarge(nn.Module): self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - def forward_features(self, input, pool=True): - x_conv0 = self.conv0(input) + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + del self.last_linear + self.last_linear = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + + def forward_features(self, x): + x_conv0 = self.conv0(x) x_stem_0 = self.cell_stem_0(x_conv0) x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) @@ -586,13 +596,11 @@ class NASNetALarge(nn.Module): x_cell_16 = self.cell_16(x_cell_15, x_cell_14) x_cell_17 = self.cell_17(x_cell_16, x_cell_15) x = self.relu(x_cell_17) - if pool: - x = self.global_pool(x) - x = x.view(x.size(0), -1) return x - def forward(self, input): - x = self.forward_features(input) + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x).flatten(1) if self.drop_rate > 0: x = F.dropout(x, self.drop_rate, training=self.training) x = self.last_linear(x) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index e04a2b1f..396e6157 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -355,7 +355,7 @@ class PNASNet5Large(nn.Module): else: self.last_linear = None - def forward_features(self, x, pool=True): + def forward_features(self, x): x_conv_0 = self.conv_0(x) x_stem_0 = self.cell_stem_0(x_conv_0) x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) @@ -372,13 +372,11 @@ class PNASNet5Large(nn.Module): x_cell_10 = self.cell_10(x_cell_8, x_cell_9) x_cell_11 = self.cell_11(x_cell_9, x_cell_10) x = self.relu(x_cell_11) - if pool: - x = self.global_pool(x) - x = x.view(x.size(0), -1) return x - def forward(self, input): - x = self.forward_features(input) + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x).flatten(1) if self.drop_rate > 0: x = F.dropout(x, self.drop_rate, training=self.training) x = self.last_linear(x) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index c7d80dba..b90bb9d5 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -17,7 +17,7 @@ from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -__all__ = ['ResNet'] # model_registry will add each entrypoint fn to this +__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): @@ -374,12 +374,9 @@ class ResNet(nn.Module): self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes del self.fc - if num_classes: - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - else: - self.fc = None + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None - def forward_features(self, x, pool=True): + def forward_features(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -389,14 +386,11 @@ class ResNet(nn.Module): x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) - - if pool: - x = self.global_pool(x) - x = x.view(x.size(0), -1) return x def forward(self, x): x = self.forward_features(x) + x = self.global_pool(x).flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) diff --git a/timm/models/senet.py b/timm/models/senet.py index 0fbcfb86..90ef5ae1 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -274,6 +274,7 @@ class SENet(nn.Module): super(SENet, self).__init__() self.inplanes = inplanes self.num_classes = num_classes + self.drop_rate = drop_rate if input_3x3: layer0_modules = [ ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)), @@ -337,7 +338,6 @@ class SENet(nn.Module): downsample_padding=downsample_padding ) self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.drop_rate = drop_rate self.num_features = 512 * block.expansion self.last_linear = nn.Linear(self.num_features, num_classes) @@ -366,26 +366,25 @@ class SENet(nn.Module): def get_classifier(self): return self.last_linear - def reset_classifier(self, num_classes): + def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes + self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool) del self.last_linear if num_classes: - self.last_linear = nn.Linear(self.num_features, num_classes) + self.last_linear = nn.Linear(self.num_features * self.avg_pool.feat_mult(), num_classes) else: self.last_linear = None - def forward_features(self, x, pool=True): + def forward_features(self, x): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) - if pool: - x = self.avg_pool(x) - x = x.view(x.size(0), -1) return x def logits(self, x): + x = self.avg_pool(x).flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.last_linear(x) diff --git a/timm/models/test_time_pool.py b/timm/models/test_time_pool.py index 7d5bb571..ce6ddf07 100644 --- a/timm/models/test_time_pool.py +++ b/timm/models/test_time_pool.py @@ -20,7 +20,7 @@ class TestTimePoolHead(nn.Module): self.base.reset_classifier(0) # delete original fc layer def forward(self, x): - x = self.base.forward_features(x, pool=False) + x = self.base.forward_features(x) x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) x = self.fc(x) x = adaptive_avgmax_pool2d(x, 1) diff --git a/timm/models/xception.py b/timm/models/xception.py index e76ed9ff..2dc334fa 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -29,7 +29,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import select_adaptive_pool2d +from .adaptive_avgmax_pool import SelectAdaptivePool2d __all__ = ['Xception'] @@ -163,7 +163,8 @@ class Xception(nn.Module): self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) self.bn4 = nn.BatchNorm2d(self.num_features) - self.fc = nn.Linear(self.num_features, num_classes) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) # #------- init weights -------- for m in self.modules(): @@ -178,15 +179,12 @@ class Xception(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = global_pool + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) del self.fc - if num_classes: - self.fc = nn.Linear(self.num_features, num_classes) - else: - self.fc = None + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None - def forward_features(self, input, pool=True): - x = self.conv1(input) + def forward_features(self, x): + x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -214,14 +212,11 @@ class Xception(nn.Module): x = self.conv4(x) x = self.bn4(x) x = self.relu(x) - - if pool: - x = select_adaptive_pool2d(x, pool_type=self.global_pool) - x = x.view(x.size(0), -1) return x - def forward(self, input): - x = self.forward_features(input) + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x).flatten(1) if self.drop_rate: F.dropout(x, self.drop_rate, training=self.training) x = self.fc(x)