Small TResNet simplification, just use SelectAdaptivePool, don't notice any perf difference

pull/136/head
Ross Wightman 4 years ago
parent e3a98171b2
commit be7c784d21

@ -229,9 +229,9 @@ class TResNet(nn.Module):
# head # head
self.num_features = (self.planes * 8) * Bottleneck.expansion self.num_features = (self.planes * 8) * Bottleneck.expansion
self.global_pool = None self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head = None self.head = nn.Sequential(OrderedDict([
self.reset_classifier(num_classes, global_pool) ('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))]))
# model initilization # model initilization
for m in self.modules(): for m in self.modules():
@ -273,11 +273,8 @@ class TResNet(nn.Module):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.num_classes = num_classes self.num_classes = num_classes
if global_pool == 'avg':
self.global_pool = FastGlobalAvgPool2d(flatten=True)
else:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head = None self.head = None
if num_classes: if num_classes:
self.head = nn.Sequential(OrderedDict([ self.head = nn.Sequential(OrderedDict([

Loading…
Cancel
Save