diff --git a/post_quantization_validate.py b/post_quantization_validate.py index 73c09dd4..f026d5ab 100755 --- a/post_quantization_validate.py +++ b/post_quantization_validate.py @@ -20,6 +20,8 @@ from collections import OrderedDict from contextlib import suppress import torch.quantization +import torch.quantization.quantize_fx as quantize_fx +import copy #currently, quantization only runs on CPUs os.environ['CUDA_VISIBLE_DEVICES'] = ""