|
|
@ -12,7 +12,7 @@ from torch import nn as nn
|
|
|
|
from torch.nn import functional as F
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_USE_MEM_EFFICIENT_ISH = False
|
|
|
|
_USE_MEM_EFFICIENT_ISH = True
|
|
|
|
if _USE_MEM_EFFICIENT_ISH:
|
|
|
|
if _USE_MEM_EFFICIENT_ISH:
|
|
|
|
# This version reduces memory overhead of Swish during training by
|
|
|
|
# This version reduces memory overhead of Swish during training by
|
|
|
|
# recomputing torch.sigmoid(x) in backward instead of saving it.
|
|
|
|
# recomputing torch.sigmoid(x) in backward instead of saving it.
|
|
|
|