Updated metrics.py script and used it in train.py

pull/914/head
Abeer Alessa 4 years ago
parent 02daf2ab94
commit a21cceb7c4

BIN
.DS_Store vendored

Binary file not shown.

BIN
timm/.DS_Store vendored

Binary file not shown.

@ -5,7 +5,7 @@ from .cuda import ApexScaler, NativeScaler
from .distributed import distribute_bn, reduce_tensor
from .jit import set_jit_legacy
from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy
from .metrics import AverageMeter, accuracy , precision, recall, f1_score
from .misc import natural_key, add_bool_arg
from .model import unwrap_model, get_state_dict, freeze, unfreeze
from .model_ema import ModelEma, ModelEmaV2

@ -3,6 +3,11 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from sklearn.metrics import precision_score,accuracy_score ,recall_score ,log_loss,f1_score,confusion_matrix
import torch.nn.functional as F
import warnings
warnings.filterwarnings('always') # "error", "ignore", "always", "default", "module" or "once"
class AverageMeter:
"""Computes and stores the average and current value"""
@ -30,3 +35,46 @@ def accuracy(output, target, topk=(1,)):
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
def precision(output, target):
_, y_pred = output.topk(1, 1, True, True)
y_true = target
y_true=y_true.tolist()
y_pred=y_pred.tolist()
y_pred=sum(y_pred, [])
TP=0+0.00000000000000009
FP=0
for i in range(len(y_true)):
if y_true[i]==1 and y_pred[i]==1:
TP=TP+1
if y_true[i]==0 and y_pred[i]==1:
FP=FP+1
precision=precision_score(y_true,y_pred)
return precision*100
def recall(output, target):
y_pred = torch.ge(output, 0.0)
_, y_pred = output.topk(1, 1, True, True)
y_true = target
true_positive = len((y_true.flatten() == y_pred.flatten()).nonzero().flatten())
y_true=y_true.tolist()
y_pred=y_pred.tolist()
y_pred=sum(y_pred, [])
TP=0+0.00000000000000009
FN=0
for i in range(len(y_true)):
if y_true[i]==1 and y_pred[i]==1:
TP=TP+1
if y_true[i]==1 and y_pred[i]==0:
FN=FN+1
recall =recall_score(y_true,y_pred)
return recall*100
def f1_scor(output, target):
y_pred = torch.ge(output, 0.5)
_, y_pred = output.topk(1, 1, True, True)
y_true = target
y_true=y_true.tolist()
y_pred=y_pred.tolist()
return f1_score(y_true,y_pred)*100

@ -752,15 +752,16 @@ def train_one_epoch(
return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
top5_ = AverageMeter()
top_p = AverageMeter()
top_r = AverageMeter()
top_f = AverageMeter()
model.eval()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
@ -771,35 +772,35 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
target = target.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# acc1 = accuracy(output, target, topk=(1,1))
acc1, acc5 = accuracy(output, target, topk=(1, k))
f1 = f1_scor(output, target)
prec = precision(output.detach(), target)
rec = recall(output.detach(), target)
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
# acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
losses_m.update(reduced_loss.item(),'acc', input.size(0))
top1_m.update(acc1.item(),'acc', output.size(0))
top_p.update(prec,'prec', input.size(0))
top_r.update(rec,'rec', input.size(0))
top_f.update(f1,'f1', input.size(0))
# top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end,'acc')
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
@ -808,13 +809,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
'Precision: {top_p.val:>7.3f} ({top_p.avg_pre:>7.3f}) '
'Recall: {top_r.val:>7.3f} ({top_r.avg_rec:>7.3f})'
'F1: {top_f.val:>7.3f} ({top_f.avg_f1:>7.3f})'.format(
log_name, batch_idx, last_idx, batch_time=batch_time_m,
loss=losses_m, top1=top1_m, top5=top5_m))
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics
loss=losses_m, top1=top1_m,top_p=top_p,top_r=top_r,top_f=top_f))
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('f1', top_f.avg_f1), ('Precision',top_p.avg_pre), ('Recall',top_r.avg_rec)])
if __name__ == '__main__':

Loading…
Cancel
Save