|
|
|
@ -23,12 +23,12 @@ from .device_env import DeviceEnv, DeviceEnvType, TensorList
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_PT_TO_XM_OP = {
|
|
|
|
|
ReduceOp.SUM: xm.REDUCE_SUM,
|
|
|
|
|
ReduceOp.PRODUCT: xm.REDUCE_MUL,
|
|
|
|
|
ReduceOp.MIN: xm.REDUCE_MIN,
|
|
|
|
|
ReduceOp.MAX: xm.REDUCE_MAX,
|
|
|
|
|
ReduceOp.BAND: xm.REDUCE_AND,
|
|
|
|
|
ReduceOp.BOR: xm.REDUCE_OR,
|
|
|
|
|
ReduceOp.SUM: 'sum',
|
|
|
|
|
ReduceOp.PRODUCT: 'mul',
|
|
|
|
|
ReduceOp.MIN: 'min',
|
|
|
|
|
ReduceOp.MAX: 'max',
|
|
|
|
|
ReduceOp.BAND: 'and',
|
|
|
|
|
ReduceOp.BOR: 'or',
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|