|
|
|
@ -365,6 +365,7 @@ class ResNetV2(nn.Module):
|
|
|
|
|
return self.head.fc
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.head = ClassifierHead(
|
|
|
|
|
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
|
|
|
|
|
|
|
|
@ -393,6 +394,7 @@ class ResNetV2(nn.Module):
|
|
|
|
|
self.stem.conv.weight.copy_(stem_conv_w)
|
|
|
|
|
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
|
|
|
|
|
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
|
|
|
|
|
if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
|
|
|
|
|
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
|
|
|
|
|
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
|
|
|
|
|
for i, (sname, stage) in enumerate(self.stages.named_children()):
|
|
|
|
|