Fix model create fn not passing num_classes through. Fix #135

pull/136/head
Ross Wightman 4 years ago
parent 779cb0fcc0
commit ea300709f0

@ -390,7 +390,7 @@ def pnasnet5large(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
<https://arxiv.org/abs/1712.00559>`_ paper. <https://arxiv.org/abs/1712.00559>`_ paper.
""" """
default_cfg = default_cfgs['pnasnet5large'] default_cfg = default_cfgs['pnasnet5large']
model = PNASNet5Large(num_classes=1000, in_chans=in_chans, **kwargs) model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)

@ -229,7 +229,7 @@ def res2next50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
default_cfg = default_cfgs['res2next50'] default_cfg = default_cfgs['res2next50']
res2net_block_args = dict(scale=4) res2net_block_args = dict(scale=4)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=4, cardinality=8, model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=4, cardinality=8,
num_classes=1000, in_chans=in_chans, block_args=res2net_block_args, **kwargs) num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)

Loading…
Cancel
Save