# Convert Whisper transformer model from PyTorch to ggml format # # Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium # # You need to clone the original repo in ~/path/to/repo/whisper/ # # git clone https://github.com/openai/whisper ~/path/to/repo/whisper/ # # It is used to various assets needed by the algorithm: # # - tokenizer # - mel filters # # Also, you need to have the original models in ~/.cache/whisper/ # See the original repo for more details. # # This script loads the specified model and whisper assets and saves them in ggml format. # The output is a single binary file containing the following information: # # - hparams # - mel filters # - tokenizer vocab # - model variables # # For each variable, write the following: # # - Number of dimensions (int) # - Name length (int) # - Dimensions (int[n_dims]) # - Name (char[name_length]) # - Data (float[n_dims]) # import io import os import sys import struct import json import code import torch import numpy as np #from transformers import GPTJForCausalLM #from transformers import GPT2TokenizerFast # ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110 #LANGUAGES = { # "en": "english", # "zh": "chinese", # "de": "german", # "es": "spanish", # "ru": "russian", # "ko": "korean", # "fr": "french", # "ja": "japanese", # "pt": "portuguese", # "tr": "turkish", # "pl": "polish", # "ca": "catalan", # "nl": "dutch", # "ar": "arabic", # "sv": "swedish", # "it": "italian", # "id": "indonesian", # "hi": "hindi", # "fi": "finnish", # "vi": "vietnamese", # "iw": "hebrew", # "uk": "ukrainian", # "el": "greek", # "ms": "malay", # "cs": "czech", # "ro": "romanian", # "da": "danish", # "hu": "hungarian", # "ta": "tamil", # "no": "norwegian", # "th": "thai", # "ur": "urdu", # "hr": "croatian", # "bg": "bulgarian", # "lt": "lithuanian", # "la": "latin", # "mi": "maori", # "ml": "malayalam", # "cy": "welsh", # "sk": "slovak", # "te": "telugu", # "fa": "persian", # "lv": "latvian", # "bn": "bengali", # "sr": "serbian", # "az": "azerbaijani", # "sl": "slovenian", # "kn": "kannada", # "et": "estonian", # "mk": "macedonian", # "br": "breton", # "eu": "basque", # "is": "icelandic", # "hy": "armenian", # "ne": "nepali", # "mn": "mongolian", # "bs": "bosnian", # "kk": "kazakh", # "sq": "albanian", # "sw": "swahili", # "gl": "galician", # "mr": "marathi", # "pa": "punjabi", # "si": "sinhala", # "km": "khmer", # "sn": "shona", # "yo": "yoruba", # "so": "somali", # "af": "afrikaans", # "oc": "occitan", # "ka": "georgian", # "be": "belarusian", # "tg": "tajik", # "sd": "sindhi", # "gu": "gujarati", # "am": "amharic", # "yi": "yiddish", # "lo": "lao", # "uz": "uzbek", # "fo": "faroese", # "ht": "haitian creole", # "ps": "pashto", # "tk": "turkmen", # "nn": "nynorsk", # "mt": "maltese", # "sa": "sanskrit", # "lb": "luxembourgish", # "my": "myanmar", # "bo": "tibetan", # "tl": "tagalog", # "mg": "malagasy", # "as": "assamese", # "tt": "tatar", # "haw": "hawaiian", # "ln": "lingala", # "ha": "hausa", # "ba": "bashkir", # "jw": "javanese", # "su": "sundanese", #} ## ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292 #def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"): # os.environ["TOKENIZERS_PARALLELISM"] = "false" # path = os.path.join(path_to_whisper_repo, "whisper/assets", name) # tokenizer = GPT2TokenizerFast.from_pretrained(path) # # specials = [ # "<|startoftranscript|>", # *[f"<|{lang}|>" for lang in LANGUAGES.keys()], # "<|translate|>", # "<|transcribe|>", # "<|startoflm|>", # "<|startofprev|>", # "<|nocaptions|>", # "<|notimestamps|>", # ] # # tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) # return tokenizer # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8+n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) if len(sys.argv) < 4: print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n") sys.exit(1) fname_inp = sys.argv[1] dir_whisper = sys.argv[2] dir_out = sys.argv[3] # try to load PyTorch binary data try: model_bytes = open(fname_inp, "rb").read() with io.BytesIO(model_bytes) as fp: checkpoint = torch.load(fp, map_location="cpu") except: print("Error: failed to load PyTorch model file: %s" % fname_inp) sys.exit(1) hparams = checkpoint["dims"] print("hparams:", hparams) list_vars = checkpoint["model_state_dict"] #print(list_vars['encoder.positional_embedding']) #print(list_vars['encoder.conv1.weight']) #print(list_vars['encoder.conv1.weight'].shape) # load mel filters n_mels = hparams["n_mels"] with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f: filters = torch.from_numpy(f[f"mel_{n_mels}"]) #print (filters) #code.interact(local=locals()) multilingual = hparams["n_vocab"] == 51865 dir_tokenizer = os.path.join(dir_whisper, "whisper/assets", multilingual and "multilingual" or "gpt2") #tokenizer = build_tokenizer(dir_whisper, multilingual and "multilingual" or "gpt2") #print(tokenizer) #print(tokenizer.name_or_path) #print(len(tokenizer.additional_special_tokens)) # output in the same directory as the model fname_out = dir_out + "/ggml-model.bin" with open(dir_tokenizer + "/vocab.json", "r", encoding="utf8") as f: tokens = json.load(f) # use 16-bit or 32-bit floats use_f16 = True if len(sys.argv) > 4: use_f16 = False fname_out = dir_out + "/ggml-model-f32.bin" fout = open(fname_out, "wb") fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex fout.write(struct.pack("i", hparams["n_vocab"])) fout.write(struct.pack("i", hparams["n_audio_ctx"])) fout.write(struct.pack("i", hparams["n_audio_state"])) fout.write(struct.pack("i", hparams["n_audio_head"])) fout.write(struct.pack("i", hparams["n_audio_layer"])) fout.write(struct.pack("i", hparams["n_text_ctx"])) fout.write(struct.pack("i", hparams["n_text_state"])) fout.write(struct.pack("i", hparams["n_text_head"])) fout.write(struct.pack("i", hparams["n_text_layer"])) fout.write(struct.pack("i", hparams["n_mels"])) fout.write(struct.pack("i", use_f16)) # write mel filters fout.write(struct.pack("i", filters.shape[0])) fout.write(struct.pack("i", filters.shape[1])) for i in range(filters.shape[0]): for j in range(filters.shape[1]): fout.write(struct.pack("f", filters[i][j])) byte_encoder = bytes_to_unicode() byte_decoder = {v:k for k, v in byte_encoder.items()} fout.write(struct.pack("i", len(tokens))) for key in tokens: text = bytearray([byte_decoder[c] for c in key]) fout.write(struct.pack("i", len(text))) fout.write(text) for name in list_vars.keys(): data = list_vars[name].squeeze().numpy() print("Processing variable: " + name + " with shape: ", data.shape) # reshape conv bias from [n] to [n, 1] if name == "encoder.conv1.bias" or \ name == "encoder.conv2.bias": data = data.reshape(data.shape[0], 1) print(" Reshaped variable: " + name + " to shape: ", data.shape) n_dims = len(data.shape); # looks like the whisper models are in f16 by default # so we need to convert the small tensors to f32 until we fully support f16 in ggml # ftype == 0 -> float32, ftype == 1 -> float16 ftype = 1; if use_f16: if n_dims < 2 or \ name == "encoder.conv1.bias" or \ name == "encoder.conv2.bias" or \ name == "encoder.positional_embedding" or \ name == "decoder.positional_embedding": print(" Converting to float32") data = data.astype(np.float32) ftype = 0 else: data = data.astype(np.float32) ftype = 0 #if name.startswith("encoder"): # if name.endswith("mlp.0.weight") or \ # name.endswith("mlp.2.weight"): # print(" Transposing") # data = data.transpose() # header str = name.encode('utf-8') fout.write(struct.pack("iii", n_dims, len(str), ftype)) for i in range(n_dims): fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) fout.write(str); # data data.tofile(fout) fout.close() print("Done. Output file: " + fname_out) print("")