From 592e2984af6a45cf9883b29d12a8f6c888ccabdf Mon Sep 17 00:00:00 2001 From: kira7005 Date: Thu, 14 Apr 2022 18:19:19 +0000 Subject: [PATCH] classes_number --- timm/data/bag_sampler.py | 1 + timm/loss/mil_ranking.py | 6 +++--- timm/models/mlp_mixer.py | 31 +++++++++++++++++++++++++------ 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/timm/data/bag_sampler.py b/timm/data/bag_sampler.py index c1c3d8c9..c4f1535e 100644 --- a/timm/data/bag_sampler.py +++ b/timm/data/bag_sampler.py @@ -2,6 +2,7 @@ import torch import random from torch.utils.data.sampler import Sampler +##TODO: Ensure that all videos are covered. Make a copy of indices and remove the indices that were already picked. class BagSampler(Sampler): def __init__(self, dataset): halfway_point = int(len(dataset)/2) diff --git a/timm/loss/mil_ranking.py b/timm/loss/mil_ranking.py index 7bdc8ee4..ccdfab78 100644 --- a/timm/loss/mil_ranking.py +++ b/timm/loss/mil_ranking.py @@ -9,15 +9,15 @@ class MilRankingLoss(nn.Module): #def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # pass #def mil_ranking(y_true, y_pred): - def forward(self, y_true, y_pred): + def forward(self, y_pred, y_true): 'Custom Objective function' y_true = torch.flatten(y_true) y_pred = torch.flatten(y_pred) print("MIL_Ranking") - print(y_true) + #print(y_true) #print(y_true.type) - print(y_pred) + #print(y_pred) #print(y_pred.type) n_seg = 32 # Because we have 32 segments per video. diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 9aba4b7e..c008bfb5 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -85,7 +85,7 @@ default_cfgs = dict( # Mixer ImageNet-21K-P pretraining mixer_b16_224_miil_in21k=_cfg( url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth', - mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=2, + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=1, #11221 ), mixer_b16_224_miil=_cfg( url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth', @@ -287,8 +287,23 @@ class MlpMixer(nn.Module): """ self.norm = norm_layer(embed_dim) self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() - - self.init_weights(nlhb=nlhb) + #self.head = nn.Sequential( + # nn.Linear(embed_dim, self.num_classes), + # nn.ReLU(), + # nn.Dropout(p=0.3), + # nn.Linear(self.num_classes, 1024), + # nn.ReLU(), + # nn.Dropout(p=0.3), + # nn.Linear(1024, 512), + # nn.ReLU(), + # nn.Dropout(p=0.3), + # nn.Linear(512, 256), + # nn.ReLU(), + # nn.Dropout(p=0.3), + # nn.Linear(256, 1) + # ) + self.sigmoid = nn.Sigmoid() + #self.init_weights(nlhb=nlhb) def init_weights(self, nlhb=False): head_bias = -math.log(self.num_classes) if nlhb else 0. @@ -304,17 +319,21 @@ class MlpMixer(nn.Module): def forward_features(self, x): #x = self.stem(x) #print(x.shape) + print("In_Model") x = self.blocks(x) + print(x) x = self.norm(x) + print(x) x = x.mean(dim=1) + print(x) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) - #print("In_model") - #print(x.shape) - #print(x) + print(x) + x = self.sigmoid(x) + print(x) return x