From be7c784d21aa50e039ff3466ca6cf7636e742e71 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 27 Apr 2020 17:50:19 -0700 Subject: [PATCH] Small TResNet simplification, just use SelectAdaptivePool, don't notice any perf difference --- timm/models/tresnet.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 84b5cb31..48b3e1de 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -229,9 +229,9 @@ class TResNet(nn.Module): # head self.num_features = (self.planes * 8) * Bottleneck.expansion - self.global_pool = None - self.head = None - self.reset_classifier(num_classes, global_pool) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + self.head = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))])) # model initilization for m in self.modules(): @@ -273,11 +273,8 @@ class TResNet(nn.Module): return self.head.fc def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) self.num_classes = num_classes - if global_pool == 'avg': - self.global_pool = FastGlobalAvgPool2d(flatten=True) - else: - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) self.head = None if num_classes: self.head = nn.Sequential(OrderedDict([