model_freeze_and_pretraining

pull/1229/head
kira7005 3 years ago
parent 50fd5be983
commit 450b0d57f0

@ -49,9 +49,9 @@ class TextDataset(Dataset):
if len(features) == 0: if len(features) == 0:
print(idx) print(idx)
print(VideoPath) print(VideoPath)
features = torch.reshape(features, (16, 256)) #features = torch.reshape(features, (16, 256))
# features = torch.reshape(features, (196, 768)) # features = torch.reshape(features, (196, 768))
#features = torch.reshape(features, (1, 4096)) features = torch.reshape(features, (1, 4096))
#print(VideoPath) #print(VideoPath)
if VideoPath.find('Normal') == -1: if VideoPath.find('Normal') == -1:
label = 0 label = 0

@ -266,16 +266,18 @@ class MlpMixer(nn.Module):
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
##initial_fc and stem not needed ##initial_fc and stem not needed
#self.initial_fc =nn.Linear(4096, 150528) self.initial_fc =nn.Linear(4096, 150528)
#self.stem = PatchEmbed( #self.stem = PatchEmbed(
# img_size=img_size, patch_size=patch_size, in_chans=in_chans, # img_size=img_size, patch_size=patch_size, in_chans=in_chans,
# embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None) # embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None)
# FIXME drop_path (stochastic depth scaling rule or all the same?) # FIXME drop_path (stochastic depth scaling rule or all the same?)
#print("num_classes:",self.num_classes, "embed_dim:", embed_dim) #print("num_classes:",self.num_classes, "embed_dim:", embed_dim)
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
block_layer( block_layer(
embed_dim embed_dim
,16 #196 #self.stem.num_patches ,196 #16 #self.stem.num_patches
, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, , mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
for _ in range(num_blocks)]) for _ in range(num_blocks)])
@ -288,21 +290,21 @@ class MlpMixer(nn.Module):
""" """
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
# self.head = nn.Sequential( self.final_head = nn.Sequential(
# nn.Linear(embed_dim, self.num_classes), # nn.Linear(embed_dim, self.num_classes),
# nn.ReLU(), # nn.ReLU(),
# nn.Dropout(p=0.3), # nn.Dropout(p=0.3),
# nn.Linear(self.num_classes, 1024), nn.Linear(self.num_classes, 1024),
# nn.ReLU(), nn.ReLU(),
# nn.Dropout(p=0.3), nn.Dropout(p=0.3),
# nn.Linear(1024, 512), nn.Linear(1024, 512),
# nn.ReLU(), nn.ReLU(),
# nn.Dropout(p=0.3), nn.Dropout(p=0.3),
# nn.Linear(512, 256), nn.Linear(512, 256),
# nn.ReLU(), nn.ReLU(),
# nn.Dropout(p=0.3), nn.Dropout(p=0.3),
# nn.Linear(256, 2) nn.Linear(256, 2)
# ) )
#self.sigmoid = nn.Sigmoid() #self.sigmoid = nn.Sigmoid()
self.sm = nn.Softmax(dim=1) self.sm = nn.Softmax(dim=1)
self.init_weights(nlhb=nlhb) self.init_weights(nlhb=nlhb)
@ -330,10 +332,12 @@ class MlpMixer(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
#x = self.initial_fc(x) x = self.initial_fc(x)
#x = torch.reshape(x, (196, 768)) x = nn.ReLU(x)
x = torch.reshape(x, (196, 768))
x = self.forward_features(x) x = self.forward_features(x)
x = self.head(x) x = self.head(x)
x = self.final_head(x)
#print(x) #print(x)
#x = self.sigmoid(x) #x = self.sigmoid(x)
#print(x) #print(x)

@ -668,6 +668,10 @@ def train_one_epoch(
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None): loss_scaler=None, model_ema=None, mixup_fn=None):
for param in model.forward_features.parameters():
param.requires_grad=False
print(model)
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled: if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False loader.mixup_enabled = False

Loading…
Cancel
Save