Fix model create fn not passing num_classes through. Fix #135

pull/136/head
Ross Wightman 4 years ago
parent 779cb0fcc0
commit ea300709f0

@ -390,7 +390,7 @@ def pnasnet5large(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
<https://arxiv.org/abs/1712.00559>`_ paper.
"""
default_cfg = default_cfgs['pnasnet5large']
model = PNASNet5Large(num_classes=1000, in_chans=in_chans, **kwargs)
model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

@ -229,7 +229,7 @@ def res2next50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
default_cfg = default_cfgs['res2next50']
res2net_block_args = dict(scale=4)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=4, cardinality=8,
num_classes=1000, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

Loading…
Cancel
Save