From ed683187cbe9f5574a2f8b3106d96f2e07721f49 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 31 Dec 2022 13:57:04 +0200 Subject: [PATCH] t5 : add example for text-to-text transfer transformer inference --- examples/CMakeLists.txt | 1 + examples/t5/CMakeLists.txt | 6 ++++++ examples/t5/README.md | 3 +++ examples/t5/convert-flan-t5-pt-to-ggml.py | 25 +++++++++++++++++++++++ examples/t5/main.cpp | 3 +++ 5 files changed, 38 insertions(+) create mode 100644 examples/t5/CMakeLists.txt create mode 100644 examples/t5/README.md create mode 100644 examples/t5/convert-flan-t5-pt-to-ggml.py create mode 100644 examples/t5/main.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 5f7f3a4..e4d8ff7 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -4,3 +4,4 @@ target_include_directories(ggml_utils PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) add_subdirectory(gpt-2) add_subdirectory(gpt-j) add_subdirectory(whisper) +add_subdirectory(t5) diff --git a/examples/t5/CMakeLists.txt b/examples/t5/CMakeLists.txt new file mode 100644 index 0000000..f1258c1 --- /dev/null +++ b/examples/t5/CMakeLists.txt @@ -0,0 +1,6 @@ +# +# t5 + +set(TEST_TARGET t5) +add_executable(${TEST_TARGET} main.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils) diff --git a/examples/t5/README.md b/examples/t5/README.md new file mode 100644 index 0000000..a12cb07 --- /dev/null +++ b/examples/t5/README.md @@ -0,0 +1,3 @@ +# t5 + +ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py diff --git a/examples/t5/convert-flan-t5-pt-to-ggml.py b/examples/t5/convert-flan-t5-pt-to-ggml.py new file mode 100644 index 0000000..32a3b23 --- /dev/null +++ b/examples/t5/convert-flan-t5-pt-to-ggml.py @@ -0,0 +1,25 @@ +import io +import sys +import torch + +import code + +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + +if len(sys.argv) < 3: + print("Usage: convert-flan-t5-pt-to-ggml.py path-to-pt-model dir-output [use-f32]\n") + sys.exit(1) + +fname_inp=sys.argv[1] + "/pytorch_model.bin" + +try: + model_bytes = open(fname_inp, "rb").read() + with io.BytesIO(model_bytes) as fp: + checkpoint = torch.load(fp, map_location="cpu") +except: + print("Error: failed to load PyTorch model file: %s" % fname_inp) + sys.exit(1) + +# list all keys +for k in checkpoint.keys(): + print(k) diff --git a/examples/t5/main.cpp b/examples/t5/main.cpp new file mode 100644 index 0000000..33c14ce --- /dev/null +++ b/examples/t5/main.cpp @@ -0,0 +1,3 @@ +int main() { + return 0; +}