From a21cceb7c4b4e5df2987eccae829faf01c012a57 Mon Sep 17 00:00:00 2001 From: Abeer Alessa Date: Wed, 13 Oct 2021 19:47:33 +0300 Subject: [PATCH] Updated metrics.py script and used it in train.py --- .DS_Store | Bin 0 -> 8196 bytes timm/.DS_Store | Bin 0 -> 6148 bytes timm/utils/__init__.py | 2 +- timm/utils/metrics.py | 48 +++++++++++++++++++++++++++++++++++++++++ train.py | 48 ++++++++++++++++++++--------------------- 5 files changed, 73 insertions(+), 25 deletions(-) create mode 100644 .DS_Store create mode 100644 timm/.DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..cd3aeda8f17c4e366cc0086feefd273a95b3e9ce GIT binary patch literal 8196 zcmeHM&2G~`5S~p_;-rN}P}E9D$On)(v{hRuJs_m?CrFh5AyUBsP-{1CW7Ua6;xv>} zlnW052M!#Ncod$12jKzWo3-oO8B|z4`*tG!YqyEX!j^ zG6L6g916LpK;}^_1M>3KAJXIk$Vwy!p+pIS2%(T6R76({5#1a_ z2uFP-(jSB}gp-gh!%oz)h;Aq%x_Agvs*|XMQm2{$&A>1Ne0R^&9PN=yfpUMphq_z8 zWRSp1CwgR2fL{j~n|jDO)K#$G6l@{JHh{3x+G0x4F2&j|de$NzZ8-|&gMzt?UL8Pt zi*~4i9_`RRwSjbj52y=^Z8}tJ{8Vg=!ygCwY@rEz9oWdg5i|2O!ORsPJRmJ#v)j+I z!W<>76kx@Ng#%dXsG}h{N@`D19lmvyaj2tx9vMseJ%{>+;xqG>w3CM&A1x~*D43h1 zUJ0CFF)q_F`B16Y}b0;=n`-SD=TqL2Fc{$5f-L;sIt$N_&hR(F!y- z=m8Kl+C&=|bp-{#Q>YR~c>lzAAP({DDHw2&!c0KN2gbo@&r07o|0WyFvqXGJuu;U= z`f|0{kUm;hI49;Ijd%)=R%jg_R?$k6>PiopizMPIuM~QXk#c3OiKlS>&nA2g&}U_} z3+5+j>DnJGowC5I>x6?WtIS~1YWYm7b6{quSIl9~OmJ9`_%8hOpr65CO1p^p??A7I zwLgcMyGwh+{G|{^3NgH6fTfiFs2opTg-Ty|{iHPW&~Y(>&2$f>ho2K3#(H+Y+m61b zez)76+YRRDe~8>je)K}YFiJ-0V(C?_;rDBS74*8RR_6sjJ+Xp@9pyKD-;48>{mgax zTjj|$zZF<+%jrn4TX&%H^10ip`~6kF*Q$4=omyC77)7JFRh~XNT3lM1F&F2Tk7vxI z<>mPq^Um$N$Hzru>gK}zjqT>)+oN~K?>}PGKoHSe8cl@Harz0l*ujHZr{(#rJq92e zAn0`C8ls*TM=nX!7&b}^)4;_9wdpJpxyy$j+h^_5LbLWoZxQ=_*eK#1% z`Jf1jS>6w(Z}{m&P>e!Teko=8ryPW@l6XFB-#?LA5hPhWRnAEigYx!ul0|afm$NL& zRh)rtaGaLY8n*Wri^HR%j(6DYEj!+#*Xwq?Cy$>lmn~=S!NIe$^YPV(#r5*zXAB1> zuyrfDXYdBTps*swGCL>c@e&sKtrg&Abo5&V7LIE@J&hAFM zjeQBLfK_0V6yW{ALt^v|mKxR8flOTifDKeDLzzDo*vB*I8!R=V1tzpBP`e6q#Sq#Z z^`7DR21||ForJl32-CALHx!|IM|@AwN%$IVZWXW!tShjsAKQHXKfnF{zwTt8tO8bn z|4ISTI31k!F(h-g)&|FCt%r1u#Kyc*qjEuJj$>8eqj(=l8QMH90DXg{M%2LUhk%m7 KW>$fJs=#mepSB+W literal 0 HcmV?d00001 diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 11de9c9c..55045ae6 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -5,7 +5,7 @@ from .cuda import ApexScaler, NativeScaler from .distributed import distribute_bn, reduce_tensor from .jit import set_jit_legacy from .log import setup_default_logging, FormatterNoInfo -from .metrics import AverageMeter, accuracy +from .metrics import AverageMeter, accuracy , precision, recall, f1_score from .misc import natural_key, add_bool_arg from .model import unwrap_model, get_state_dict, freeze, unfreeze from .model_ema import ModelEma, ModelEmaV2 diff --git a/timm/utils/metrics.py b/timm/utils/metrics.py index 9fdbe13e..3a7ae510 100644 --- a/timm/utils/metrics.py +++ b/timm/utils/metrics.py @@ -3,6 +3,11 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import torch +from sklearn.metrics import precision_score,accuracy_score ,recall_score ,log_loss,f1_score,confusion_matrix +import torch.nn.functional as F +import warnings +warnings.filterwarnings('always') # "error", "ignore", "always", "default", "module" or "once" class AverageMeter: """Computes and stores the average and current value""" @@ -30,3 +35,46 @@ def accuracy(output, target, topk=(1,)): pred = pred.t() correct = pred.eq(target.reshape(1, -1).expand_as(pred)) return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] +def precision(output, target): + _, y_pred = output.topk(1, 1, True, True) + y_true = target + y_true=y_true.tolist() + y_pred=y_pred.tolist() + y_pred=sum(y_pred, []) + TP=0+0.00000000000000009 + FP=0 + for i in range(len(y_true)): + if y_true[i]==1 and y_pred[i]==1: + TP=TP+1 + if y_true[i]==0 and y_pred[i]==1: + FP=FP+1 + precision=precision_score(y_true,y_pred) + + return precision*100 + + +def recall(output, target): + y_pred = torch.ge(output, 0.0) + _, y_pred = output.topk(1, 1, True, True) + y_true = target + true_positive = len((y_true.flatten() == y_pred.flatten()).nonzero().flatten()) + y_true=y_true.tolist() + y_pred=y_pred.tolist() + y_pred=sum(y_pred, []) + TP=0+0.00000000000000009 + FN=0 + for i in range(len(y_true)): + if y_true[i]==1 and y_pred[i]==1: + TP=TP+1 + if y_true[i]==1 and y_pred[i]==0: + FN=FN+1 + recall =recall_score(y_true,y_pred) + return recall*100 + +def f1_scor(output, target): + y_pred = torch.ge(output, 0.5) + _, y_pred = output.topk(1, 1, True, True) + y_true = target + y_true=y_true.tolist() + y_pred=y_pred.tolist() + return f1_score(y_true,y_pred)*100 diff --git a/train.py b/train.py index 332dec0c..2954578b 100755 --- a/train.py +++ b/train.py @@ -752,15 +752,16 @@ def train_one_epoch( return OrderedDict([('loss', losses_m.avg)]) - def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): + batch_time_m = AverageMeter() losses_m = AverageMeter() top1_m = AverageMeter() - top5_m = AverageMeter() - + top5_ = AverageMeter() + top_p = AverageMeter() + top_r = AverageMeter() + top_f = AverageMeter() model.eval() - end = time.time() last_idx = len(loader) - 1 with torch.no_grad(): @@ -771,35 +772,35 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') target = target.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) - with amp_autocast(): output = model(input) if isinstance(output, (tuple, list)): output = output[0] - # augmentation reduction reduce_factor = args.tta if reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) - target = target[0:target.size(0):reduce_factor] - + target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - +# acc1 = accuracy(output, target, topk=(1,1)) + acc1, acc5 = accuracy(output, target, topk=(1, k)) + f1 = f1_scor(output, target) + prec = precision(output.detach(), target) + rec = recall(output.detach(), target) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) acc1 = reduce_tensor(acc1, args.world_size) - acc5 = reduce_tensor(acc5, args.world_size) + # acc5 = reduce_tensor(acc5, args.world_size) else: reduced_loss = loss.data - torch.cuda.synchronize() - - losses_m.update(reduced_loss.item(), input.size(0)) - top1_m.update(acc1.item(), output.size(0)) - top5_m.update(acc5.item(), output.size(0)) - - batch_time_m.update(time.time() - end) + losses_m.update(reduced_loss.item(),'acc', input.size(0)) + top1_m.update(acc1.item(),'acc', output.size(0)) + top_p.update(prec,'prec', input.size(0)) + top_r.update(rec,'rec', input.size(0)) + top_f.update(f1,'f1', input.size(0)) + # top5_m.update(acc5.item(), output.size(0)) + batch_time_m.update(time.time() - end,'acc') end = time.time() if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix @@ -808,13 +809,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' - 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( + 'Precision: {top_p.val:>7.3f} ({top_p.avg_pre:>7.3f}) ' + 'Recall: {top_r.val:>7.3f} ({top_r.avg_rec:>7.3f})' + 'F1: {top_f.val:>7.3f} ({top_f.avg_f1:>7.3f})'.format( log_name, batch_idx, last_idx, batch_time=batch_time_m, - loss=losses_m, top1=top1_m, top5=top5_m)) - - metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) - - return metrics + loss=losses_m, top1=top1_m,top_p=top_p,top_r=top_r,top_f=top_f)) + metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('f1', top_f.avg_f1), ('Precision',top_p.avg_pre), ('Recall',top_r.avg_rec)]) if __name__ == '__main__':