Update ClassifierHead module, add reset() method, update in_chs -> in_features for consistency

pull/1641/head
Ross Wightman 2 years ago
parent 8ab573cd26
commit ca38e1e73f

@ -38,13 +38,24 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
class ClassifierHead(nn.Module): class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout.""" """Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
super(ClassifierHead, self).__init__() super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) self.in_features = in_features
self.use_conv = use_conv
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
if global_pool != self.global_pool.pool_type:
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
num_pooled_features = self.in_features * self.global_pool.feat_mult()
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
def forward(self, x, pre_logits: bool = False): def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x) x = self.global_pool(x)
if self.drop_rate: if self.drop_rate:

@ -913,7 +913,7 @@ class CspNet(nn.Module):
# Construct the head # Construct the head
self.num_features = prev_chs self.num_features = prev_chs
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

@ -496,7 +496,7 @@ class RegNet(nn.Module):
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
self.num_features = prev_width self.num_features = prev_width
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

@ -216,7 +216,7 @@ class XceptionAligned(nn.Module):
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
self.act = act_layer(inplace=True) if preact else nn.Identity() self.act = act_layer(inplace=True) if preact else nn.Identity()
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):

Loading…
Cancel
Save