|
|
|
@ -85,7 +85,7 @@ default_cfgs = dict(
|
|
|
|
|
# Mixer ImageNet-21K-P pretraining
|
|
|
|
|
mixer_b16_224_miil_in21k=_cfg(
|
|
|
|
|
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth',
|
|
|
|
|
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=2,
|
|
|
|
|
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=1, #11221
|
|
|
|
|
),
|
|
|
|
|
mixer_b16_224_miil=_cfg(
|
|
|
|
|
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth',
|
|
|
|
@ -287,8 +287,23 @@ class MlpMixer(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
self.norm = norm_layer(embed_dim)
|
|
|
|
|
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
self.init_weights(nlhb=nlhb)
|
|
|
|
|
#self.head = nn.Sequential(
|
|
|
|
|
# nn.Linear(embed_dim, self.num_classes),
|
|
|
|
|
# nn.ReLU(),
|
|
|
|
|
# nn.Dropout(p=0.3),
|
|
|
|
|
# nn.Linear(self.num_classes, 1024),
|
|
|
|
|
# nn.ReLU(),
|
|
|
|
|
# nn.Dropout(p=0.3),
|
|
|
|
|
# nn.Linear(1024, 512),
|
|
|
|
|
# nn.ReLU(),
|
|
|
|
|
# nn.Dropout(p=0.3),
|
|
|
|
|
# nn.Linear(512, 256),
|
|
|
|
|
# nn.ReLU(),
|
|
|
|
|
# nn.Dropout(p=0.3),
|
|
|
|
|
# nn.Linear(256, 1)
|
|
|
|
|
# )
|
|
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
#self.init_weights(nlhb=nlhb)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, nlhb=False):
|
|
|
|
|
head_bias = -math.log(self.num_classes) if nlhb else 0.
|
|
|
|
@ -304,17 +319,21 @@ class MlpMixer(nn.Module):
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
#x = self.stem(x)
|
|
|
|
|
#print(x.shape)
|
|
|
|
|
print("In_Model")
|
|
|
|
|
x = self.blocks(x)
|
|
|
|
|
print(x)
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
print(x)
|
|
|
|
|
x = x.mean(dim=1)
|
|
|
|
|
print(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
#print("In_model")
|
|
|
|
|
#print(x.shape)
|
|
|
|
|
#print(x)
|
|
|
|
|
print(x)
|
|
|
|
|
x = self.sigmoid(x)
|
|
|
|
|
print(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|