Update README.md + minor stuff

- Changed default threads to 4
- Added GGML_PERF for enabling runtime performance timings
experiments/blocking
Georgi Gerganov 2 years ago
parent 0f4e99b1cc
commit f21b84cd21
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1,24 +1,23 @@
# ggml
Tensor library in C for machine learning
Tensor library for machine learning
## Features
- Automatic differentiation (WIP)
- Written in C
- 16-bit float support
- Automatic differentiation (WIP in progress)
- ADAM and L-BFGS optimizers
- Optimized for Arm64 architectures (i.e. MacBook M1) via NEON intrinsics
- Optimized for Arm64 architectures (M1) via NEON intrinsics
- On x86 architectures utilzes AVX intrinsics
- No third-party dependencies
- Zero memory allocations during runtime
## Local GPT inference
## Example - GPT inference
Using ggml you can run [GPT-2](examples/gpt-2) and [GPT-J](examples/gpt-j) inference locally on your computer without any additional software or hardware. You don't even need to install python or any other third-party library.
With ggml you can efficiently run [GPT-2](examples/gpt-2) and [GPT-J](examples/gpt-j) inference on the CPU.
The example programs are implemented in C++. They run entirely on the CPU.
Here is how to use them:
Here is how to run the example programs:
```bash
# Build ggml + examples
@ -37,7 +36,7 @@ make -j4 gpt-2 gpt-j
./bin/gpt-j -m models/gpt-j-6B/ggml-model.bin -p "This is an example"
```
This is the inference speed for the different models on my MacBook M1 Pro:
The inference speeds that I get for the different models on my 32GB MacBook M1 Pro are as follows:
| Model | Size | Time / Token |
| --- | --- | --- |

