|
|
|
@ -23,12 +23,12 @@ from .device_env import DeviceEnv, DeviceEnvType, TensorList
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_PT_TO_XM_OP = {
|
|
|
|
|
ReduceOp.SUM: 'sum',
|
|
|
|
|
ReduceOp.PRODUCT: 'prod',
|
|
|
|
|
ReduceOp.MIN: 'min',
|
|
|
|
|
ReduceOp.MAX: 'max',
|
|
|
|
|
ReduceOp.BAND: 'and',
|
|
|
|
|
ReduceOp.BOR: 'or',
|
|
|
|
|
ReduceOp.SUM: xm.REDUCE_SUM,
|
|
|
|
|
ReduceOp.PRODUCT: xm.REDUCE_MUL,
|
|
|
|
|
ReduceOp.MIN: xm.REDUCE_MIN,
|
|
|
|
|
ReduceOp.MAX: xm.REDUCE_MAX,
|
|
|
|
|
ReduceOp.BAND: xm.REDUCE_AND,
|
|
|
|
|
ReduceOp.BOR: xm.REDUCE_OR,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -77,20 +77,16 @@ class DeviceEnvXla(DeviceEnv):
|
|
|
|
|
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False):
|
|
|
|
|
assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed
|
|
|
|
|
op = _PT_TO_XM_OP[op]
|
|
|
|
|
scale = 1.0
|
|
|
|
|
if average:
|
|
|
|
|
scale /= self.world_size
|
|
|
|
|
scale = 1.0 / self.world_size if average else 1.0
|
|
|
|
|
return xm.all_reduce(op, tensor, scale=scale)
|
|
|
|
|
|
|
|
|
|
def all_reduce_(self, tensor: TensorList, op=ReduceOp.SUM, average=False):
|
|
|
|
|
op = _PT_TO_XM_OP[op]
|
|
|
|
|
scale = 1.0
|
|
|
|
|
scale = 1.0 / self.world_size if average else 1.0
|
|
|
|
|
wrapped = False
|
|
|
|
|
if isinstance(tensor, torch.Tensor):
|
|
|
|
|
tensor = [tensor] # bare tensors are not operated on in-place
|
|
|
|
|
wrapped = True
|
|
|
|
|
if average:
|
|
|
|
|
scale /= self.world_size
|
|
|
|
|
xm.all_reduce(op, tensor, scale=scale)
|
|
|
|
|
if wrapped:
|
|
|
|
|
tensor = tensor[0]
|
|
|
|
|