You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
29 lines
896 B
29 lines
896 B
""" Distributed training/validation utils
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
import torch
|
|
from torch import distributed as dist
|
|
|
|
from .model import unwrap_model
|
|
|
|
|
|
def reduce_tensor(tensor, n):
|
|
rt = tensor.clone()
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
|
rt /= n
|
|
return rt
|
|
|
|
|
|
def distribute_bn(model, world_size, reduce=False):
|
|
# ensure every node has the same running bn stats
|
|
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
|
|
if ('running_mean' in bn_name) or ('running_var' in bn_name):
|
|
if reduce:
|
|
# average bn stats across whole group
|
|
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
|
|
bn_buf /= float(world_size)
|
|
else:
|
|
# broadcast bn stats from rank 0 to whole group
|
|
torch.distributed.broadcast(bn_buf, 0)
|