@ -1,7 +1,6 @@
# gpt-2
This is a C++ example running GPT-2 inference using the [ggml](https://github.com/ggerganov/ggml) library.
The enitre code of the example is in [main.cpp](main.cpp).
The program runs on the CPU - no video card is required.
@ -73,11 +72,11 @@ main: total time = 629.84 ms
## Downloading and converting the original models
You can download the original model files using the [download-model.sh](download-model.sh) Bash script.
The model is in Tensorflow format, so before using it with ggml, we need to convert it to appropriate format.
This is done via the [convert-ckpt-to-ggml.py](convert-ckpt-to-ggml.py) python script.
You can download the original model files using the [download-model.sh](download-model.sh) Bash script. The models are
in Tensorflow format, so in order to use them with ggml, you need to convert them to appropriate format. This is done
via the [convert-ckpt-to-ggml.py](convert-ckpt-to-ggml.py) python script.
Here is the entire process for the GPT-2 117M model:
Here is the entire process for the GPT-2 117M model (download from official site + conversion):
```
cd ggml/build
@ -99,14 +98,13 @@ Run the convert-ckpt-to-ggml.py script to convert the model to ggml format.
```
This conversion requires that you have python and Tensorflow installed on your computer.
Still, if you want to avoid this, you can download the already converted ggml models as
described below.
This conversion requires that you have python and Tensorflow installed on your computer. Still, if you want to avoid
this, you can download the already converted ggml models as described below.
## Downloading the ggml model directly
For convenience, I will be hosting the converted ggml model files in order to make it easier to run the examples.
This way, you can directly download a single binary file and start using it. No python or Tensorflow is required.
For convenience, I will be hosting the converted ggml model files in order to make it easier to run the examples. This
way, you can directly download a single binary file and start using it. No python or Tensorflow is required.
Here is how to get the 117M ggml model:
@ -123,4 +121,4 @@ You can now use it like this:
```
At some point, I might stop hosting these models. So in that case, simply revert to the manual process above.
At some point, I might decide to stop hosting these models. So in that case, simply revert to the manual process above.

@ -4,25 +4,23 @@ Local GPT-J inference on your computer using C/C++
No video card required. You just need to have 16 GB of RAM.
For example, you can run this on a 16 GB MacBook M1.
## Motivation
The GPT-J 6B model is the open-source alternative to OpenAI's GPT-3. It's basically a neural network that
allows you to generate coherent, human-like text given a certain context (prompt).
The GPT-J 6B model is the open-source alternative to OpenAI's GPT-3. It's basically a neural network that allows you to
generate coherent, human-like text given a certain context (prompt).
The GPT-J model is quite big - the compact version of the model uses 16-bit floating point representation
of the weights and is still 12 GB big. This means that in order to run inference on your computer, you
would need to have a video card with at least 12 GB of video RAM. Alternatively, you can try to run the
python implementations on the CPU, but that would probably not be very efficient as they are primarily
optimized for running on a GPU (or at least this is my guess - I don't have much experience with python).
The GPT-J model is quite big - the compact version of the model uses 16-bit floating point representation of the weights
and is still 12 GB big. This means that in order to run inference on your computer, you would need to have a video card
with at least 12 GB of video RAM. Alternatively, you can try to run the python implementations on the CPU, but that
would probably not be very efficient as they are primarily optimized for running on a GPU (or at least this is my guess -
I don't have much experience with python).
Looking on the internet, I couldn't find a dedicated CPU implementation that would allow me to run the model
without a high-end video card. So I decided to write my own inference using a custom build tensor library.
The tensor library (called [ggml](https://github.com/ggerganov/ggml), written in C) is in early development
stage, but it already allows me to run the GPT-J model.
I wanted to try and run the model on my MacBook, so I decided to implement the model inference from scratch using my own
custom build tensor library. The tensor library (called [ggml](https://github.com/ggerganov/ggml), written in C) is in
early development stage, but it already allows me to run the GPT-J model.
On my MacBook M1 Pro, I achieve an inference speed of about `125 ms/token` or about 2-3 words per second.
On my 32GB MacBook M1 Pro, I achieve an inference speed of about `125 ms/token` or about ~6 words per second (1 word
typically consists of 1 or 2 tokens).
Here is a sample run with prompt `int main(int argc, char ** argv) {`:
@ -68,51 +66,133 @@ main: total time = 33035.37 ms
real 0m33.171s
user 3m32.269s
sys 0m3.686s
sys 0m3.686s
$
```
It took ~6.2 seconds to load the model to memory. After that, it took ~26.4 seconds to generate 200
tokens of what looks like to be the beginning of a networking program in C. Pretty cool!
It took ~6.2 seconds to load the model to memory. After that, it took ~26.4 seconds to generate 200 tokens of what
looks like to be the beginning of a networking program in C. Pretty cool!
Here is another run, just for fun:
```
time ./bin/gpt-j -n 500 -t 8 -p "Ask HN: Inherited the worst code and tech team I have ever seen. How to fix it?
"
gptj_model_load: loading model from 'models/gpt-j-6B/ggml-model.bin' - please wait ...
gptj_model_load: n_vocab = 50400
gptj_model_load: n_ctx = 2048
gptj_model_load: n_embd = 4096
gptj_model_load: n_head = 16
gptj_model_load: n_layer = 28
gptj_model_load: n_rot = 64
gptj_model_load: f16 = 1
gptj_model_load: ggml ctx size = 13334.86 MB
gptj_model_load: memory_size = 1792.00 MB, n_mem = 57344
gptj_model_load: ................................... done
gptj_model_load: model size = 11542.79 MB / num tensors = 285
main: number of tokens in prompt = 24
Ask HN: Inherited the worst code and tech team I have ever seen. How to fix it?
I've inherited a team with some very strange and un-documented practices, one of them is that they use an old custom
application with a very slow tech stack written in Python that the team doesn't want to touch but also doesn't want to
throw away as it has some "legacy" code in it.
The problem is, the tech stack is very very slow.
They have a single web server on a VM that is slow.
The server is a little bit busy (not very busy though) and they have a lot of processes (30+ that are constantly being
spawned by the application)
They have an application that is single threaded and was written in Python and the team don't want to touch this, and
the application is very slow.
My task as a new member of the team is to fix this.
I'm a senior dev on the team (3 years on the project) and have been told that I will take the lead on this task. I know
next to nothing about Python. So here is what I have so far.
What I have done is I've been trying to debug the processes with the "ps" command. This way I can see what is running
and where. From what I see, the application spawns 10 processes a minute and some of them are used for nothing.
I have also started to look for the code. The application source is not in GitHub or any other repository, it is only on
our internal GitLab.
What I've found so far:
The application uses a custom SQLAlchemy implementation to interact with the data. I've looked at the source, it looks
like an object cache or something like that. But from what I've seen, the cache gets full every 20 minutes and then gets
cleared with a special command.
Another strange thing is that the application creates a file for every entry in the database (even if the entry already
exists). I've looked at the file to see if it contains something, but it seems to be a JSON file with lots of records.
The other strange thing is that I can only find the database tables in the GitLab repository and not the code. So I
can't really understand how the application is supposed to interact with the database.
I also found a "log" directory, but the code is encrypted with AES. From what I've found, it is in
main: mem per token = 16430420 bytes
main: load time = 3900.10 ms
main: sample time = 32.58 ms
main: predict time = 68049.91 ms / 130.11 ms per token
main: total time = 73020.05 ms
real 1m13.156s
user 9m1.328s
sys. 0m7.103s
```
## Implementation details
The high level implementation of the model is contained in the [main.cpp](main.cpp) file. The core
computations are performed by the `ggml` library.
The high level implementation of the model is contained in the [main.cpp](main.cpp) file. The core computations are
performed by the [ggml](https://github.com/ggerganov/ggml/blob/master/include/ggml/ggml.h) library.
#### Matrix multiplication
The most performance critical part of the implementation is of course the matrix multiplication routine.
99% of the time is spent here, so it is important to optimize this as much as possible.
The most performance critical part of the implementation is of course the matrix multiplication routine. 99% of the time
is spent here, so it was important to optimize this as much as possible.
On Arm64, I utilize the 128-bit NEON intrinsics for 16-bit floating point operations:
https://github.com/ggerganov/ggml/blob/fb558f78d905f85c54813602649ddd628ffe0f3a/src/ggml.c#L187-L243
These instructions allow each core to operate simultaneously on 64 floating point numbers. I'm no expert
in SIMD, but after quite some trials this was the most efficient code for dot product that I could come up
with. Combined with the parallel computation on 8 CPU threads, I think I got close to the maximum performance
that one could possibly get on the M1 CPU. Still, I'm curious to know if there is a more efficient way to
implement this.
These instructions allow each core to operate simultaneously on 64 16-bit floats. I'm no expert in SIMD, but after quite
some trials this was the most efficient code for dot product of a row and column that I could come up with. Combined
with the parallel computation on 8 CPU threads, I believe I'm close to the maximum performance that one could possibly
get on the M1 CPU. Still, I'm curious to know if there is a more efficient way to implement this.
#### Attempt to use the M1 GPU
One interesting property of the GPT-J transformer architecture is that it allows you to perform part
of the inference in parallel - i.e. the Feed-forward layer can be computed in parallel to the Self-Attention
layer:
One interesting property of the GPT-J transformer architecture is that it allows you to perform part of the inference in
parallel - i.e. the Feed-forward network can be computed in parallel to the Self-attention layer:
https://github.com/ggerganov/ggml/blob/fb558f78d905f85c54813602649ddd628ffe0f3a/examples/gpt-j/main.cpp#L507-L531
So I thought why not bring in the M1 GPU to compute half of the neural network in parallel to the CPU.
Thanks to the shared memory model, it was relatively easy to offload half of the computation to the GPU
using [Metal Performance Shaders](https://developer.apple.com/documentation/metalperformanceshaders).
However, to my surprise, I did not get any performance improvement at all. My conclusion was that the
8-thread NEON CPU computation is basically saturating the memory bandwidth of the M1 and since the CPU
and the GPU on the MacBook are sharing that bandwidth, it does not help to offload the computation to the
GPU. Another observation was that the MPS GPU matrix multiplication using 16-bit floats had the same
performance as the 8-thread NEON CPU implementation. Again, I explain this with a saturated memory channel.
But of course, I could be totally wrong and somehow my implementation wasn't utilizing the resources
correctly.
So I thought why not try and bring in the M1 GPU to compute half of the neural network in parallel to the CPU and
potentially gain some extra performance. Thanks to the M1's shared memory model, it was relatively easy to offload part
of the computation to the GPU using Apple's [Metal Performance
Shaders](https://developer.apple.com/documentation/metalperformanceshaders). The GPU shares the host memory, so there is
no need to copy the data back and forth as you would normally do with Cuda or OpenCL. The weight matrices are directly
available to be used by the GPU.
Another property of my implementation is that it does not perform any memory allocations once the model
is loaded into memory. All required memory is allocated at the start of the program.
However, to my surprise, using MPS together with the CPU did not lead to any performance improvement at all. My
conclusion was that the 8-thread NEON CPU computation is already saturating the memory bandwidth of the M1 and since
the CPU and the GPU on the MacBook are sharing that bandwidth, it does not help to offload the computation to the GPU.
Another observation was that the MPS GPU matrix multiplication using 16-bit floats had the same performance as the
8-thread NEON CPU implementation. Again, I explain this with a saturated memory channel. But of course, my explanation
could be totally wrong and somehow the implementation wasn't utilizing the resources correctly.
In the end, I decided to not use MPS or the GPU all together.
### Zero memory allocations
Another property of my implementation is that it does not perform any memory allocations once the model is loaded into
memory. All required memory is allocated at the start of the program with a single `malloc` (technically 2 calls, but
that is not important).
## Usage
@ -134,22 +214,26 @@ make -j4 gpt-j
```
To run the `gpt-j` tool, you need the 12GB `ggml-model.bin` file which contains the GPT-J model in
[ggml](https://github.com/ggerganov/ggml) format. In the instructions above, I download the binary file
[ggml](https://github.com/ggerganov/ggml) compatible format. In the instructions above, I download the binary file
directly from one of my servers, using the [download-ggml-model.sh](download-ggml-model.sh) script.
---
Alternatively, you can perform the conversion yourself.
Alternatively, if you don't want to download the 12GB ggml model file, you can perform the conversion yourself using
python.
First, you need to download the full GPT-J model from here: https://huggingface.co/EleutherAI/gpt-j-6B
Note that the full model is quite big - about 72 GB. After you download it, you need to make the
conversion using the [convert-h5-to-ggml.py](convert-h5-to-ggml.py) script. This will generate the
`ggml-model.bin` file, which you can then use with the `gpt-j` program.
Note that the full model is quite big - about 72 GB. After you download it, you need to convert it to ggml format using
the [convert-h5-to-ggml.py](convert-h5-to-ggml.py) script. This will generate the `ggml-model.bin` file, which you can
then use with the `gpt-j` program.
## GPT-2
I have also implemented a tool for CPU inference using the smaller GPT-2 models. They have worse
quality compared to GPT-J, but are much faster to execute.
I also implemented a tool for CPU inference using the smaller GPT-2 models. They have worse quality compared to GPT-J,
but are much faster to execute.
For example, the Small GPT-2 model is only 240 MB big and the inference speed on my MacBook is about 200 tokens/sec.
Checkout the GPT-2 example here: [gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2)
For more details, checkout the GPT-2 example here: [gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2)

@ -14,7 +14,7 @@
struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(8, (int32_t) std::thread::hardware_concurrency());
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 200; // new tokens to predict
// sampling parameters

@ -48,6 +48,10 @@ set(TARGET ggml)
# endif()
#endif()
if (GGML_PERF)
set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_PERF)
endif()
add_library(${TARGET}
ggml.c
)

@ -12,12 +12,11 @@
#include <pthread.h>
#define GGML_DEBUG 0
#define GGML_MEM_ALIGN 16
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define GGML_MEM_ALIGN 16
#define UNUSED(x) (void)(x)
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
@ -117,7 +116,6 @@ ggml_fp16_t ggml_fp32_to_fp16(float f) {
// timing
//
// TODO: need to be able to disable these in performance critical code since they make slow system calls
int64_t ggml_time_ms(void) {
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
@ -138,6 +136,18 @@ int64_t ggml_cycles_per_ms(void) {
return CLOCKS_PER_SEC/1000;
}
#ifdef GGML_PERF
#define ggml_perf_time_ms() ggml_time_ms()
#define ggml_perf_time_us() ggml_time_us()
#define ggml_perf_cycles() ggml_cycles()
#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms()
#else
#define ggml_perf_time_ms() 0
#define ggml_perf_time_us() 0
#define ggml_perf_cycles() 0
#define ggml_perf_cycles_per_ms() 0
#endif
//
// cache line
//
@ -3053,7 +3063,7 @@ void ggml_compute_forward_mul_mat_f32(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
int64_t t0 = ggml_time_us();
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
const int ne00 = src0->ne[0];
@ -3232,7 +3242,7 @@ void ggml_compute_forward_mul_mat_f32(
}
}
//int64_t t1 = ggml_time_us();
//int64_t t1 = ggml_perf_time_us();
//static int64_t acc = 0;
//acc += t1 - t0;
//if (t1 - t0 > 10) {
@ -3251,7 +3261,7 @@ void ggml_compute_forward_mul_mat_f16_f32(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
int64_t t0 = ggml_time_us();
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
const int ne00 = src0->ne[0];
@ -4619,8 +4629,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
}
}
const int64_t perf_start_cycles = ggml_cycles();
const int64_t perf_start_time_us = ggml_time_us();
const int64_t perf_start_cycles = ggml_perf_cycles();
const int64_t perf_start_time_us = ggml_perf_time_us();
for (int i = 0; i < cgraph->n_nodes; i++) {
GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes);
@ -4632,8 +4642,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
// continue;
//}
const int64_t perf_node_start_cycles = ggml_cycles();
const int64_t perf_node_start_time_us = ggml_time_us();
const int64_t perf_node_start_cycles = ggml_perf_cycles();
const int64_t perf_node_start_time_us = ggml_perf_time_us();
// INIT
struct ggml_compute_params params = {
@ -4706,8 +4716,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
// performance stats (node)
{
int64_t perf_cycles_cur = ggml_cycles() - perf_node_start_cycles;
int64_t perf_time_us_cur = ggml_time_us() - perf_node_start_time_us;
int64_t perf_cycles_cur = ggml_perf_cycles() - perf_node_start_cycles;
int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us;
node->perf_runs++;
node->perf_cycles += perf_cycles_cur;
@ -4731,8 +4741,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
// performance stats (graph)
{
int64_t perf_cycles_cur = ggml_cycles() - perf_start_cycles;
int64_t perf_time_us_cur = ggml_time_us() - perf_start_time_us;
int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles;
int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us;
cgraph->perf_runs++;
cgraph->perf_cycles += perf_cycles_cur;

Loading…
Cancel
Save