|
|
|
@ -515,52 +515,6 @@ def create_block(block: Union[str, nn.Module], **kwargs):
|
|
|
|
|
return _block_registry[block](**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# class Stem(nn.Module):
|
|
|
|
|
#
|
|
|
|
|
# def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
|
|
|
|
|
# num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
|
|
|
|
|
# super().__init__()
|
|
|
|
|
# assert stride in (2, 4)
|
|
|
|
|
# if pool:
|
|
|
|
|
# assert stride == 4
|
|
|
|
|
# layers = layers or LayerFn()
|
|
|
|
|
#
|
|
|
|
|
# if isinstance(out_chs, (list, tuple)):
|
|
|
|
|
# num_rep = len(out_chs)
|
|
|
|
|
# stem_chs = out_chs
|
|
|
|
|
# else:
|
|
|
|
|
# stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
|
|
|
|
|
#
|
|
|
|
|
# self.stride = stride
|
|
|
|
|
# stem_strides = [2] + [1] * (num_rep - 1)
|
|
|
|
|
# if stride == 4 and not pool:
|
|
|
|
|
# # set last conv in stack to be strided if stride == 4 and no pooling layer
|
|
|
|
|
# stem_strides[-1] = 2
|
|
|
|
|
#
|
|
|
|
|
# num_act = num_rep if num_act is None else num_act
|
|
|
|
|
# # if num_act < num_rep, first convs in stack won't have bn + act
|
|
|
|
|
# stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
|
|
|
|
|
# prev_chs = in_chs
|
|
|
|
|
# convs = []
|
|
|
|
|
# for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
|
|
|
|
|
# layer_fn = layers.conv_norm_act if na else create_conv2d
|
|
|
|
|
# convs.append(layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
|
|
|
|
|
# prev_chs = ch
|
|
|
|
|
# self.conv = nn.Sequential(*convs) if len(convs) > 1 else convs[0]
|
|
|
|
|
#
|
|
|
|
|
# if not pool:
|
|
|
|
|
# self.pool = nn.Identity()
|
|
|
|
|
# elif 'max' in pool.lower():
|
|
|
|
|
# self.pool = nn.MaxPool2d(3, 2, 1) if pool else nn.Identity()
|
|
|
|
|
# else:
|
|
|
|
|
# assert False, "Unknown pooling type"
|
|
|
|
|
#
|
|
|
|
|
# def forward(self, x):
|
|
|
|
|
# x = self.conv(x)
|
|
|
|
|
# x = self.pool(x)
|
|
|
|
|
# return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Stem(nn.Sequential):
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
|
|
|
|
|