Update model prune loader to use pkgutil

pull/1662/head
Ross Wightman 2 years ago
parent 0f2803de7a
commit 7a0bd095cb

@ -1,4 +1,5 @@
import os
import pkgutil
from copy import deepcopy
from torch import nn as nn
@ -108,6 +109,5 @@ def adapt_model_from_string(parent_module, model_string):
def adapt_model_from_file(parent_module, model_variant):
adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt')
with open(adapt_file, 'r') as f:
return adapt_model_from_string(parent_module, f.read().strip())
adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())

Loading…
Cancel
Save