From ca38e1e73fe4f318e153f2cb66efe41de4b7cf5f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 20 Jan 2023 14:44:05 -0800 Subject: [PATCH] Update ClassifierHead module, add reset() method, update in_chs -> in_features for consistency --- timm/layers/classifier.py | 15 +++++++++++++-- timm/models/cspnet.py | 2 +- timm/models/regnet.py | 2 +- timm/models/xception_aligned.py | 2 +- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 3ac33387..e885084c 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -38,13 +38,24 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False class ClassifierHead(nn.Module): """Classifier head w/ configurable global pooling and dropout.""" - def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): + def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False): super(ClassifierHead, self).__init__() self.drop_rate = drop_rate - self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) + self.in_features = in_features + self.use_conv = use_conv + + self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() + def reset(self, num_classes, global_pool=None): + if global_pool is not None: + if global_pool != self.global_pool.pool_type: + self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv) + self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity() + num_pooled_features = self.in_features * self.global_pool.feat_mult() + self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv) + def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate: diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 26ec54d9..da9d1ae0 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -913,7 +913,7 @@ class CspNet(nn.Module): # Construct the head self.num_features = prev_chs self.head = ClassifierHead( - in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 9d2528f6..63c9b57f 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -496,7 +496,7 @@ class RegNet(nn.Module): self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() self.num_features = prev_width self.head = ClassifierHead( - in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index e3348e64..6bb7085f 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -216,7 +216,7 @@ class XceptionAligned(nn.Module): num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] self.act = act_layer(inplace=True) if preact else nn.Identity() self.head = ClassifierHead( - in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) @torch.jit.ignore def group_matcher(self, coarse=False):