Spaces:
Running
Running
Flash + language support (ref #2)
Browse files- Achieved big performance improvement + memory usage reduction
- Can now translate / transcribe different languages
Makefile
CHANGED
|
@@ -30,11 +30,16 @@ samples:
|
|
| 30 |
# runs it on all samples in the folder "./samples":
|
| 31 |
|
| 32 |
.PHONY: tiny.en
|
|
|
|
| 33 |
.PHONY: base.en
|
| 34 |
-
.PHONY:
|
| 35 |
.PHONY: small.en
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
tiny.en base.en
|
| 38 |
bash ./download-ggml-model.sh $@
|
| 39 |
@echo ""
|
| 40 |
@echo "==============================================="
|
|
|
|
| 30 |
# runs it on all samples in the folder "./samples":
|
| 31 |
|
| 32 |
.PHONY: tiny.en
|
| 33 |
+
.PHONY: tiny
|
| 34 |
.PHONY: base.en
|
| 35 |
+
.PHONY: base
|
| 36 |
.PHONY: small.en
|
| 37 |
+
.PHONY: small
|
| 38 |
+
.PHONY: medium.en
|
| 39 |
+
.PHONY: medium
|
| 40 |
+
.PHONY: large
|
| 41 |
|
| 42 |
+
tiny.en tiny base.en base small.en small medium.en medium large: main
|
| 43 |
bash ./download-ggml-model.sh $@
|
| 44 |
@echo ""
|
| 45 |
@echo "==============================================="
|
README.md
CHANGED
|
@@ -4,7 +4,8 @@ C/C++ port of [OpenAI's Whisper](https://github.com/openai/whisper) speech-to-te
|
|
| 4 |
|
| 5 |
- Plain C/C++ implementation without dependencies
|
| 6 |
- ARM_NEON and AVX intrinsics support
|
| 7 |
-
- F16 support
|
|
|
|
| 8 |
|
| 9 |
## Usage
|
| 10 |
|
|
@@ -27,9 +28,33 @@ For a quick demo, simply run `make base.en`:
|
|
| 27 |
```bash
|
| 28 |
$ make base.en
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
===============================================
|
| 35 |
Running base.en on all samples in ./samples ...
|
|
@@ -52,23 +77,24 @@ whisper_model_load: n_text_layer = 6
|
|
| 52 |
whisper_model_load: n_mels = 80
|
| 53 |
whisper_model_load: f16 = 1
|
| 54 |
whisper_model_load: type = 2
|
| 55 |
-
whisper_model_load: mem_required =
|
| 56 |
whisper_model_load: adding 1607 extra tokens
|
| 57 |
-
whisper_model_load: ggml ctx size =
|
| 58 |
-
whisper_model_load: memory size =
|
| 59 |
whisper_model_load: model size = 140.54 MB
|
| 60 |
log_mel_spectrogram: n_sample = 176000, n_len = 1100
|
| 61 |
log_mel_spectrogram: recording length: 11.000000 s
|
| 62 |
|
| 63 |
-
|
| 64 |
|
| 65 |
-
|
| 66 |
-
main: mel time = 38.69 ms
|
| 67 |
-
main: sample time = 2.36 ms
|
| 68 |
-
main: encode time = 875.63 ms / 145.94 ms per layer
|
| 69 |
-
main: decode time = 103.17 ms
|
| 70 |
-
main: total time = 1081.13 ms
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
```
|
| 73 |
|
| 74 |
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
|
|
@@ -81,13 +107,18 @@ make samples
|
|
| 81 |
|
| 82 |
This will download a few more audio files from Wikipedia and convert them to 16-bit WAV format via `ffmpeg`.
|
| 83 |
|
| 84 |
-
You can download and run the other
|
| 85 |
|
| 86 |
```
|
| 87 |
make tiny.en
|
|
|
|
| 88 |
make base.en
|
|
|
|
| 89 |
make small.en
|
|
|
|
| 90 |
make medium.en
|
|
|
|
|
|
|
| 91 |
```
|
| 92 |
|
| 93 |
For detailed usage instructions, run: `./main -h`
|
|
@@ -101,10 +132,8 @@ ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav
|
|
| 101 |
|
| 102 |
## Limitations
|
| 103 |
|
| 104 |
-
- Only `.en` models are supported
|
| 105 |
- Very basic greedy sampling scheme - always pick up the top token
|
| 106 |
- No timestamps
|
| 107 |
-
- English only
|
| 108 |
- Inference only
|
| 109 |
- Runs on the CPU
|
| 110 |
- Only mono-channel 16-bit WAV is supported
|
|
@@ -113,10 +142,11 @@ ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav
|
|
| 113 |
|
| 114 |
| Model | Disk | Mem |
|
| 115 |
| --- | --- | --- |
|
| 116 |
-
| tiny
|
| 117 |
-
| base
|
| 118 |
-
| small
|
| 119 |
-
| medium
|
|
|
|
| 120 |
|
| 121 |
## ggml format
|
| 122 |
|
|
|
|
| 4 |
|
| 5 |
- Plain C/C++ implementation without dependencies
|
| 6 |
- ARM_NEON and AVX intrinsics support
|
| 7 |
+
- Mixed F16 / F32 support
|
| 8 |
+
- Low memory usage (Flash Attention + Flash Forward)
|
| 9 |
|
| 10 |
## Usage
|
| 11 |
|
|
|
|
| 28 |
```bash
|
| 29 |
$ make base.en
|
| 30 |
|
| 31 |
+
gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c
|
| 32 |
+
g++ -pthread -O3 -std=c++11 -c main.cpp
|
| 33 |
+
g++ -o main ggml.o main.o
|
| 34 |
+
./main -h
|
| 35 |
+
|
| 36 |
+
usage: ./main [options]
|
| 37 |
+
|
| 38 |
+
options:
|
| 39 |
+
-h, --help show this help message and exit
|
| 40 |
+
-s SEED, --seed SEED RNG seed (default: -1)
|
| 41 |
+
-t N, --threads N number of threads to use during computation (default: 4)
|
| 42 |
+
-T N, --tokens N maximum number of tokens to generate per iteration (default: 64)
|
| 43 |
+
-v, --verbose verbose output
|
| 44 |
+
--translate translate from source language to english
|
| 45 |
+
-ps, --print_special print special tokens
|
| 46 |
+
-l LANG, --language LANG spoken language (default: en)
|
| 47 |
+
-m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
|
| 48 |
+
-f FNAME, --file FNAME input WAV file path (default: samples/jfk.wav)
|
| 49 |
+
|
| 50 |
+
bash ./download-ggml-model.sh base.en
|
| 51 |
+
Downloading ggml model base.en ...
|
| 52 |
+
models/ggml-base.en.bin 100%[=====================================>] 141.11M 7.84MB/s in 18s
|
| 53 |
+
Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
|
| 54 |
+
You can now use it like this:
|
| 55 |
+
|
| 56 |
+
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
|
| 57 |
+
|
| 58 |
|
| 59 |
===============================================
|
| 60 |
Running base.en on all samples in ./samples ...
|
|
|
|
| 77 |
whisper_model_load: n_mels = 80
|
| 78 |
whisper_model_load: f16 = 1
|
| 79 |
whisper_model_load: type = 2
|
| 80 |
+
whisper_model_load: mem_required = 611.00 MB
|
| 81 |
whisper_model_load: adding 1607 extra tokens
|
| 82 |
+
whisper_model_load: ggml ctx size = 163.43 MB
|
| 83 |
+
whisper_model_load: memory size = 22.83 MB
|
| 84 |
whisper_model_load: model size = 140.54 MB
|
| 85 |
log_mel_spectrogram: n_sample = 176000, n_len = 1100
|
| 86 |
log_mel_spectrogram: recording length: 11.000000 s
|
| 87 |
|
| 88 |
+
main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = transcribe ...
|
| 89 |
|
| 90 |
+
And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
main: load time = 71.89 ms
|
| 93 |
+
main: mel time = 36.95 ms
|
| 94 |
+
main: sample time = 2.10 ms
|
| 95 |
+
main: encode time = 700.94 ms / 116.82 ms per layer
|
| 96 |
+
main: decode time = 86.14 ms
|
| 97 |
+
main: total time = 898.72 ms
|
| 98 |
```
|
| 99 |
|
| 100 |
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
|
|
|
|
| 107 |
|
| 108 |
This will download a few more audio files from Wikipedia and convert them to 16-bit WAV format via `ffmpeg`.
|
| 109 |
|
| 110 |
+
You can download and run the other models as follows:
|
| 111 |
|
| 112 |
```
|
| 113 |
make tiny.en
|
| 114 |
+
make tiny
|
| 115 |
make base.en
|
| 116 |
+
make base
|
| 117 |
make small.en
|
| 118 |
+
make small
|
| 119 |
make medium.en
|
| 120 |
+
make medium
|
| 121 |
+
make large
|
| 122 |
```
|
| 123 |
|
| 124 |
For detailed usage instructions, run: `./main -h`
|
|
|
|
| 132 |
|
| 133 |
## Limitations
|
| 134 |
|
|
|
|
| 135 |
- Very basic greedy sampling scheme - always pick up the top token
|
| 136 |
- No timestamps
|
|
|
|
| 137 |
- Inference only
|
| 138 |
- Runs on the CPU
|
| 139 |
- Only mono-channel 16-bit WAV is supported
|
|
|
|
| 142 |
|
| 143 |
| Model | Disk | Mem |
|
| 144 |
| --- | --- | --- |
|
| 145 |
+
| tiny | 75 MB | ~460 MB |
|
| 146 |
+
| base | 142 MB | ~620 MB |
|
| 147 |
+
| small | 466 MB | ~1.3 GB |
|
| 148 |
+
| medium | 1.5 GB | ~2.8 GB |
|
| 149 |
+
| large | 2.9 GB | ~4.9 GB |
|
| 150 |
|
| 151 |
## ggml format
|
| 152 |
|
download-ggml-model.sh
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
ggml_path=$(dirname $(realpath $0))
|
| 7 |
|
| 8 |
# Whisper models
|
| 9 |
-
models=( "tiny.en" "base.en" "small.en" "medium.en" )
|
| 10 |
|
| 11 |
# list available models
|
| 12 |
function list_models {
|
|
|
|
| 6 |
ggml_path=$(dirname $(realpath $0))
|
| 7 |
|
| 8 |
# Whisper models
|
| 9 |
+
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large" )
|
| 10 |
|
| 11 |
# list available models
|
| 12 |
function list_models {
|
ggml.c
CHANGED
|
@@ -20,7 +20,13 @@
|
|
| 20 |
#define UNUSED(x) (void)(x)
|
| 21 |
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
| 22 |
|
| 23 |
-
#define GGML_ASSERT(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
#ifdef GGML_USE_ACCELERATE
|
| 26 |
#include <Accelerate/Accelerate.h>
|
|
@@ -118,6 +124,16 @@ ggml_fp16_t ggml_fp32_to_fp16(float f) {
|
|
| 118 |
}
|
| 119 |
#endif
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
//
|
| 122 |
// timing
|
| 123 |
//
|
|
@@ -331,7 +347,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
| 331 |
|
| 332 |
// leftovers
|
| 333 |
for (int i = n32; i < n; ++i) {
|
| 334 |
-
GGML_ASSERT(false); // should not end up here
|
| 335 |
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
|
| 336 |
}
|
| 337 |
#elif defined(__AVX2__)
|
|
@@ -375,7 +390,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
| 375 |
|
| 376 |
// leftovers
|
| 377 |
for (int i = n32; i < n; ++i) {
|
| 378 |
-
GGML_ASSERT(false);
|
| 379 |
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
|
| 380 |
}
|
| 381 |
#else
|
|
@@ -558,12 +573,20 @@ inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) {
|
|
| 558 |
const ggml_float GELU_COEF_A = 0.044715;
|
| 559 |
const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
|
| 560 |
|
| 561 |
-
inline static
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
for (int i = 0; i < n; ++i) {
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
}
|
| 568 |
}
|
| 569 |
|
|
@@ -641,6 +664,9 @@ const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|
| 641 |
"ROPE",
|
| 642 |
"CONV_1D_1S",
|
| 643 |
"CONV_1D_2S",
|
|
|
|
|
|
|
|
|
|
| 644 |
};
|
| 645 |
|
| 646 |
const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
@@ -678,6 +704,9 @@ const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 678 |
"rope(x)",
|
| 679 |
"conv_1d_1s(x)",
|
| 680 |
"conv_1d_2s(x)",
|
|
|
|
|
|
|
|
|
|
| 681 |
};
|
| 682 |
|
| 683 |
//
|
|
@@ -878,6 +907,24 @@ int ggml_up64(int n) {
|
|
| 878 |
////////////////////////////////////////////////////////////////////////////////
|
| 879 |
|
| 880 |
struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 881 |
// find non-used context in g_state
|
| 882 |
struct ggml_context * ctx = NULL;
|
| 883 |
|
|
@@ -900,7 +947,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
| 900 |
}
|
| 901 |
|
| 902 |
if (ctx == NULL) {
|
| 903 |
-
GGML_PRINT_DEBUG("%s
|
| 904 |
return NULL;
|
| 905 |
}
|
| 906 |
|
|
@@ -923,8 +970,8 @@ void ggml_free(struct ggml_context * ctx) {
|
|
| 923 |
if (&g_state.contexts[i].context == ctx) {
|
| 924 |
g_state.contexts[i].used = false;
|
| 925 |
|
| 926 |
-
GGML_PRINT_DEBUG("
|
| 927 |
-
i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
|
| 928 |
|
| 929 |
if (ctx->mem_buffer_owned) {
|
| 930 |
free(ctx->mem_buffer);
|
|
@@ -1010,6 +1057,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
|
| 1010 |
/*.grad =*/ NULL,
|
| 1011 |
/*.src0 =*/ NULL,
|
| 1012 |
/*.src1 =*/ NULL,
|
|
|
|
| 1013 |
/*.n_tasks =*/ 0,
|
| 1014 |
/*.perf_runs =*/ 0,
|
| 1015 |
/*.perf_cycles =*/ 0,
|
|
@@ -1079,6 +1127,14 @@ struct ggml_tensor * ggml_new_tensor_4d(
|
|
| 1079 |
return ggml_new_tensor(ctx, type, 4, ne);
|
| 1080 |
}
|
| 1081 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1082 |
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
|
| 1083 |
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
| 1084 |
|
|
@@ -1096,6 +1152,58 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
|
| 1096 |
return tensor;
|
| 1097 |
}
|
| 1098 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1099 |
struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
| 1100 |
const int n = ggml_nrows(tensor);
|
| 1101 |
const int nc = tensor->ne[0];
|
|
@@ -1148,40 +1256,109 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
| 1148 |
return tensor;
|
| 1149 |
}
|
| 1150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1151 |
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
| 1152 |
switch (tensor->type) {
|
| 1153 |
case GGML_TYPE_I8:
|
| 1154 |
{
|
| 1155 |
-
|
| 1156 |
return ((int8_t *)(tensor->data))[i];
|
| 1157 |
} break;
|
| 1158 |
case GGML_TYPE_I16:
|
| 1159 |
{
|
| 1160 |
-
|
| 1161 |
return ((int16_t *)(tensor->data))[i];
|
| 1162 |
} break;
|
| 1163 |
case GGML_TYPE_I32:
|
| 1164 |
{
|
| 1165 |
-
|
| 1166 |
return ((int32_t *)(tensor->data))[i];
|
| 1167 |
} break;
|
| 1168 |
case GGML_TYPE_F16:
|
| 1169 |
{
|
| 1170 |
-
|
| 1171 |
return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]);
|
| 1172 |
} break;
|
| 1173 |
case GGML_TYPE_F32:
|
| 1174 |
{
|
| 1175 |
-
|
| 1176 |
return ((float *)(tensor->data))[i];
|
| 1177 |
} break;
|
| 1178 |
case GGML_TYPE_COUNT:
|
| 1179 |
{
|
| 1180 |
-
|
| 1181 |
} break;
|
| 1182 |
}
|
| 1183 |
|
| 1184 |
-
assert(false);
|
| 1185 |
return 0.0f;
|
| 1186 |
}
|
| 1187 |
|
|
@@ -1189,32 +1366,32 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
|
| 1189 |
switch (tensor->type) {
|
| 1190 |
case GGML_TYPE_I8:
|
| 1191 |
{
|
| 1192 |
-
|
| 1193 |
((int8_t *)(tensor->data))[i] = value;
|
| 1194 |
} break;
|
| 1195 |
case GGML_TYPE_I16:
|
| 1196 |
{
|
| 1197 |
-
|
| 1198 |
((int16_t *)(tensor->data))[i] = value;
|
| 1199 |
} break;
|
| 1200 |
case GGML_TYPE_I32:
|
| 1201 |
{
|
| 1202 |
-
|
| 1203 |
((int32_t *)(tensor->data))[i] = value;
|
| 1204 |
} break;
|
| 1205 |
case GGML_TYPE_F16:
|
| 1206 |
{
|
| 1207 |
-
|
| 1208 |
((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value);
|
| 1209 |
} break;
|
| 1210 |
case GGML_TYPE_F32:
|
| 1211 |
{
|
| 1212 |
-
|
| 1213 |
((float *)(tensor->data))[i] = value;
|
| 1214 |
} break;
|
| 1215 |
case GGML_TYPE_COUNT:
|
| 1216 |
{
|
| 1217 |
-
|
| 1218 |
} break;
|
| 1219 |
}
|
| 1220 |
}
|
|
@@ -2308,6 +2485,70 @@ struct ggml_tensor * ggml_conv_1d_2s(
|
|
| 2308 |
return result;
|
| 2309 |
}
|
| 2310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2311 |
////////////////////////////////////////////////////////////////////////////////
|
| 2312 |
|
| 2313 |
void ggml_set_param(
|
|
@@ -2415,7 +2656,7 @@ void ggml_compute_forward_dup_f32(
|
|
| 2415 |
GGML_ASSERT(false); // TODO: implement
|
| 2416 |
}
|
| 2417 |
} else {
|
| 2418 |
-
printf("%s: this is not optimal - fix me\n", __func__);
|
| 2419 |
|
| 2420 |
if (dst->type == GGML_TYPE_F32) {
|
| 2421 |
int id = 0;
|
|
@@ -4185,10 +4426,17 @@ void ggml_compute_forward_soft_max_f32(
|
|
| 4185 |
}
|
| 4186 |
|
| 4187 |
ggml_float sum = 0.0;
|
|
|
|
| 4188 |
for (int i = 0; i < nc; i++) {
|
| 4189 |
-
|
| 4190 |
-
|
| 4191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4192 |
}
|
| 4193 |
|
| 4194 |
assert(sum > 0.0f);
|
|
@@ -4362,7 +4610,6 @@ void ggml_compute_forward_conv_1d_1s_f16_f32(
|
|
| 4362 |
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 4363 |
GGML_ASSERT(nb10 == sizeof(float));
|
| 4364 |
|
| 4365 |
-
// WHISPER
|
| 4366 |
if (params->type == GGML_TASK_INIT) {
|
| 4367 |
// TODO: fix this memset (wsize is overestimated)
|
| 4368 |
memset(params->wdata, 0, params->wsize);
|
|
@@ -4483,7 +4730,6 @@ void ggml_compute_forward_conv_1d_1s_f32(
|
|
| 4483 |
GGML_ASSERT(nb00 == sizeof(float));
|
| 4484 |
GGML_ASSERT(nb10 == sizeof(float));
|
| 4485 |
|
| 4486 |
-
// WHISPER
|
| 4487 |
if (params->type == GGML_TASK_INIT) {
|
| 4488 |
// TODO: fix this memset (wsize is overestimated)
|
| 4489 |
memset(params->wdata, 0, params->wsize);
|
|
@@ -4630,7 +4876,6 @@ void ggml_compute_forward_conv_1d_2s_f16_f32(
|
|
| 4630 |
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 4631 |
GGML_ASSERT(nb10 == sizeof(float));
|
| 4632 |
|
| 4633 |
-
// WHISPER
|
| 4634 |
if (params->type == GGML_TASK_INIT) {
|
| 4635 |
// TODO: fix this memset (wsize is overestimated)
|
| 4636 |
memset(params->wdata, 0, params->wsize);
|
|
@@ -4751,7 +4996,6 @@ void ggml_compute_forward_conv_1d_2s_f32(
|
|
| 4751 |
GGML_ASSERT(nb00 == sizeof(float));
|
| 4752 |
GGML_ASSERT(nb10 == sizeof(float));
|
| 4753 |
|
| 4754 |
-
// WHISPER
|
| 4755 |
if (params->type == GGML_TASK_INIT) {
|
| 4756 |
// TODO: fix this memset (wsize is overestimated)
|
| 4757 |
memset(params->wdata, 0, params->wsize);
|
|
@@ -4841,6 +5085,607 @@ void ggml_compute_forward_conv_1d_2s(
|
|
| 4841 |
}
|
| 4842 |
}
|
| 4843 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4844 |
/////////////////////////////////
|
| 4845 |
|
| 4846 |
void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
|
@@ -4967,13 +5812,24 @@ void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tenso
|
|
| 4967 |
{
|
| 4968 |
ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor);
|
| 4969 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4970 |
case GGML_OP_NONE:
|
| 4971 |
{
|
| 4972 |
// nop
|
| 4973 |
} break;
|
| 4974 |
case GGML_OP_COUNT:
|
| 4975 |
{
|
| 4976 |
-
|
| 4977 |
} break;
|
| 4978 |
};
|
| 4979 |
}
|
|
@@ -5205,6 +6061,14 @@ void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tenso
|
|
| 5205 |
{
|
| 5206 |
GGML_ASSERT(false); // TODO: not implemented
|
| 5207 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5208 |
case GGML_OP_NONE:
|
| 5209 |
{
|
| 5210 |
// nop
|
|
@@ -5246,6 +6110,12 @@ void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node)
|
|
| 5246 |
ggml_visit_parents(cgraph, node->src1);
|
| 5247 |
}
|
| 5248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5249 |
if (node->op == GGML_OP_NONE && node->grad == NULL) {
|
| 5250 |
// reached a leaf node, not part of the gradient graph (e.g. a constant)
|
| 5251 |
assert(cgraph->n_leafs < GGML_MAX_NODES);
|
|
@@ -5591,7 +6461,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
| 5591 |
case GGML_OP_CONV_1D_1S:
|
| 5592 |
case GGML_OP_CONV_1D_2S:
|
| 5593 |
{
|
| 5594 |
-
// WHISPER
|
| 5595 |
node->n_tasks = n_threads;
|
| 5596 |
|
| 5597 |
GGML_ASSERT(node->src0->ne[3] == 1);
|
|
@@ -5617,6 +6486,42 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
| 5617 |
GGML_ASSERT(false);
|
| 5618 |
}
|
| 5619 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5620 |
work_size = MAX(work_size, cur);
|
| 5621 |
} break;
|
| 5622 |
case GGML_OP_NONE:
|
|
|
|
| 20 |
#define UNUSED(x) (void)(x)
|
| 21 |
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
| 22 |
|
| 23 |
+
#define GGML_ASSERT(x) \
|
| 24 |
+
do { \
|
| 25 |
+
if (!(x)) { \
|
| 26 |
+
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
| 27 |
+
abort(); \
|
| 28 |
+
} \
|
| 29 |
+
} while (0)
|
| 30 |
|
| 31 |
#ifdef GGML_USE_ACCELERATE
|
| 32 |
#include <Accelerate/Accelerate.h>
|
|
|
|
| 124 |
}
|
| 125 |
#endif
|
| 126 |
|
| 127 |
+
//
|
| 128 |
+
// global data
|
| 129 |
+
//
|
| 130 |
+
|
| 131 |
+
// precomputed gelu table for f16 (128 KB)
|
| 132 |
+
static ggml_fp16_t table_gelu_f16[1 << 16];
|
| 133 |
+
|
| 134 |
+
// precomputed exp table for f16 (128 KB)
|
| 135 |
+
static ggml_fp16_t table_exp_f16[1 << 16];
|
| 136 |
+
|
| 137 |
//
|
| 138 |
// timing
|
| 139 |
//
|
|
|
|
| 347 |
|
| 348 |
// leftovers
|
| 349 |
for (int i = n32; i < n; ++i) {
|
|
|
|
| 350 |
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
|
| 351 |
}
|
| 352 |
#elif defined(__AVX2__)
|
|
|
|
| 390 |
|
| 391 |
// leftovers
|
| 392 |
for (int i = n32; i < n; ++i) {
|
| 393 |
+
//GGML_ASSERT(false);
|
| 394 |
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
|
| 395 |
}
|
| 396 |
#else
|
|
|
|
| 573 |
const ggml_float GELU_COEF_A = 0.044715;
|
| 574 |
const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
|
| 575 |
|
| 576 |
+
inline static float ggml_gelu_f32(float x) {
|
| 577 |
+
return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
|
| 581 |
for (int i = 0; i < n; ++i) {
|
| 582 |
+
y[i] = ggml_gelu_f32(x[i]);
|
| 583 |
+
}
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
| 587 |
+
const uint16_t * i16 = (const uint16_t *) x;
|
| 588 |
+
for (int i = 0; i < n; ++i) {
|
| 589 |
+
y[i] = table_gelu_f16[i16[i]];
|
| 590 |
}
|
| 591 |
}
|
| 592 |
|
|
|
|
| 664 |
"ROPE",
|
| 665 |
"CONV_1D_1S",
|
| 666 |
"CONV_1D_2S",
|
| 667 |
+
|
| 668 |
+
"FLASH_ATTN",
|
| 669 |
+
"FLASH_FF",
|
| 670 |
};
|
| 671 |
|
| 672 |
const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
|
|
| 704 |
"rope(x)",
|
| 705 |
"conv_1d_1s(x)",
|
| 706 |
"conv_1d_2s(x)",
|
| 707 |
+
|
| 708 |
+
"flash_attn(x)",
|
| 709 |
+
"flash_ff(x)",
|
| 710 |
};
|
| 711 |
|
| 712 |
//
|
|
|
|
| 907 |
////////////////////////////////////////////////////////////////////////////////
|
| 908 |
|
| 909 |
struct ggml_context * ggml_init(struct ggml_init_params params) {
|
| 910 |
+
static bool is_first_call = true;
|
| 911 |
+
if (is_first_call) {
|
| 912 |
+
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
|
| 913 |
+
|
| 914 |
+
for (int i = 0; i < (1 << 16); ++i) {
|
| 915 |
+
uint16_t ii = (uint16_t) i;
|
| 916 |
+
const float f = ggml_fp16_to_fp32(*(ggml_fp16_t *)(&ii));
|
| 917 |
+
table_gelu_f16[i] = ggml_fp32_to_fp16(ggml_gelu_f32(f));
|
| 918 |
+
table_exp_f16[i] = ggml_fp32_to_fp16(exp(f));
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
+
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
| 922 |
+
|
| 923 |
+
GGML_PRINT_DEBUG("%s: GELU table initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
| 924 |
+
|
| 925 |
+
is_first_call = false;
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
// find non-used context in g_state
|
| 929 |
struct ggml_context * ctx = NULL;
|
| 930 |
|
|
|
|
| 947 |
}
|
| 948 |
|
| 949 |
if (ctx == NULL) {
|
| 950 |
+
GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
|
| 951 |
return NULL;
|
| 952 |
}
|
| 953 |
|
|
|
|
| 970 |
if (&g_state.contexts[i].context == ctx) {
|
| 971 |
g_state.contexts[i].used = false;
|
| 972 |
|
| 973 |
+
GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
|
| 974 |
+
__func__, i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
|
| 975 |
|
| 976 |
if (ctx->mem_buffer_owned) {
|
| 977 |
free(ctx->mem_buffer);
|
|
|
|
| 1057 |
/*.grad =*/ NULL,
|
| 1058 |
/*.src0 =*/ NULL,
|
| 1059 |
/*.src1 =*/ NULL,
|
| 1060 |
+
/*.opt =*/ { NULL },
|
| 1061 |
/*.n_tasks =*/ 0,
|
| 1062 |
/*.perf_runs =*/ 0,
|
| 1063 |
/*.perf_cycles =*/ 0,
|
|
|
|
| 1127 |
return ggml_new_tensor(ctx, type, 4, ne);
|
| 1128 |
}
|
| 1129 |
|
| 1130 |
+
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
|
| 1131 |
+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
| 1132 |
+
|
| 1133 |
+
ggml_set_i32(result, value);
|
| 1134 |
+
|
| 1135 |
+
return result;
|
| 1136 |
+
}
|
| 1137 |
+
|
| 1138 |
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
|
| 1139 |
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
| 1140 |
|
|
|
|
| 1152 |
return tensor;
|
| 1153 |
}
|
| 1154 |
|
| 1155 |
+
struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
| 1156 |
+
const int n = ggml_nrows(tensor);
|
| 1157 |
+
const int nc = tensor->ne[0];
|
| 1158 |
+
const size_t n1 = tensor->nb[1];
|
| 1159 |
+
|
| 1160 |
+
char * const data = tensor->data;
|
| 1161 |
+
|
| 1162 |
+
switch (tensor->type) {
|
| 1163 |
+
case GGML_TYPE_I8:
|
| 1164 |
+
{
|
| 1165 |
+
assert(tensor->nb[0] == sizeof(int8_t));
|
| 1166 |
+
for (int i = 0; i < n; i++) {
|
| 1167 |
+
ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
|
| 1168 |
+
}
|
| 1169 |
+
} break;
|
| 1170 |
+
case GGML_TYPE_I16:
|
| 1171 |
+
{
|
| 1172 |
+
assert(tensor->nb[0] == sizeof(int16_t));
|
| 1173 |
+
for (int i = 0; i < n; i++) {
|
| 1174 |
+
ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
|
| 1175 |
+
}
|
| 1176 |
+
} break;
|
| 1177 |
+
case GGML_TYPE_I32:
|
| 1178 |
+
{
|
| 1179 |
+
assert(tensor->nb[0] == sizeof(int32_t));
|
| 1180 |
+
for (int i = 0; i < n; i++) {
|
| 1181 |
+
ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
|
| 1182 |
+
}
|
| 1183 |
+
} break;
|
| 1184 |
+
case GGML_TYPE_F16:
|
| 1185 |
+
{
|
| 1186 |
+
assert(tensor->nb[0] == sizeof(ggml_fp16_t));
|
| 1187 |
+
for (int i = 0; i < n; i++) {
|
| 1188 |
+
ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
|
| 1189 |
+
}
|
| 1190 |
+
} break;
|
| 1191 |
+
case GGML_TYPE_F32:
|
| 1192 |
+
{
|
| 1193 |
+
assert(tensor->nb[0] == sizeof(float));
|
| 1194 |
+
for (int i = 0; i < n; i++) {
|
| 1195 |
+
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
| 1196 |
+
}
|
| 1197 |
+
} break;
|
| 1198 |
+
case GGML_TYPE_COUNT:
|
| 1199 |
+
{
|
| 1200 |
+
assert(false);
|
| 1201 |
+
} break;
|
| 1202 |
+
}
|
| 1203 |
+
|
| 1204 |
+
return tensor;
|
| 1205 |
+
}
|
| 1206 |
+
|
| 1207 |
struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
| 1208 |
const int n = ggml_nrows(tensor);
|
| 1209 |
const int nc = tensor->ne[0];
|
|
|
|
| 1256 |
return tensor;
|
| 1257 |
}
|
| 1258 |
|
| 1259 |
+
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
| 1260 |
+
switch (tensor->type) {
|
| 1261 |
+
case GGML_TYPE_I8:
|
| 1262 |
+
{
|
| 1263 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
| 1264 |
+
return ((int8_t *)(tensor->data))[i];
|
| 1265 |
+
} break;
|
| 1266 |
+
case GGML_TYPE_I16:
|
| 1267 |
+
{
|
| 1268 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
|
| 1269 |
+
return ((int16_t *)(tensor->data))[i];
|
| 1270 |
+
} break;
|
| 1271 |
+
case GGML_TYPE_I32:
|
| 1272 |
+
{
|
| 1273 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
|
| 1274 |
+
return ((int32_t *)(tensor->data))[i];
|
| 1275 |
+
} break;
|
| 1276 |
+
case GGML_TYPE_F16:
|
| 1277 |
+
{
|
| 1278 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
|
| 1279 |
+
return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]);
|
| 1280 |
+
} break;
|
| 1281 |
+
case GGML_TYPE_F32:
|
| 1282 |
+
{
|
| 1283 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
| 1284 |
+
return ((float *)(tensor->data))[i];
|
| 1285 |
+
} break;
|
| 1286 |
+
case GGML_TYPE_COUNT:
|
| 1287 |
+
{
|
| 1288 |
+
GGML_ASSERT(false);
|
| 1289 |
+
} break;
|
| 1290 |
+
}
|
| 1291 |
+
|
| 1292 |
+
return 0.0f;
|
| 1293 |
+
}
|
| 1294 |
+
|
| 1295 |
+
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
| 1296 |
+
switch (tensor->type) {
|
| 1297 |
+
case GGML_TYPE_I8:
|
| 1298 |
+
{
|
| 1299 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
| 1300 |
+
((int8_t *)(tensor->data))[i] = value;
|
| 1301 |
+
} break;
|
| 1302 |
+
case GGML_TYPE_I16:
|
| 1303 |
+
{
|
| 1304 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
|
| 1305 |
+
((int16_t *)(tensor->data))[i] = value;
|
| 1306 |
+
} break;
|
| 1307 |
+
case GGML_TYPE_I32:
|
| 1308 |
+
{
|
| 1309 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
|
| 1310 |
+
((int32_t *)(tensor->data))[i] = value;
|
| 1311 |
+
} break;
|
| 1312 |
+
case GGML_TYPE_F16:
|
| 1313 |
+
{
|
| 1314 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
|
| 1315 |
+
((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value);
|
| 1316 |
+
} break;
|
| 1317 |
+
case GGML_TYPE_F32:
|
| 1318 |
+
{
|
| 1319 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
| 1320 |
+
((float *)(tensor->data))[i] = value;
|
| 1321 |
+
} break;
|
| 1322 |
+
case GGML_TYPE_COUNT:
|
| 1323 |
+
{
|
| 1324 |
+
GGML_ASSERT(false);
|
| 1325 |
+
} break;
|
| 1326 |
+
}
|
| 1327 |
+
}
|
| 1328 |
+
|
| 1329 |
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
| 1330 |
switch (tensor->type) {
|
| 1331 |
case GGML_TYPE_I8:
|
| 1332 |
{
|
| 1333 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
| 1334 |
return ((int8_t *)(tensor->data))[i];
|
| 1335 |
} break;
|
| 1336 |
case GGML_TYPE_I16:
|
| 1337 |
{
|
| 1338 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
|
| 1339 |
return ((int16_t *)(tensor->data))[i];
|
| 1340 |
} break;
|
| 1341 |
case GGML_TYPE_I32:
|
| 1342 |
{
|
| 1343 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
|
| 1344 |
return ((int32_t *)(tensor->data))[i];
|
| 1345 |
} break;
|
| 1346 |
case GGML_TYPE_F16:
|
| 1347 |
{
|
| 1348 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
|
| 1349 |
return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]);
|
| 1350 |
} break;
|
| 1351 |
case GGML_TYPE_F32:
|
| 1352 |
{
|
| 1353 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
| 1354 |
return ((float *)(tensor->data))[i];
|
| 1355 |
} break;
|
| 1356 |
case GGML_TYPE_COUNT:
|
| 1357 |
{
|
| 1358 |
+
GGML_ASSERT(false);
|
| 1359 |
} break;
|
| 1360 |
}
|
| 1361 |
|
|
|
|
| 1362 |
return 0.0f;
|
| 1363 |
}
|
| 1364 |
|
|
|
|
| 1366 |
switch (tensor->type) {
|
| 1367 |
case GGML_TYPE_I8:
|
| 1368 |
{
|
| 1369 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
| 1370 |
((int8_t *)(tensor->data))[i] = value;
|
| 1371 |
} break;
|
| 1372 |
case GGML_TYPE_I16:
|
| 1373 |
{
|
| 1374 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
|
| 1375 |
((int16_t *)(tensor->data))[i] = value;
|
| 1376 |
} break;
|
| 1377 |
case GGML_TYPE_I32:
|
| 1378 |
{
|
| 1379 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
|
| 1380 |
((int32_t *)(tensor->data))[i] = value;
|
| 1381 |
} break;
|
| 1382 |
case GGML_TYPE_F16:
|
| 1383 |
{
|
| 1384 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
|
| 1385 |
((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value);
|
| 1386 |
} break;
|
| 1387 |
case GGML_TYPE_F32:
|
| 1388 |
{
|
| 1389 |
+
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
| 1390 |
((float *)(tensor->data))[i] = value;
|
| 1391 |
} break;
|
| 1392 |
case GGML_TYPE_COUNT:
|
| 1393 |
{
|
| 1394 |
+
GGML_ASSERT(false);
|
| 1395 |
} break;
|
| 1396 |
}
|
| 1397 |
}
|
|
|
|
| 2485 |
return result;
|
| 2486 |
}
|
| 2487 |
|
| 2488 |
+
// ggml_flash_attn
|
| 2489 |
+
|
| 2490 |
+
struct ggml_tensor * ggml_flash_attn(
|
| 2491 |
+
struct ggml_context * ctx,
|
| 2492 |
+
struct ggml_tensor * q,
|
| 2493 |
+
struct ggml_tensor * k,
|
| 2494 |
+
struct ggml_tensor * v,
|
| 2495 |
+
bool masked) {
|
| 2496 |
+
assert(ggml_can_mul_mat(k, q));
|
| 2497 |
+
// TODO: check if vT can be multiplied by (k*qT)
|
| 2498 |
+
|
| 2499 |
+
bool is_node = false;
|
| 2500 |
+
|
| 2501 |
+
if (q->grad || k->grad || v->grad) {
|
| 2502 |
+
GGML_ASSERT(false); // TODO: implement backward
|
| 2503 |
+
is_node = true;
|
| 2504 |
+
}
|
| 2505 |
+
|
| 2506 |
+
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
| 2507 |
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne);
|
| 2508 |
+
|
| 2509 |
+
result->op = GGML_OP_FLASH_ATTN;
|
| 2510 |
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 2511 |
+
result->src0 = q;
|
| 2512 |
+
result->src1 = k;
|
| 2513 |
+
result->opt[0] = v;
|
| 2514 |
+
result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0);
|
| 2515 |
+
|
| 2516 |
+
return result;
|
| 2517 |
+
}
|
| 2518 |
+
|
| 2519 |
+
// ggml_flash_ff
|
| 2520 |
+
|
| 2521 |
+
struct ggml_tensor * ggml_flash_ff(
|
| 2522 |
+
struct ggml_context * ctx,
|
| 2523 |
+
struct ggml_tensor * a,
|
| 2524 |
+
struct ggml_tensor * b0,
|
| 2525 |
+
struct ggml_tensor * b1,
|
| 2526 |
+
struct ggml_tensor * c0,
|
| 2527 |
+
struct ggml_tensor * c1) {
|
| 2528 |
+
assert(ggml_can_mul_mat(b0, a));
|
| 2529 |
+
// TODO: more checks
|
| 2530 |
+
|
| 2531 |
+
bool is_node = false;
|
| 2532 |
+
|
| 2533 |
+
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
| 2534 |
+
GGML_ASSERT(false); // TODO: implement backward
|
| 2535 |
+
is_node = true;
|
| 2536 |
+
}
|
| 2537 |
+
|
| 2538 |
+
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
| 2539 |
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne);
|
| 2540 |
+
|
| 2541 |
+
result->op = GGML_OP_FLASH_FF;
|
| 2542 |
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 2543 |
+
result->src0 = a;
|
| 2544 |
+
result->src1 = b0;
|
| 2545 |
+
result->opt[0] = b1;
|
| 2546 |
+
result->opt[1] = c0;
|
| 2547 |
+
result->opt[2] = c1;
|
| 2548 |
+
|
| 2549 |
+
return result;
|
| 2550 |
+
}
|
| 2551 |
+
|
| 2552 |
////////////////////////////////////////////////////////////////////////////////
|
| 2553 |
|
| 2554 |
void ggml_set_param(
|
|
|
|
| 2656 |
GGML_ASSERT(false); // TODO: implement
|
| 2657 |
}
|
| 2658 |
} else {
|
| 2659 |
+
//printf("%s: this is not optimal - fix me\n", __func__);
|
| 2660 |
|
| 2661 |
if (dst->type == GGML_TYPE_F32) {
|
| 2662 |
int id = 0;
|
|
|
|
| 4426 |
}
|
| 4427 |
|
| 4428 |
ggml_float sum = 0.0;
|
| 4429 |
+
|
| 4430 |
for (int i = 0; i < nc; i++) {
|
| 4431 |
+
if (p[i] == -INFINITY) {
|
| 4432 |
+
p[i] = 0.0;
|
| 4433 |
+
} else {
|
| 4434 |
+
//const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
|
| 4435 |
+
ggml_fp16_t s = ggml_fp32_to_fp16(p[i] - max);
|
| 4436 |
+
const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]);
|
| 4437 |
+
sum += val;
|
| 4438 |
+
p[i] = val;
|
| 4439 |
+
}
|
| 4440 |
}
|
| 4441 |
|
| 4442 |
assert(sum > 0.0f);
|
|
|
|
| 4610 |
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 4611 |
GGML_ASSERT(nb10 == sizeof(float));
|
| 4612 |
|
|
|
|
| 4613 |
if (params->type == GGML_TASK_INIT) {
|
| 4614 |
// TODO: fix this memset (wsize is overestimated)
|
| 4615 |
memset(params->wdata, 0, params->wsize);
|
|
|
|
| 4730 |
GGML_ASSERT(nb00 == sizeof(float));
|
| 4731 |
GGML_ASSERT(nb10 == sizeof(float));
|
| 4732 |
|
|
|
|
| 4733 |
if (params->type == GGML_TASK_INIT) {
|
| 4734 |
// TODO: fix this memset (wsize is overestimated)
|
| 4735 |
memset(params->wdata, 0, params->wsize);
|
|
|
|
| 4876 |
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 4877 |
GGML_ASSERT(nb10 == sizeof(float));
|
| 4878 |
|
|
|
|
| 4879 |
if (params->type == GGML_TASK_INIT) {
|
| 4880 |
// TODO: fix this memset (wsize is overestimated)
|
| 4881 |
memset(params->wdata, 0, params->wsize);
|
|
|
|
| 4996 |
GGML_ASSERT(nb00 == sizeof(float));
|
| 4997 |
GGML_ASSERT(nb10 == sizeof(float));
|
| 4998 |
|
|
|
|
| 4999 |
if (params->type == GGML_TASK_INIT) {
|
| 5000 |
// TODO: fix this memset (wsize is overestimated)
|
| 5001 |
memset(params->wdata, 0, params->wsize);
|
|
|
|
| 5085 |
}
|
| 5086 |
}
|
| 5087 |
|
| 5088 |
+
// ggml_compute_forward_flash_attn
|
| 5089 |
+
|
| 5090 |
+
void ggml_compute_forward_flash_attn_f32(
|
| 5091 |
+
const struct ggml_compute_params * params,
|
| 5092 |
+
const struct ggml_tensor * q,
|
| 5093 |
+
const struct ggml_tensor * k,
|
| 5094 |
+
const struct ggml_tensor * v,
|
| 5095 |
+
const bool masked,
|
| 5096 |
+
struct ggml_tensor * dst) {
|
| 5097 |
+
int64_t t0 = ggml_perf_time_us();
|
| 5098 |
+
UNUSED(t0);
|
| 5099 |
+
|
| 5100 |
+
const int neq0 = q->ne[0];
|
| 5101 |
+
const int neq1 = q->ne[1];
|
| 5102 |
+
const int neq2 = q->ne[2];
|
| 5103 |
+
const int neq3 = q->ne[3];
|
| 5104 |
+
|
| 5105 |
+
const int nek0 = k->ne[0];
|
| 5106 |
+
const int nek1 = k->ne[1];
|
| 5107 |
+
//const int nek2 = k->ne[2];
|
| 5108 |
+
//const int nek3 = k->ne[3];
|
| 5109 |
+
|
| 5110 |
+
//const int nev0 = v->ne[0];
|
| 5111 |
+
const int nev1 = v->ne[1];
|
| 5112 |
+
//const int nev2 = v->ne[2];
|
| 5113 |
+
//const int nev3 = v->ne[3];
|
| 5114 |
+
|
| 5115 |
+
const int ne0 = dst->ne[0];
|
| 5116 |
+
const int ne1 = dst->ne[1];
|
| 5117 |
+
//const int ne2 = dst->ne[2];
|
| 5118 |
+
//const int ne3 = dst->ne[3];
|
| 5119 |
+
|
| 5120 |
+
const int nbk0 = k->nb[0];
|
| 5121 |
+
const int nbk1 = k->nb[1];
|
| 5122 |
+
const int nbk2 = k->nb[2];
|
| 5123 |
+
const int nbk3 = k->nb[3];
|
| 5124 |
+
|
| 5125 |
+
const int nbq0 = q->nb[0];
|
| 5126 |
+
const int nbq1 = q->nb[1];
|
| 5127 |
+
const int nbq2 = q->nb[2];
|
| 5128 |
+
const int nbq3 = q->nb[3];
|
| 5129 |
+
|
| 5130 |
+
const int nbv0 = v->nb[0];
|
| 5131 |
+
const int nbv1 = v->nb[1];
|
| 5132 |
+
const int nbv2 = v->nb[2];
|
| 5133 |
+
const int nbv3 = v->nb[3];
|
| 5134 |
+
|
| 5135 |
+
const int nb0 = dst->nb[0];
|
| 5136 |
+
const int nb1 = dst->nb[1];
|
| 5137 |
+
const int nb2 = dst->nb[2];
|
| 5138 |
+
const int nb3 = dst->nb[3];
|
| 5139 |
+
|
| 5140 |
+
const int ith = params->ith;
|
| 5141 |
+
const int nth = params->nth;
|
| 5142 |
+
|
| 5143 |
+
const int D = neq0;
|
| 5144 |
+
const int N = neq1;
|
| 5145 |
+
const int P = nek1 - N;
|
| 5146 |
+
const int M = P + N;
|
| 5147 |
+
|
| 5148 |
+
GGML_ASSERT(ne0 == D);
|
| 5149 |
+
GGML_ASSERT(ne1 == N);
|
| 5150 |
+
GGML_ASSERT(P >= 0);
|
| 5151 |
+
|
| 5152 |
+
GGML_ASSERT(nbq0 == sizeof(float));
|
| 5153 |
+
GGML_ASSERT(nbk0 == sizeof(float));
|
| 5154 |
+
GGML_ASSERT(nbv0 == sizeof(float));
|
| 5155 |
+
|
| 5156 |
+
GGML_ASSERT(neq0 == D);
|
| 5157 |
+
GGML_ASSERT(nek0 == D);
|
| 5158 |
+
GGML_ASSERT(nev1 == D);
|
| 5159 |
+
|
| 5160 |
+
GGML_ASSERT(neq1 == N);
|
| 5161 |
+
GGML_ASSERT(nek1 == N + P);
|
| 5162 |
+
GGML_ASSERT(nev1 == D);
|
| 5163 |
+
|
| 5164 |
+
// dst cannot be transposed or permuted
|
| 5165 |
+
GGML_ASSERT(nb0 == sizeof(float));
|
| 5166 |
+
GGML_ASSERT(nb0 <= nb1);
|
| 5167 |
+
GGML_ASSERT(nb1 <= nb2);
|
| 5168 |
+
GGML_ASSERT(nb2 <= nb3);
|
| 5169 |
+
|
| 5170 |
+
if (params->type == GGML_TASK_INIT) {
|
| 5171 |
+
return;
|
| 5172 |
+
}
|
| 5173 |
+
|
| 5174 |
+
if (params->type == GGML_TASK_FINALIZE) {
|
| 5175 |
+
return;
|
| 5176 |
+
}
|
| 5177 |
+
|
| 5178 |
+
// parallelize by q rows using ggml_vec_dot_f32
|
| 5179 |
+
|
| 5180 |
+
// total rows in q
|
| 5181 |
+
const int nr = neq1*neq2*neq3;
|
| 5182 |
+
|
| 5183 |
+
// rows per thread
|
| 5184 |
+
const int dr = (nr + nth - 1)/nth;
|
| 5185 |
+
|
| 5186 |
+
// row range for this thread
|
| 5187 |
+
const int ir0 = dr*ith;
|
| 5188 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 5189 |
+
|
| 5190 |
+
const float scale = 1.0/sqrt((double) D);
|
| 5191 |
+
|
| 5192 |
+
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
| 5193 |
+
|
| 5194 |
+
for (int ir = ir0; ir < ir1; ++ir) {
|
| 5195 |
+
// q indices
|
| 5196 |
+
const int iq3 = ir/(neq2*neq1);
|
| 5197 |
+
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
| 5198 |
+
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
| 5199 |
+
|
| 5200 |
+
float * S = (float *) params->wdata + ith*(M + CACHE_LINE_SIZE_F32);
|
| 5201 |
+
|
| 5202 |
+
for (int ic = 0; ic < nek1; ++ic) {
|
| 5203 |
+
// k indices
|
| 5204 |
+
const int ik3 = iq3;
|
| 5205 |
+
const int ik2 = iq2;
|
| 5206 |
+
const int ik1 = ic;
|
| 5207 |
+
|
| 5208 |
+
// S indices
|
| 5209 |
+
const int i1 = ik1;
|
| 5210 |
+
|
| 5211 |
+
ggml_vec_dot_f32(neq0,
|
| 5212 |
+
S + i1,
|
| 5213 |
+
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
| 5214 |
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
| 5215 |
+
}
|
| 5216 |
+
|
| 5217 |
+
// scale
|
| 5218 |
+
ggml_vec_scale_f32(nek1, S, scale);
|
| 5219 |
+
|
| 5220 |
+
if (masked) {
|
| 5221 |
+
for (int i = P; i < M; i++) {
|
| 5222 |
+
if (i > P + iq1) {
|
| 5223 |
+
S[i] = -INFINITY;
|
| 5224 |
+
}
|
| 5225 |
+
}
|
| 5226 |
+
}
|
| 5227 |
+
|
| 5228 |
+
// softmax
|
| 5229 |
+
{
|
| 5230 |
+
float max = -INFINITY;
|
| 5231 |
+
for (int i = 0; i < M; i++) {
|
| 5232 |
+
max = MAX(max, S[i]);
|
| 5233 |
+
}
|
| 5234 |
+
|
| 5235 |
+
ggml_float sum = 0.0;
|
| 5236 |
+
|
| 5237 |
+
for (int i = 0; i < M; i++) {
|
| 5238 |
+
if (S[i] == -INFINITY) {
|
| 5239 |
+
S[i] = 0.0;
|
| 5240 |
+
} else {
|
| 5241 |
+
//const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
|
| 5242 |
+
ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max);
|
| 5243 |
+
const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]);
|
| 5244 |
+
sum += val;
|
| 5245 |
+
S[i] = val;
|
| 5246 |
+
}
|
| 5247 |
+
}
|
| 5248 |
+
|
| 5249 |
+
assert(sum > 0.0f);
|
| 5250 |
+
|
| 5251 |
+
sum = 1.0/sum;
|
| 5252 |
+
ggml_vec_scale_f32(M, S, sum);
|
| 5253 |
+
}
|
| 5254 |
+
|
| 5255 |
+
for (int ic = 0; ic < nev1; ++ic) {
|
| 5256 |
+
// dst indices
|
| 5257 |
+
const int i1 = iq1;
|
| 5258 |
+
const int i2 = iq2;
|
| 5259 |
+
const int i3 = iq3;
|
| 5260 |
+
|
| 5261 |
+
ggml_vec_dot_f32(nek1,
|
| 5262 |
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
| 5263 |
+
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
| 5264 |
+
S);
|
| 5265 |
+
}
|
| 5266 |
+
}
|
| 5267 |
+
}
|
| 5268 |
+
|
| 5269 |
+
void ggml_compute_forward_flash_attn_f16(
|
| 5270 |
+
const struct ggml_compute_params * params,
|
| 5271 |
+
const struct ggml_tensor * q,
|
| 5272 |
+
const struct ggml_tensor * k,
|
| 5273 |
+
const struct ggml_tensor * v,
|
| 5274 |
+
const bool masked,
|
| 5275 |
+
struct ggml_tensor * dst) {
|
| 5276 |
+
int64_t t0 = ggml_perf_time_us();
|
| 5277 |
+
UNUSED(t0);
|
| 5278 |
+
|
| 5279 |
+
const int neq0 = q->ne[0];
|
| 5280 |
+
const int neq1 = q->ne[1];
|
| 5281 |
+
const int neq2 = q->ne[2];
|
| 5282 |
+
const int neq3 = q->ne[3];
|
| 5283 |
+
|
| 5284 |
+
const int nek0 = k->ne[0];
|
| 5285 |
+
const int nek1 = k->ne[1];
|
| 5286 |
+
//const int nek2 = k->ne[2];
|
| 5287 |
+
//const int nek3 = k->ne[3];
|
| 5288 |
+
|
| 5289 |
+
//const int nev0 = v->ne[0];
|
| 5290 |
+
const int nev1 = v->ne[1];
|
| 5291 |
+
//const int nev2 = v->ne[2];
|
| 5292 |
+
//const int nev3 = v->ne[3];
|
| 5293 |
+
|
| 5294 |
+
const int ne0 = dst->ne[0];
|
| 5295 |
+
const int ne1 = dst->ne[1];
|
| 5296 |
+
//const int ne2 = dst->ne[2];
|
| 5297 |
+
//const int ne3 = dst->ne[3];
|
| 5298 |
+
|
| 5299 |
+
const int nbk0 = k->nb[0];
|
| 5300 |
+
const int nbk1 = k->nb[1];
|
| 5301 |
+
const int nbk2 = k->nb[2];
|
| 5302 |
+
const int nbk3 = k->nb[3];
|
| 5303 |
+
|
| 5304 |
+
const int nbq0 = q->nb[0];
|
| 5305 |
+
const int nbq1 = q->nb[1];
|
| 5306 |
+
const int nbq2 = q->nb[2];
|
| 5307 |
+
const int nbq3 = q->nb[3];
|
| 5308 |
+
|
| 5309 |
+
const int nbv0 = v->nb[0];
|
| 5310 |
+
const int nbv1 = v->nb[1];
|
| 5311 |
+
const int nbv2 = v->nb[2];
|
| 5312 |
+
const int nbv3 = v->nb[3];
|
| 5313 |
+
|
| 5314 |
+
const int nb0 = dst->nb[0];
|
| 5315 |
+
const int nb1 = dst->nb[1];
|
| 5316 |
+
const int nb2 = dst->nb[2];
|
| 5317 |
+
const int nb3 = dst->nb[3];
|
| 5318 |
+
|
| 5319 |
+
const int ith = params->ith;
|
| 5320 |
+
const int nth = params->nth;
|
| 5321 |
+
|
| 5322 |
+
const int D = neq0;
|
| 5323 |
+
const int N = neq1;
|
| 5324 |
+
const int P = nek1 - N;
|
| 5325 |
+
const int M = P + N;
|
| 5326 |
+
|
| 5327 |
+
GGML_ASSERT(ne0 == D);
|
| 5328 |
+
GGML_ASSERT(ne1 == N);
|
| 5329 |
+
GGML_ASSERT(P >= 0);
|
| 5330 |
+
|
| 5331 |
+
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
|
| 5332 |
+
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
| 5333 |
+
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
| 5334 |
+
|
| 5335 |
+
GGML_ASSERT(neq0 == D);
|
| 5336 |
+
GGML_ASSERT(nek0 == D);
|
| 5337 |
+
GGML_ASSERT(nev1 == D);
|
| 5338 |
+
|
| 5339 |
+
GGML_ASSERT(neq1 == N);
|
| 5340 |
+
GGML_ASSERT(nek1 == N + P);
|
| 5341 |
+
GGML_ASSERT(nev1 == D);
|
| 5342 |
+
|
| 5343 |
+
// dst cannot be transposed or permuted
|
| 5344 |
+
GGML_ASSERT(nb0 == sizeof(float));
|
| 5345 |
+
GGML_ASSERT(nb0 <= nb1);
|
| 5346 |
+
GGML_ASSERT(nb1 <= nb2);
|
| 5347 |
+
GGML_ASSERT(nb2 <= nb3);
|
| 5348 |
+
|
| 5349 |
+
if (params->type == GGML_TASK_INIT) {
|
| 5350 |
+
return;
|
| 5351 |
+
}
|
| 5352 |
+
|
| 5353 |
+
if (params->type == GGML_TASK_FINALIZE) {
|
| 5354 |
+
return;
|
| 5355 |
+
}
|
| 5356 |
+
|
| 5357 |
+
// parallelize by q rows using ggml_vec_dot_f32
|
| 5358 |
+
|
| 5359 |
+
// total rows in q
|
| 5360 |
+
const int nr = neq1*neq2*neq3;
|
| 5361 |
+
|
| 5362 |
+
// rows per thread
|
| 5363 |
+
const int dr = (nr + nth - 1)/nth;
|
| 5364 |
+
|
| 5365 |
+
// row range for this thread
|
| 5366 |
+
const int ir0 = dr*ith;
|
| 5367 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 5368 |
+
|
| 5369 |
+
const float scale = 1.0/sqrt((double) D);
|
| 5370 |
+
|
| 5371 |
+
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
| 5372 |
+
|
| 5373 |
+
for (int ir = ir0; ir < ir1; ++ir) {
|
| 5374 |
+
// q indices
|
| 5375 |
+
const int iq3 = ir/(neq2*neq1);
|
| 5376 |
+
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
| 5377 |
+
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
| 5378 |
+
|
| 5379 |
+
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
|
| 5380 |
+
|
| 5381 |
+
for (int ic = 0; ic < nek1; ++ic) {
|
| 5382 |
+
// k indices
|
| 5383 |
+
const int ik3 = iq3;
|
| 5384 |
+
const int ik2 = iq2;
|
| 5385 |
+
const int ik1 = ic;
|
| 5386 |
+
|
| 5387 |
+
// S indices
|
| 5388 |
+
const int i1 = ik1;
|
| 5389 |
+
|
| 5390 |
+
ggml_vec_dot_f16(neq0,
|
| 5391 |
+
S + i1,
|
| 5392 |
+
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
| 5393 |
+
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
| 5394 |
+
}
|
| 5395 |
+
|
| 5396 |
+
// scale
|
| 5397 |
+
ggml_vec_scale_f32(nek1, S, scale);
|
| 5398 |
+
|
| 5399 |
+
if (masked) {
|
| 5400 |
+
for (int i = P; i < M; i++) {
|
| 5401 |
+
if (i > P + iq1) {
|
| 5402 |
+
S[i] = -INFINITY;
|
| 5403 |
+
}
|
| 5404 |
+
}
|
| 5405 |
+
}
|
| 5406 |
+
|
| 5407 |
+
// softmax
|
| 5408 |
+
{
|
| 5409 |
+
float max = -INFINITY;
|
| 5410 |
+
for (int i = 0; i < M; i++) {
|
| 5411 |
+
max = MAX(max, S[i]);
|
| 5412 |
+
}
|
| 5413 |
+
|
| 5414 |
+
ggml_float sum = 0.0;
|
| 5415 |
+
|
| 5416 |
+
for (int i = 0; i < M; i++) {
|
| 5417 |
+
if (S[i] == -INFINITY) {
|
| 5418 |
+
S[i] = 0.0;
|
| 5419 |
+
} else {
|
| 5420 |
+
//const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
|
| 5421 |
+
ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max);
|
| 5422 |
+
const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]);
|
| 5423 |
+
sum += val;
|
| 5424 |
+
S[i] = val;
|
| 5425 |
+
}
|
| 5426 |
+
}
|
| 5427 |
+
|
| 5428 |
+
assert(sum > 0.0f);
|
| 5429 |
+
|
| 5430 |
+
sum = 1.0/sum;
|
| 5431 |
+
ggml_vec_scale_f32(M, S, sum);
|
| 5432 |
+
}
|
| 5433 |
+
|
| 5434 |
+
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
|
| 5435 |
+
|
| 5436 |
+
for (int i = 0; i < M; i++) {
|
| 5437 |
+
S16[i] = ggml_fp32_to_fp16(S[i]);
|
| 5438 |
+
}
|
| 5439 |
+
|
| 5440 |
+
for (int ic = 0; ic < nev1; ++ic) {
|
| 5441 |
+
// dst indices
|
| 5442 |
+
const int i1 = iq1;
|
| 5443 |
+
const int i2 = iq2;
|
| 5444 |
+
const int i3 = iq3;
|
| 5445 |
+
|
| 5446 |
+
ggml_vec_dot_f16(nek1,
|
| 5447 |
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
| 5448 |
+
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
| 5449 |
+
S16);
|
| 5450 |
+
}
|
| 5451 |
+
}
|
| 5452 |
+
}
|
| 5453 |
+
|
| 5454 |
+
void ggml_compute_forward_flash_attn(
|
| 5455 |
+
const struct ggml_compute_params * params,
|
| 5456 |
+
const struct ggml_tensor * q,
|
| 5457 |
+
const struct ggml_tensor * k,
|
| 5458 |
+
const struct ggml_tensor * v,
|
| 5459 |
+
const bool masked,
|
| 5460 |
+
struct ggml_tensor * dst) {
|
| 5461 |
+
switch (q->type) {
|
| 5462 |
+
case GGML_TYPE_F16:
|
| 5463 |
+
{
|
| 5464 |
+
ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst);
|
| 5465 |
+
} break;
|
| 5466 |
+
case GGML_TYPE_F32:
|
| 5467 |
+
{
|
| 5468 |
+
ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
|
| 5469 |
+
} break;
|
| 5470 |
+
case GGML_TYPE_I8:
|
| 5471 |
+
case GGML_TYPE_I16:
|
| 5472 |
+
case GGML_TYPE_I32:
|
| 5473 |
+
case GGML_TYPE_COUNT:
|
| 5474 |
+
{
|
| 5475 |
+
assert(false);
|
| 5476 |
+
} break;
|
| 5477 |
+
}
|
| 5478 |
+
}
|
| 5479 |
+
|
| 5480 |
+
// ggml_compute_forward_flash_ff
|
| 5481 |
+
|
| 5482 |
+
void ggml_compute_forward_flash_ff_f16(
|
| 5483 |
+
const struct ggml_compute_params * params,
|
| 5484 |
+
const struct ggml_tensor * a, // F16
|
| 5485 |
+
const struct ggml_tensor * b0, // F16 fc_w
|
| 5486 |
+
const struct ggml_tensor * b1, // F32 fc_b
|
| 5487 |
+
const struct ggml_tensor * c0, // F16 proj_w
|
| 5488 |
+
const struct ggml_tensor * c1, // F32 proj_b
|
| 5489 |
+
struct ggml_tensor * dst) {
|
| 5490 |
+
int64_t t0 = ggml_perf_time_us();
|
| 5491 |
+
UNUSED(t0);
|
| 5492 |
+
|
| 5493 |
+
const int nea0 = a->ne[0];
|
| 5494 |
+
const int nea1 = a->ne[1];
|
| 5495 |
+
const int nea2 = a->ne[2];
|
| 5496 |
+
const int nea3 = a->ne[3];
|
| 5497 |
+
|
| 5498 |
+
const int neb00 = b0->ne[0];
|
| 5499 |
+
const int neb01 = b0->ne[1];
|
| 5500 |
+
//const int neb02 = b0->ne[2];
|
| 5501 |
+
//const int neb03 = b0->ne[3];
|
| 5502 |
+
|
| 5503 |
+
const int neb10 = b1->ne[0];
|
| 5504 |
+
const int neb11 = b1->ne[1];
|
| 5505 |
+
//const int neb12 = b1->ne[2];
|
| 5506 |
+
//const int neb13 = b1->ne[3];
|
| 5507 |
+
|
| 5508 |
+
const int nec00 = c0->ne[0];
|
| 5509 |
+
const int nec01 = c0->ne[1];
|
| 5510 |
+
//const int nec02 = c0->ne[2];
|
| 5511 |
+
//const int nec03 = c0->ne[3];
|
| 5512 |
+
|
| 5513 |
+
const int nec10 = c1->ne[0];
|
| 5514 |
+
const int nec11 = c1->ne[1];
|
| 5515 |
+
//const int nec12 = c1->ne[2];
|
| 5516 |
+
//const int nec13 = c1->ne[3];
|
| 5517 |
+
|
| 5518 |
+
const int ne0 = dst->ne[0];
|
| 5519 |
+
const int ne1 = dst->ne[1];
|
| 5520 |
+
const int ne2 = dst->ne[2];
|
| 5521 |
+
//const int ne3 = dst->ne[3];
|
| 5522 |
+
|
| 5523 |
+
const int nba0 = a->nb[0];
|
| 5524 |
+
const int nba1 = a->nb[1];
|
| 5525 |
+
const int nba2 = a->nb[2];
|
| 5526 |
+
const int nba3 = a->nb[3];
|
| 5527 |
+
|
| 5528 |
+
const int nbb00 = b0->nb[0];
|
| 5529 |
+
const int nbb01 = b0->nb[1];
|
| 5530 |
+
const int nbb02 = b0->nb[2];
|
| 5531 |
+
const int nbb03 = b0->nb[3];
|
| 5532 |
+
|
| 5533 |
+
const int nbb10 = b1->nb[0];
|
| 5534 |
+
//const int nbb11 = b1->nb[1];
|
| 5535 |
+
//const int nbb12 = b1->nb[2];
|
| 5536 |
+
//const int nbb13 = b1->nb[3];
|
| 5537 |
+
|
| 5538 |
+
const int nbc00 = c0->nb[0];
|
| 5539 |
+
const int nbc01 = c0->nb[1];
|
| 5540 |
+
const int nbc02 = c0->nb[2];
|
| 5541 |
+
const int nbc03 = c0->nb[3];
|
| 5542 |
+
|
| 5543 |
+
const int nbc10 = c1->nb[0];
|
| 5544 |
+
//const int nbc11 = c1->nb[1];
|
| 5545 |
+
//const int nbc12 = c1->nb[2];
|
| 5546 |
+
//const int nbc13 = c1->nb[3];
|
| 5547 |
+
|
| 5548 |
+
const int nb0 = dst->nb[0];
|
| 5549 |
+
const int nb1 = dst->nb[1];
|
| 5550 |
+
const int nb2 = dst->nb[2];
|
| 5551 |
+
const int nb3 = dst->nb[3];
|
| 5552 |
+
|
| 5553 |
+
const int ith = params->ith;
|
| 5554 |
+
const int nth = params->nth;
|
| 5555 |
+
|
| 5556 |
+
const int D = nea0;
|
| 5557 |
+
//const int N = nea1;
|
| 5558 |
+
const int M = neb01;
|
| 5559 |
+
|
| 5560 |
+
GGML_ASSERT(ne0 == nea0);
|
| 5561 |
+
GGML_ASSERT(ne1 == nea1);
|
| 5562 |
+
GGML_ASSERT(ne2 == nea2);
|
| 5563 |
+
|
| 5564 |
+
GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
|
| 5565 |
+
GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
|
| 5566 |
+
GGML_ASSERT(nbb10 == sizeof(float));
|
| 5567 |
+
GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
|
| 5568 |
+
GGML_ASSERT(nbc10 == sizeof(float));
|
| 5569 |
+
|
| 5570 |
+
GGML_ASSERT(neb00 == D);
|
| 5571 |
+
GGML_ASSERT(neb01 == M);
|
| 5572 |
+
GGML_ASSERT(neb10 == M);
|
| 5573 |
+
GGML_ASSERT(neb11 == 1);
|
| 5574 |
+
|
| 5575 |
+
GGML_ASSERT(nec00 == M);
|
| 5576 |
+
GGML_ASSERT(nec01 == D);
|
| 5577 |
+
GGML_ASSERT(nec10 == D);
|
| 5578 |
+
GGML_ASSERT(nec11 == 1);
|
| 5579 |
+
|
| 5580 |
+
// dst cannot be transposed or permuted
|
| 5581 |
+
GGML_ASSERT(nb0 == sizeof(float));
|
| 5582 |
+
GGML_ASSERT(nb0 <= nb1);
|
| 5583 |
+
GGML_ASSERT(nb1 <= nb2);
|
| 5584 |
+
GGML_ASSERT(nb2 <= nb3);
|
| 5585 |
+
|
| 5586 |
+
if (params->type == GGML_TASK_INIT) {
|
| 5587 |
+
return;
|
| 5588 |
+
}
|
| 5589 |
+
|
| 5590 |
+
if (params->type == GGML_TASK_FINALIZE) {
|
| 5591 |
+
return;
|
| 5592 |
+
}
|
| 5593 |
+
|
| 5594 |
+
// parallelize by a rows using ggml_vec_dot_f32
|
| 5595 |
+
|
| 5596 |
+
// total rows in a
|
| 5597 |
+
const int nr = nea1*nea2*nea3;
|
| 5598 |
+
|
| 5599 |
+
// rows per thread
|
| 5600 |
+
const int dr = (nr + nth - 1)/nth;
|
| 5601 |
+
|
| 5602 |
+
// row range for this thread
|
| 5603 |
+
const int ir0 = dr*ith;
|
| 5604 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 5605 |
+
|
| 5606 |
+
for (int ir = ir0; ir < ir1; ++ir) {
|
| 5607 |
+
// a indices
|
| 5608 |
+
const int ia3 = ir/(nea2*nea1);
|
| 5609 |
+
const int ia2 = (ir - ia3*nea2*nea1)/nea1;
|
| 5610 |
+
const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
|
| 5611 |
+
|
| 5612 |
+
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
|
| 5613 |
+
|
| 5614 |
+
for (int ic = 0; ic < neb01; ++ic) {
|
| 5615 |
+
// b0 indices
|
| 5616 |
+
const int ib03 = ia3;
|
| 5617 |
+
const int ib02 = ia2;
|
| 5618 |
+
const int ib01 = ic;
|
| 5619 |
+
|
| 5620 |
+
// S indices
|
| 5621 |
+
const int i1 = ib01;
|
| 5622 |
+
|
| 5623 |
+
ggml_vec_dot_f16(nea0,
|
| 5624 |
+
S + i1,
|
| 5625 |
+
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
|
| 5626 |
+
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)));
|
| 5627 |
+
}
|
| 5628 |
+
|
| 5629 |
+
ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
|
| 5630 |
+
//ggml_vec_gelu_f32(neb01, S, S);
|
| 5631 |
+
|
| 5632 |
+
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
|
| 5633 |
+
|
| 5634 |
+
for (int i = 0; i < M; i++) {
|
| 5635 |
+
S16[i] = ggml_fp32_to_fp16(S[i]);
|
| 5636 |
+
}
|
| 5637 |
+
|
| 5638 |
+
ggml_vec_gelu_f16(neb01, S16, S16);
|
| 5639 |
+
|
| 5640 |
+
{
|
| 5641 |
+
// dst indices
|
| 5642 |
+
const int i1 = ia1;
|
| 5643 |
+
const int i2 = ia2;
|
| 5644 |
+
const int i3 = ia3;
|
| 5645 |
+
|
| 5646 |
+
for (int ic = 0; ic < nec01; ++ic) {
|
| 5647 |
+
|
| 5648 |
+
ggml_vec_dot_f16(neb01,
|
| 5649 |
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
| 5650 |
+
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)),
|
| 5651 |
+
S16);
|
| 5652 |
+
}
|
| 5653 |
+
|
| 5654 |
+
ggml_vec_add_f32(nec01,
|
| 5655 |
+
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
| 5656 |
+
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
| 5657 |
+
(float *) c1->data);
|
| 5658 |
+
}
|
| 5659 |
+
}
|
| 5660 |
+
}
|
| 5661 |
+
|
| 5662 |
+
void ggml_compute_forward_flash_ff(
|
| 5663 |
+
const struct ggml_compute_params * params,
|
| 5664 |
+
const struct ggml_tensor * a,
|
| 5665 |
+
const struct ggml_tensor * b0,
|
| 5666 |
+
const struct ggml_tensor * b1,
|
| 5667 |
+
const struct ggml_tensor * c0,
|
| 5668 |
+
const struct ggml_tensor * c1,
|
| 5669 |
+
struct ggml_tensor * dst) {
|
| 5670 |
+
switch (b0->type) {
|
| 5671 |
+
case GGML_TYPE_F16:
|
| 5672 |
+
{
|
| 5673 |
+
ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst);
|
| 5674 |
+
} break;
|
| 5675 |
+
case GGML_TYPE_F32:
|
| 5676 |
+
{
|
| 5677 |
+
GGML_ASSERT(false); // TODO
|
| 5678 |
+
} break;
|
| 5679 |
+
case GGML_TYPE_I8:
|
| 5680 |
+
case GGML_TYPE_I16:
|
| 5681 |
+
case GGML_TYPE_I32:
|
| 5682 |
+
case GGML_TYPE_COUNT:
|
| 5683 |
+
{
|
| 5684 |
+
assert(false);
|
| 5685 |
+
} break;
|
| 5686 |
+
}
|
| 5687 |
+
}
|
| 5688 |
+
|
| 5689 |
/////////////////////////////////
|
| 5690 |
|
| 5691 |
void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
|
|
|
| 5812 |
{
|
| 5813 |
ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor);
|
| 5814 |
} break;
|
| 5815 |
+
case GGML_OP_FLASH_ATTN:
|
| 5816 |
+
{
|
| 5817 |
+
int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
| 5818 |
+
GGML_ASSERT(t == 0 || t == 1);
|
| 5819 |
+
bool masked = t != 0;
|
| 5820 |
+
ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
|
| 5821 |
+
} break;
|
| 5822 |
+
case GGML_OP_FLASH_FF:
|
| 5823 |
+
{
|
| 5824 |
+
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
| 5825 |
+
} break;
|
| 5826 |
case GGML_OP_NONE:
|
| 5827 |
{
|
| 5828 |
// nop
|
| 5829 |
} break;
|
| 5830 |
case GGML_OP_COUNT:
|
| 5831 |
{
|
| 5832 |
+
GGML_ASSERT(false);
|
| 5833 |
} break;
|
| 5834 |
};
|
| 5835 |
}
|
|
|
|
| 6061 |
{
|
| 6062 |
GGML_ASSERT(false); // TODO: not implemented
|
| 6063 |
} break;
|
| 6064 |
+
case GGML_OP_FLASH_ATTN:
|
| 6065 |
+
{
|
| 6066 |
+
GGML_ASSERT(false); // not supported
|
| 6067 |
+
} break;
|
| 6068 |
+
case GGML_OP_FLASH_FF:
|
| 6069 |
+
{
|
| 6070 |
+
GGML_ASSERT(false); // not supported
|
| 6071 |
+
} break;
|
| 6072 |
case GGML_OP_NONE:
|
| 6073 |
{
|
| 6074 |
// nop
|
|
|
|
| 6110 |
ggml_visit_parents(cgraph, node->src1);
|
| 6111 |
}
|
| 6112 |
|
| 6113 |
+
for (int i = 0; i < GGML_MAX_OPT; ++i) {
|
| 6114 |
+
if (node->opt[i]) {
|
| 6115 |
+
ggml_visit_parents(cgraph, node->opt[i]);
|
| 6116 |
+
}
|
| 6117 |
+
}
|
| 6118 |
+
|
| 6119 |
if (node->op == GGML_OP_NONE && node->grad == NULL) {
|
| 6120 |
// reached a leaf node, not part of the gradient graph (e.g. a constant)
|
| 6121 |
assert(cgraph->n_leafs < GGML_MAX_NODES);
|
|
|
|
| 6461 |
case GGML_OP_CONV_1D_1S:
|
| 6462 |
case GGML_OP_CONV_1D_2S:
|
| 6463 |
{
|
|
|
|
| 6464 |
node->n_tasks = n_threads;
|
| 6465 |
|
| 6466 |
GGML_ASSERT(node->src0->ne[3] == 1);
|
|
|
|
| 6486 |
GGML_ASSERT(false);
|
| 6487 |
}
|
| 6488 |
|
| 6489 |
+
work_size = MAX(work_size, cur);
|
| 6490 |
+
} break;
|
| 6491 |
+
case GGML_OP_FLASH_ATTN:
|
| 6492 |
+
{
|
| 6493 |
+
node->n_tasks = n_threads;
|
| 6494 |
+
|
| 6495 |
+
size_t cur = 0;
|
| 6496 |
+
|
| 6497 |
+
if (node->src1->type == GGML_TYPE_F32) {
|
| 6498 |
+
cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
|
| 6499 |
+
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
| 6500 |
+
}
|
| 6501 |
+
|
| 6502 |
+
if (node->src1->type == GGML_TYPE_F16) {
|
| 6503 |
+
cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
|
| 6504 |
+
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
| 6505 |
+
}
|
| 6506 |
+
|
| 6507 |
+
work_size = MAX(work_size, cur);
|
| 6508 |
+
} break;
|
| 6509 |
+
case GGML_OP_FLASH_FF:
|
| 6510 |
+
{
|
| 6511 |
+
node->n_tasks = n_threads;
|
| 6512 |
+
|
| 6513 |
+
size_t cur = 0;
|
| 6514 |
+
|
| 6515 |
+
if (node->src1->type == GGML_TYPE_F32) {
|
| 6516 |
+
cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
|
| 6517 |
+
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
| 6518 |
+
}
|
| 6519 |
+
|
| 6520 |
+
if (node->src1->type == GGML_TYPE_F16) {
|
| 6521 |
+
cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
|
| 6522 |
+
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
| 6523 |
+
}
|
| 6524 |
+
|
| 6525 |
work_size = MAX(work_size, cur);
|
| 6526 |
} break;
|
| 6527 |
case GGML_OP_NONE:
|
ggml.h
CHANGED
|
@@ -12,6 +12,7 @@ extern "C" {
|
|
| 12 |
#define GGML_MAX_NODES 4096
|
| 13 |
#define GGML_MAX_PARAMS 16
|
| 14 |
#define GGML_MAX_CONTEXTS 16
|
|
|
|
| 15 |
|
| 16 |
#ifdef __ARM_NEON
|
| 17 |
// we use the built-in 16-bit float type
|
|
@@ -71,6 +72,9 @@ enum ggml_op {
|
|
| 71 |
GGML_OP_CONV_1D_1S,
|
| 72 |
GGML_OP_CONV_1D_2S,
|
| 73 |
|
|
|
|
|
|
|
|
|
|
| 74 |
GGML_OP_COUNT,
|
| 75 |
};
|
| 76 |
|
|
@@ -93,6 +97,7 @@ struct ggml_tensor {
|
|
| 93 |
struct ggml_tensor * grad;
|
| 94 |
struct ggml_tensor * src0;
|
| 95 |
struct ggml_tensor * src1;
|
|
|
|
| 96 |
|
| 97 |
// thread scheduling
|
| 98 |
int n_tasks;
|
|
@@ -182,14 +187,19 @@ struct ggml_tensor * ggml_new_tensor_4d(
|
|
| 182 |
int ne2,
|
| 183 |
int ne3);
|
| 184 |
|
|
|
|
| 185 |
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
|
| 186 |
|
| 187 |
struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
|
| 188 |
struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
|
| 189 |
|
| 190 |
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
|
|
|
|
| 191 |
struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
|
| 192 |
|
|
|
|
|
|
|
|
|
|
| 193 |
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
|
| 194 |
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
|
| 195 |
|
|
@@ -399,6 +409,21 @@ struct ggml_tensor * ggml_conv_1d_2s(
|
|
| 399 |
struct ggml_tensor * a,
|
| 400 |
struct ggml_tensor * b);
|
| 401 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
//
|
| 403 |
// automatic differentiation
|
| 404 |
//
|
|
|
|
| 12 |
#define GGML_MAX_NODES 4096
|
| 13 |
#define GGML_MAX_PARAMS 16
|
| 14 |
#define GGML_MAX_CONTEXTS 16
|
| 15 |
+
#define GGML_MAX_OPT 4
|
| 16 |
|
| 17 |
#ifdef __ARM_NEON
|
| 18 |
// we use the built-in 16-bit float type
|
|
|
|
| 72 |
GGML_OP_CONV_1D_1S,
|
| 73 |
GGML_OP_CONV_1D_2S,
|
| 74 |
|
| 75 |
+
GGML_OP_FLASH_ATTN,
|
| 76 |
+
GGML_OP_FLASH_FF,
|
| 77 |
+
|
| 78 |
GGML_OP_COUNT,
|
| 79 |
};
|
| 80 |
|
|
|
|
| 97 |
struct ggml_tensor * grad;
|
| 98 |
struct ggml_tensor * src0;
|
| 99 |
struct ggml_tensor * src1;
|
| 100 |
+
struct ggml_tensor * opt[GGML_MAX_OPT];
|
| 101 |
|
| 102 |
// thread scheduling
|
| 103 |
int n_tasks;
|
|
|
|
| 187 |
int ne2,
|
| 188 |
int ne3);
|
| 189 |
|
| 190 |
+
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
|
| 191 |
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
|
| 192 |
|
| 193 |
struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
|
| 194 |
struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
|
| 195 |
|
| 196 |
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
|
| 197 |
+
struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
|
| 198 |
struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
|
| 199 |
|
| 200 |
+
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
|
| 201 |
+
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
|
| 202 |
+
|
| 203 |
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
|
| 204 |
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
|
| 205 |
|
|
|
|
| 409 |
struct ggml_tensor * a,
|
| 410 |
struct ggml_tensor * b);
|
| 411 |
|
| 412 |
+
struct ggml_tensor * ggml_flash_attn(
|
| 413 |
+
struct ggml_context * ctx,
|
| 414 |
+
struct ggml_tensor * q,
|
| 415 |
+
struct ggml_tensor * k,
|
| 416 |
+
struct ggml_tensor * v,
|
| 417 |
+
bool masked);
|
| 418 |
+
|
| 419 |
+
struct ggml_tensor * ggml_flash_ff(
|
| 420 |
+
struct ggml_context * ctx,
|
| 421 |
+
struct ggml_tensor * a,
|
| 422 |
+
struct ggml_tensor * b0,
|
| 423 |
+
struct ggml_tensor * b1,
|
| 424 |
+
struct ggml_tensor * c0,
|
| 425 |
+
struct ggml_tensor * c1);
|
| 426 |
+
|
| 427 |
//
|
| 428 |
// automatic differentiation
|
| 429 |
//
|
main.cpp
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
#include "ggml.h"
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
// third-party utilities
|
| 4 |
// use your favorite implementations
|
| 5 |
#define DR_WAV_IMPLEMENTATION
|
|
@@ -16,6 +19,7 @@
|
|
| 16 |
#include <thread>
|
| 17 |
#include <vector>
|
| 18 |
|
|
|
|
| 19 |
enum e_model {
|
| 20 |
MODEL_UNKNOWN,
|
| 21 |
MODEL_TINY,
|
|
@@ -25,14 +29,116 @@ enum e_model {
|
|
| 25 |
MODEL_LARGE,
|
| 26 |
};
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
const size_t MB = 1024*1024;
|
| 29 |
|
| 30 |
const std::map<e_model, size_t> MEM_REQ_MODEL = {
|
| 31 |
-
{ MODEL_TINY,
|
| 32 |
-
{ MODEL_BASE,
|
| 33 |
-
{ MODEL_SMALL,
|
| 34 |
-
{ MODEL_MEDIUM,
|
| 35 |
-
{ MODEL_LARGE,
|
| 36 |
};
|
| 37 |
|
| 38 |
const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
|
@@ -44,11 +150,11 @@ const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
|
| 44 |
};
|
| 45 |
|
| 46 |
const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
|
| 47 |
-
{ MODEL_TINY,
|
| 48 |
-
{ MODEL_BASE,
|
| 49 |
-
{ MODEL_SMALL,
|
| 50 |
-
{ MODEL_MEDIUM,
|
| 51 |
-
{ MODEL_LARGE,
|
| 52 |
};
|
| 53 |
|
| 54 |
const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
|
@@ -102,6 +208,10 @@ struct whisper_vocab {
|
|
| 102 |
id token_solm = 50361; // ??
|
| 103 |
id token_beg = 50363;
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
bool is_multilingual() const {
|
| 106 |
return n_vocab == 51865;
|
| 107 |
}
|
|
@@ -109,16 +219,18 @@ struct whisper_vocab {
|
|
| 109 |
|
| 110 |
// command-line parameters
|
| 111 |
struct whisper_params {
|
| 112 |
-
int32_t seed = -1; // RNG seed
|
| 113 |
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
| 114 |
|
|
|
|
| 115 |
int32_t max_tokens_per_iter = 64;
|
| 116 |
|
| 117 |
-
bool verbose
|
|
|
|
| 118 |
bool print_special_tokens = false;
|
| 119 |
|
| 120 |
-
std::string
|
| 121 |
-
|
| 122 |
std::string fname_inp = "samples/jfk.wav";
|
| 123 |
};
|
| 124 |
|
|
@@ -136,6 +248,15 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 136 |
params.max_tokens_per_iter = std::stoi(argv[++i]);
|
| 137 |
} else if (arg == "-v" || arg == "--verbose") {
|
| 138 |
params.verbose = true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
} else if (arg == "-ps" || arg == "--print_special") {
|
| 140 |
params.print_special_tokens = true;
|
| 141 |
} else if (arg == "-m" || arg == "--model") {
|
|
@@ -160,16 +281,16 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|
| 160 |
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
| 161 |
fprintf(stderr, "\n");
|
| 162 |
fprintf(stderr, "options:\n");
|
| 163 |
-
fprintf(stderr, " -h,
|
| 164 |
-
fprintf(stderr, " -s SEED,
|
| 165 |
-
fprintf(stderr, " -t N,
|
| 166 |
-
fprintf(stderr, " -T N,
|
| 167 |
-
fprintf(stderr, " -v,
|
| 168 |
-
fprintf(stderr, "
|
| 169 |
-
fprintf(stderr, " -
|
| 170 |
-
fprintf(stderr, "
|
| 171 |
-
fprintf(stderr, " -
|
| 172 |
-
fprintf(stderr, "
|
| 173 |
fprintf(stderr, "\n");
|
| 174 |
}
|
| 175 |
|
|
@@ -417,6 +538,7 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
|
|
| 417 |
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
| 418 |
printf("%s: type = %d\n", __func__, model.type);
|
| 419 |
|
|
|
|
| 420 |
const size_t mem_required =
|
| 421 |
MEM_REQ_MODEL.at(model.type) +
|
| 422 |
MEM_REQ_ENCODE.at(model.type) +
|
|
@@ -609,11 +731,11 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
|
|
| 609 |
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
|
| 610 |
}
|
| 611 |
|
| 612 |
-
ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(
|
| 613 |
-
ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(
|
| 614 |
|
| 615 |
-
ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(
|
| 616 |
-
ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(
|
| 617 |
|
| 618 |
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
|
| 619 |
|
|
@@ -836,22 +958,24 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
|
|
| 836 |
const int n_text_layer = hparams.n_text_layer;
|
| 837 |
const int n_text_ctx = hparams.n_text_ctx;
|
| 838 |
|
|
|
|
| 839 |
{
|
| 840 |
const int n_mem = n_text_layer*n_text_ctx;
|
| 841 |
const int n_elements = n_text_state*n_mem;
|
| 842 |
|
| 843 |
-
model.memory_k = ggml_new_tensor_1d(ctx,
|
| 844 |
-
model.memory_v = ggml_new_tensor_1d(ctx,
|
| 845 |
}
|
| 846 |
|
|
|
|
| 847 |
{
|
| 848 |
const int n_audio_ctx = hparams.n_audio_ctx;
|
| 849 |
|
| 850 |
const int n_mem = n_text_layer*n_audio_ctx;
|
| 851 |
const int n_elements = n_text_state*n_mem;
|
| 852 |
|
| 853 |
-
model.memory_cross_k = ggml_new_tensor_1d(ctx,
|
| 854 |
-
model.memory_cross_v = ggml_new_tensor_1d(ctx,
|
| 855 |
}
|
| 856 |
|
| 857 |
const size_t memory_size =
|
|
@@ -1057,14 +1181,14 @@ bool whisper_encode(
|
|
| 1057 |
Qcur),
|
| 1058 |
Qcur);
|
| 1059 |
|
| 1060 |
-
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
| 1061 |
|
| 1062 |
-
// no bias for Key
|
| 1063 |
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
|
| 1064 |
layer.attn_k_w,
|
| 1065 |
cur);
|
| 1066 |
|
| 1067 |
-
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
| 1068 |
|
| 1069 |
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
|
| 1070 |
layer.attn_v_w,
|
|
@@ -1078,49 +1202,57 @@ bool whisper_encode(
|
|
| 1078 |
|
| 1079 |
// ------
|
| 1080 |
|
|
|
|
| 1081 |
struct ggml_tensor * Q =
|
| 1082 |
ggml_permute(ctxL,
|
| 1083 |
ggml_cpy(ctxL,
|
| 1084 |
Qcur,
|
| 1085 |
-
ggml_new_tensor_3d(ctxL,
|
| 1086 |
0, 2, 1, 3);
|
| 1087 |
|
| 1088 |
struct ggml_tensor * K =
|
| 1089 |
ggml_permute(ctxL,
|
| 1090 |
ggml_cpy(ctxL,
|
| 1091 |
Kcur,
|
| 1092 |
-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
| 1093 |
0, 2, 1, 3);
|
| 1094 |
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1100 |
|
| 1101 |
-
|
| 1102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1103 |
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
// 1, 2, 0, 3),
|
| 1111 |
-
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
|
| 1112 |
-
// );
|
| 1113 |
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
|
| 1117 |
-
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
|
| 1123 |
-
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL,
|
| 1124 |
|
| 1125 |
//struct ggml_tensor * V_trans =
|
| 1126 |
// ggml_permute(ctxL,
|
|
@@ -1138,10 +1270,11 @@ bool whisper_encode(
|
|
| 1138 |
Vcur,
|
| 1139 |
n_state/n_head, n_head, N),
|
| 1140 |
0, 2, 1, 3),
|
| 1141 |
-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
|
| 1142 |
);
|
| 1143 |
|
| 1144 |
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
|
|
|
|
| 1145 |
|
| 1146 |
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
| 1147 |
|
|
@@ -1180,6 +1313,11 @@ bool whisper_encode(
|
|
| 1180 |
ggml_repeat(ctxL, layer.mlp_ln_b, cur));
|
| 1181 |
}
|
| 1182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1183 |
// fully connected
|
| 1184 |
cur = ggml_mul_mat(ctxL,
|
| 1185 |
layer.mlp_0_w,
|
|
@@ -1200,6 +1338,7 @@ bool whisper_encode(
|
|
| 1200 |
cur = ggml_add(ctxL,
|
| 1201 |
ggml_repeat(ctxL, layer.mlp_1_b, cur),
|
| 1202 |
cur);
|
|
|
|
| 1203 |
}
|
| 1204 |
|
| 1205 |
// output from this layer
|
|
@@ -1368,7 +1507,7 @@ bool whisper_decode(
|
|
| 1368 |
((int32_t *) position->data)[i] = n_past + i;
|
| 1369 |
}
|
| 1370 |
|
| 1371 |
-
//
|
| 1372 |
struct ggml_tensor * cur =
|
| 1373 |
ggml_add(ctx0,
|
| 1374 |
ggml_get_rows(ctx0, model.d_te, embd),
|
|
@@ -1420,7 +1559,7 @@ bool whisper_decode(
|
|
| 1420 |
|
| 1421 |
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
| 1422 |
|
| 1423 |
-
// no bias for Key
|
| 1424 |
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
|
| 1425 |
layer.attn_k_w,
|
| 1426 |
cur);
|
|
@@ -1506,7 +1645,7 @@ bool whisper_decode(
|
|
| 1506 |
|
| 1507 |
// norm
|
| 1508 |
{
|
| 1509 |
-
cur = ggml_norm(ctxL, inpCA); //
|
| 1510 |
|
| 1511 |
// cur = ln_0_w*cur + ln_0_b
|
| 1512 |
cur = ggml_add(ctxL,
|
|
@@ -1589,7 +1728,6 @@ bool whisper_decode(
|
|
| 1589 |
cur);
|
| 1590 |
}
|
| 1591 |
|
| 1592 |
-
|
| 1593 |
// add the input
|
| 1594 |
cur = ggml_add(ctxL, cur, inpCA);
|
| 1595 |
|
|
@@ -1601,8 +1739,7 @@ bool whisper_decode(
|
|
| 1601 |
{
|
| 1602 |
cur = ggml_norm(ctxL, inpFF);
|
| 1603 |
|
| 1604 |
-
// cur =
|
| 1605 |
-
// [ 768, N]
|
| 1606 |
cur = ggml_add(ctxL,
|
| 1607 |
ggml_mul(ctxL,
|
| 1608 |
ggml_repeat(ctxL, layer.mlp_ln_w, cur),
|
|
@@ -1689,11 +1826,11 @@ bool whisper_decode(
|
|
| 1689 |
probs_out.resize(N*n_vocab);
|
| 1690 |
memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
|
| 1691 |
|
| 1692 |
-
|
| 1693 |
-
|
| 1694 |
-
|
| 1695 |
-
|
| 1696 |
-
|
| 1697 |
|
| 1698 |
ggml_free(ctx0);
|
| 1699 |
|
|
@@ -1981,8 +2118,36 @@ int main(int argc, char ** argv) {
|
|
| 1981 |
t_mel_us = ggml_time_us() - t_start_us;
|
| 1982 |
}
|
| 1983 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1984 |
std::vector<whisper_vocab::id> prompt_past = { };
|
| 1985 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1986 |
// main loop
|
| 1987 |
int seek = 0;
|
| 1988 |
while (true) {
|
|
@@ -2006,24 +2171,23 @@ int main(int argc, char ** argv) {
|
|
| 2006 |
std::vector<float> probs;
|
| 2007 |
std::vector<float> logits;
|
| 2008 |
|
| 2009 |
-
|
| 2010 |
-
// ref: https://github.com/openai/whisper/blob/15ab54826343c27cfaf44ce31e9c8fb63d0aa775/whisper/decoding.py#L506-L526
|
| 2011 |
-
// TODO: use different initial tokens for different tasks
|
| 2012 |
-
std::vector<whisper_vocab::id> prompt = { vocab.token_sot };
|
| 2013 |
|
| 2014 |
int n_past = 0;
|
| 2015 |
|
|
|
|
| 2016 |
if (prompt_past.size() > 0) {
|
| 2017 |
int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size()));
|
| 2018 |
|
| 2019 |
prompt = { vocab.token_prev };
|
| 2020 |
-
prompt.insert(prompt.
|
| 2021 |
-
prompt.push_back(vocab.token_sot);
|
| 2022 |
|
| 2023 |
prompt_past.clear();
|
| 2024 |
-
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()
|
| 2025 |
}
|
| 2026 |
|
|
|
|
|
|
|
| 2027 |
bool done = false;
|
| 2028 |
int seek_delta = 100*CHUNK_SIZE;
|
| 2029 |
whisper_vocab::id last_id = 0;
|
|
@@ -2049,6 +2213,16 @@ int main(int argc, char ** argv) {
|
|
| 2049 |
n_past += prompt.size();
|
| 2050 |
prompt.clear();
|
| 2051 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2052 |
{
|
| 2053 |
// sample next token
|
| 2054 |
const float temp = 1.0; // TODO
|
|
|
|
| 1 |
#include "ggml.h"
|
| 2 |
|
| 3 |
+
#define USE_FLASH_ATTN
|
| 4 |
+
#define USE_FLASH_FF
|
| 5 |
+
|
| 6 |
// third-party utilities
|
| 7 |
// use your favorite implementations
|
| 8 |
#define DR_WAV_IMPLEMENTATION
|
|
|
|
| 19 |
#include <thread>
|
| 20 |
#include <vector>
|
| 21 |
|
| 22 |
+
// available whisper models
|
| 23 |
enum e_model {
|
| 24 |
MODEL_UNKNOWN,
|
| 25 |
MODEL_TINY,
|
|
|
|
| 29 |
MODEL_LARGE,
|
| 30 |
};
|
| 31 |
|
| 32 |
+
const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
| 33 |
+
{ "en", { 0, "english", } },
|
| 34 |
+
{ "zh", { 1, "chinese", } },
|
| 35 |
+
{ "de", { 2, "german", } },
|
| 36 |
+
{ "es", { 3, "spanish", } },
|
| 37 |
+
{ "ru", { 4, "russian", } },
|
| 38 |
+
{ "ko", { 5, "korean", } },
|
| 39 |
+
{ "fr", { 6, "french", } },
|
| 40 |
+
{ "ja", { 7, "japanese", } },
|
| 41 |
+
{ "pt", { 8, "portuguese", } },
|
| 42 |
+
{ "tr", { 9, "turkish", } },
|
| 43 |
+
{ "pl", { 10, "polish", } },
|
| 44 |
+
{ "ca", { 11, "catalan", } },
|
| 45 |
+
{ "nl", { 12, "dutch", } },
|
| 46 |
+
{ "ar", { 13, "arabic", } },
|
| 47 |
+
{ "sv", { 14, "swedish", } },
|
| 48 |
+
{ "it", { 15, "italian", } },
|
| 49 |
+
{ "id", { 16, "indonesian", } },
|
| 50 |
+
{ "hi", { 17, "hindi", } },
|
| 51 |
+
{ "fi", { 18, "finnish", } },
|
| 52 |
+
{ "vi", { 19, "vietnamese", } },
|
| 53 |
+
{ "iw", { 20, "hebrew", } },
|
| 54 |
+
{ "uk", { 21, "ukrainian", } },
|
| 55 |
+
{ "el", { 22, "greek", } },
|
| 56 |
+
{ "ms", { 23, "malay", } },
|
| 57 |
+
{ "cs", { 24, "czech", } },
|
| 58 |
+
{ "ro", { 25, "romanian", } },
|
| 59 |
+
{ "da", { 26, "danish", } },
|
| 60 |
+
{ "hu", { 27, "hungarian", } },
|
| 61 |
+
{ "ta", { 28, "tamil", } },
|
| 62 |
+
{ "no", { 29, "norwegian", } },
|
| 63 |
+
{ "th", { 30, "thai", } },
|
| 64 |
+
{ "ur", { 31, "urdu", } },
|
| 65 |
+
{ "hr", { 32, "croatian", } },
|
| 66 |
+
{ "bg", { 33, "bulgarian", } },
|
| 67 |
+
{ "lt", { 34, "lithuanian", } },
|
| 68 |
+
{ "la", { 35, "latin", } },
|
| 69 |
+
{ "mi", { 36, "maori", } },
|
| 70 |
+
{ "ml", { 37, "malayalam", } },
|
| 71 |
+
{ "cy", { 38, "welsh", } },
|
| 72 |
+
{ "sk", { 39, "slovak", } },
|
| 73 |
+
{ "te", { 40, "telugu", } },
|
| 74 |
+
{ "fa", { 41, "persian", } },
|
| 75 |
+
{ "lv", { 42, "latvian", } },
|
| 76 |
+
{ "bn", { 43, "bengali", } },
|
| 77 |
+
{ "sr", { 44, "serbian", } },
|
| 78 |
+
{ "az", { 45, "azerbaijani", } },
|
| 79 |
+
{ "sl", { 46, "slovenian", } },
|
| 80 |
+
{ "kn", { 47, "kannada", } },
|
| 81 |
+
{ "et", { 48, "estonian", } },
|
| 82 |
+
{ "mk", { 49, "macedonian", } },
|
| 83 |
+
{ "br", { 50, "breton", } },
|
| 84 |
+
{ "eu", { 51, "basque", } },
|
| 85 |
+
{ "is", { 52, "icelandic", } },
|
| 86 |
+
{ "hy", { 53, "armenian", } },
|
| 87 |
+
{ "ne", { 54, "nepali", } },
|
| 88 |
+
{ "mn", { 55, "mongolian", } },
|
| 89 |
+
{ "bs", { 56, "bosnian", } },
|
| 90 |
+
{ "kk", { 57, "kazakh", } },
|
| 91 |
+
{ "sq", { 58, "albanian", } },
|
| 92 |
+
{ "sw", { 59, "swahili", } },
|
| 93 |
+
{ "gl", { 60, "galician", } },
|
| 94 |
+
{ "mr", { 61, "marathi", } },
|
| 95 |
+
{ "pa", { 62, "punjabi", } },
|
| 96 |
+
{ "si", { 63, "sinhala", } },
|
| 97 |
+
{ "km", { 64, "khmer", } },
|
| 98 |
+
{ "sn", { 65, "shona", } },
|
| 99 |
+
{ "yo", { 66, "yoruba", } },
|
| 100 |
+
{ "so", { 67, "somali", } },
|
| 101 |
+
{ "af", { 68, "afrikaans", } },
|
| 102 |
+
{ "oc", { 69, "occitan", } },
|
| 103 |
+
{ "ka", { 70, "georgian", } },
|
| 104 |
+
{ "be", { 71, "belarusian", } },
|
| 105 |
+
{ "tg", { 72, "tajik", } },
|
| 106 |
+
{ "sd", { 73, "sindhi", } },
|
| 107 |
+
{ "gu", { 74, "gujarati", } },
|
| 108 |
+
{ "am", { 75, "amharic", } },
|
| 109 |
+
{ "yi", { 76, "yiddish", } },
|
| 110 |
+
{ "lo", { 77, "lao", } },
|
| 111 |
+
{ "uz", { 78, "uzbek", } },
|
| 112 |
+
{ "fo", { 79, "faroese", } },
|
| 113 |
+
{ "ht", { 80, "haitian creole", } },
|
| 114 |
+
{ "ps", { 81, "pashto", } },
|
| 115 |
+
{ "tk", { 82, "turkmen", } },
|
| 116 |
+
{ "nn", { 83, "nynorsk", } },
|
| 117 |
+
{ "mt", { 84, "maltese", } },
|
| 118 |
+
{ "sa", { 85, "sanskrit", } },
|
| 119 |
+
{ "lb", { 86, "luxembourgish", } },
|
| 120 |
+
{ "my", { 87, "myanmar", } },
|
| 121 |
+
{ "bo", { 88, "tibetan", } },
|
| 122 |
+
{ "tl", { 89, "tagalog", } },
|
| 123 |
+
{ "mg", { 90, "malagasy", } },
|
| 124 |
+
{ "as", { 91, "assamese", } },
|
| 125 |
+
{ "tt", { 92, "tatar", } },
|
| 126 |
+
{ "haw", { 93, "hawaiian", } },
|
| 127 |
+
{ "ln", { 94, "lingala", } },
|
| 128 |
+
{ "ha", { 95, "hausa", } },
|
| 129 |
+
{ "ba", { 96, "bashkir", } },
|
| 130 |
+
{ "jw", { 97, "javanese", } },
|
| 131 |
+
{ "su", { 98, "sundanese", } },
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
const size_t MB = 1024*1024;
|
| 135 |
|
| 136 |
const std::map<e_model, size_t> MEM_REQ_MODEL = {
|
| 137 |
+
{ MODEL_TINY, 86ull*MB },
|
| 138 |
+
{ MODEL_BASE, 165ull*MB },
|
| 139 |
+
{ MODEL_SMALL, 540ull*MB },
|
| 140 |
+
{ MODEL_MEDIUM, 1650ull*MB },
|
| 141 |
+
{ MODEL_LARGE, 3260ull*MB },
|
| 142 |
};
|
| 143 |
|
| 144 |
const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
|
|
|
| 150 |
};
|
| 151 |
|
| 152 |
const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
|
| 153 |
+
{ MODEL_TINY, 64ull*MB },
|
| 154 |
+
{ MODEL_BASE, 84ull*MB },
|
| 155 |
+
{ MODEL_SMALL, 128ull*MB },
|
| 156 |
+
{ MODEL_MEDIUM, 172ull*MB },
|
| 157 |
+
{ MODEL_LARGE, 216ull*MB },
|
| 158 |
};
|
| 159 |
|
| 160 |
const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
|
|
|
| 208 |
id token_solm = 50361; // ??
|
| 209 |
id token_beg = 50363;
|
| 210 |
|
| 211 |
+
// available tasks
|
| 212 |
+
const id token_translate = 50358;
|
| 213 |
+
const id token_transcribe = 50359;
|
| 214 |
+
|
| 215 |
bool is_multilingual() const {
|
| 216 |
return n_vocab == 51865;
|
| 217 |
}
|
|
|
|
| 219 |
|
| 220 |
// command-line parameters
|
| 221 |
struct whisper_params {
|
| 222 |
+
int32_t seed = -1; // RNG seed, not used currently
|
| 223 |
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
| 224 |
|
| 225 |
+
// sampling parameter - used for the greedy strategy
|
| 226 |
int32_t max_tokens_per_iter = 64;
|
| 227 |
|
| 228 |
+
bool verbose = false;
|
| 229 |
+
bool translate = false;
|
| 230 |
bool print_special_tokens = false;
|
| 231 |
|
| 232 |
+
std::string language = "en";
|
| 233 |
+
std::string model = "models/ggml-base.en.bin";
|
| 234 |
std::string fname_inp = "samples/jfk.wav";
|
| 235 |
};
|
| 236 |
|
|
|
|
| 248 |
params.max_tokens_per_iter = std::stoi(argv[++i]);
|
| 249 |
} else if (arg == "-v" || arg == "--verbose") {
|
| 250 |
params.verbose = true;
|
| 251 |
+
} else if (arg == "--translate") {
|
| 252 |
+
params.translate = true;
|
| 253 |
+
} else if (arg == "-l" || arg == "--language") {
|
| 254 |
+
params.language = argv[++i];
|
| 255 |
+
if (g_lang.find(params.language) == g_lang.end()) {
|
| 256 |
+
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
| 257 |
+
whisper_print_usage(argc, argv, params);
|
| 258 |
+
exit(0);
|
| 259 |
+
}
|
| 260 |
} else if (arg == "-ps" || arg == "--print_special") {
|
| 261 |
params.print_special_tokens = true;
|
| 262 |
} else if (arg == "-m" || arg == "--model") {
|
|
|
|
| 281 |
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
| 282 |
fprintf(stderr, "\n");
|
| 283 |
fprintf(stderr, "options:\n");
|
| 284 |
+
fprintf(stderr, " -h, --help show this help message and exit\n");
|
| 285 |
+
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
|
| 286 |
+
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
| 287 |
+
fprintf(stderr, " -T N, --tokens N maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter);
|
| 288 |
+
fprintf(stderr, " -v, --verbose verbose output\n");
|
| 289 |
+
fprintf(stderr, " --translate translate from source language to english\n");
|
| 290 |
+
fprintf(stderr, " -ps, --print_special print special tokens\n");
|
| 291 |
+
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
|
| 292 |
+
fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
|
| 293 |
+
fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str());
|
| 294 |
fprintf(stderr, "\n");
|
| 295 |
}
|
| 296 |
|
|
|
|
| 538 |
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
| 539 |
printf("%s: type = %d\n", __func__, model.type);
|
| 540 |
|
| 541 |
+
// this is the total memory required to run the inference
|
| 542 |
const size_t mem_required =
|
| 543 |
MEM_REQ_MODEL.at(model.type) +
|
| 544 |
MEM_REQ_ENCODE.at(model.type) +
|
|
|
|
| 731 |
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
|
| 732 |
}
|
| 733 |
|
| 734 |
+
ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
|
| 735 |
+
ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
|
| 736 |
|
| 737 |
+
ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
|
| 738 |
+
ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
|
| 739 |
|
| 740 |
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
|
| 741 |
|
|
|
|
| 958 |
const int n_text_layer = hparams.n_text_layer;
|
| 959 |
const int n_text_ctx = hparams.n_text_ctx;
|
| 960 |
|
| 961 |
+
// key/value memory for the self-attention layer
|
| 962 |
{
|
| 963 |
const int n_mem = n_text_layer*n_text_ctx;
|
| 964 |
const int n_elements = n_text_state*n_mem;
|
| 965 |
|
| 966 |
+
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
| 967 |
+
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
| 968 |
}
|
| 969 |
|
| 970 |
+
// key/value memory for the cross-attention layer
|
| 971 |
{
|
| 972 |
const int n_audio_ctx = hparams.n_audio_ctx;
|
| 973 |
|
| 974 |
const int n_mem = n_text_layer*n_audio_ctx;
|
| 975 |
const int n_elements = n_text_state*n_mem;
|
| 976 |
|
| 977 |
+
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
| 978 |
+
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
| 979 |
}
|
| 980 |
|
| 981 |
const size_t memory_size =
|
|
|
|
| 1181 |
Qcur),
|
| 1182 |
Qcur);
|
| 1183 |
|
| 1184 |
+
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
| 1185 |
|
| 1186 |
+
// note: no bias for Key
|
| 1187 |
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
|
| 1188 |
layer.attn_k_w,
|
| 1189 |
cur);
|
| 1190 |
|
| 1191 |
+
//Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
| 1192 |
|
| 1193 |
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
|
| 1194 |
layer.attn_v_w,
|
|
|
|
| 1202 |
|
| 1203 |
// ------
|
| 1204 |
|
| 1205 |
+
#ifdef USE_FLASH_ATTN
|
| 1206 |
struct ggml_tensor * Q =
|
| 1207 |
ggml_permute(ctxL,
|
| 1208 |
ggml_cpy(ctxL,
|
| 1209 |
Qcur,
|
| 1210 |
+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
| 1211 |
0, 2, 1, 3);
|
| 1212 |
|
| 1213 |
struct ggml_tensor * K =
|
| 1214 |
ggml_permute(ctxL,
|
| 1215 |
ggml_cpy(ctxL,
|
| 1216 |
Kcur,
|
| 1217 |
+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
| 1218 |
0, 2, 1, 3);
|
| 1219 |
|
| 1220 |
+
struct ggml_tensor * V =
|
| 1221 |
+
ggml_cpy(ctxL,
|
| 1222 |
+
ggml_permute(ctxL,
|
| 1223 |
+
ggml_reshape_3d(ctxL,
|
| 1224 |
+
Vcur,
|
| 1225 |
+
n_state/n_head, n_head, N),
|
| 1226 |
+
1, 2, 0, 3),
|
| 1227 |
+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
|
| 1228 |
+
);
|
| 1229 |
|
| 1230 |
+
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
| 1231 |
+
#else
|
| 1232 |
+
struct ggml_tensor * Q =
|
| 1233 |
+
ggml_permute(ctxL,
|
| 1234 |
+
ggml_cpy(ctxL,
|
| 1235 |
+
Qcur,
|
| 1236 |
+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
| 1237 |
+
0, 2, 1, 3);
|
| 1238 |
|
| 1239 |
+
struct ggml_tensor * K =
|
| 1240 |
+
ggml_permute(ctxL,
|
| 1241 |
+
ggml_cpy(ctxL,
|
| 1242 |
+
Kcur,
|
| 1243 |
+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
| 1244 |
+
0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
| 1245 |
|
| 1246 |
+
// K * Q
|
| 1247 |
+
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
|
| 1248 |
|
| 1249 |
+
struct ggml_tensor * KQ_scaled =
|
| 1250 |
+
ggml_scale(ctxL,
|
| 1251 |
+
KQ,
|
| 1252 |
+
ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
| 1253 |
+
);
|
| 1254 |
|
| 1255 |
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
|
| 1256 |
|
| 1257 |
//struct ggml_tensor * V_trans =
|
| 1258 |
// ggml_permute(ctxL,
|
|
|
|
| 1270 |
Vcur,
|
| 1271 |
n_state/n_head, n_head, N),
|
| 1272 |
0, 2, 1, 3),
|
| 1273 |
+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
|
| 1274 |
);
|
| 1275 |
|
| 1276 |
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
|
| 1277 |
+
#endif
|
| 1278 |
|
| 1279 |
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
| 1280 |
|
|
|
|
| 1313 |
ggml_repeat(ctxL, layer.mlp_ln_b, cur));
|
| 1314 |
}
|
| 1315 |
|
| 1316 |
+
#ifdef USE_FLASH_FF
|
| 1317 |
+
cur = ggml_flash_ff(ctxL,
|
| 1318 |
+
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
|
| 1319 |
+
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
| 1320 |
+
#else
|
| 1321 |
// fully connected
|
| 1322 |
cur = ggml_mul_mat(ctxL,
|
| 1323 |
layer.mlp_0_w,
|
|
|
|
| 1338 |
cur = ggml_add(ctxL,
|
| 1339 |
ggml_repeat(ctxL, layer.mlp_1_b, cur),
|
| 1340 |
cur);
|
| 1341 |
+
#endif
|
| 1342 |
}
|
| 1343 |
|
| 1344 |
// output from this layer
|
|
|
|
| 1507 |
((int32_t *) position->data)[i] = n_past + i;
|
| 1508 |
}
|
| 1509 |
|
| 1510 |
+
// token encoding + position encoding
|
| 1511 |
struct ggml_tensor * cur =
|
| 1512 |
ggml_add(ctx0,
|
| 1513 |
ggml_get_rows(ctx0, model.d_te, embd),
|
|
|
|
| 1559 |
|
| 1560 |
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
| 1561 |
|
| 1562 |
+
// note: no bias for Key
|
| 1563 |
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
|
| 1564 |
layer.attn_k_w,
|
| 1565 |
cur);
|
|
|
|
| 1645 |
|
| 1646 |
// norm
|
| 1647 |
{
|
| 1648 |
+
cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
|
| 1649 |
|
| 1650 |
// cur = ln_0_w*cur + ln_0_b
|
| 1651 |
cur = ggml_add(ctxL,
|
|
|
|
| 1728 |
cur);
|
| 1729 |
}
|
| 1730 |
|
|
|
|
| 1731 |
// add the input
|
| 1732 |
cur = ggml_add(ctxL, cur, inpCA);
|
| 1733 |
|
|
|
|
| 1739 |
{
|
| 1740 |
cur = ggml_norm(ctxL, inpFF);
|
| 1741 |
|
| 1742 |
+
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
|
|
| 1743 |
cur = ggml_add(ctxL,
|
| 1744 |
ggml_mul(ctxL,
|
| 1745 |
ggml_repeat(ctxL, layer.mlp_ln_w, cur),
|
|
|
|
| 1826 |
probs_out.resize(N*n_vocab);
|
| 1827 |
memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
|
| 1828 |
|
| 1829 |
+
if (N > 1) {
|
| 1830 |
+
//const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
|
| 1831 |
+
//printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
|
| 1832 |
+
//printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
|
| 1833 |
+
}
|
| 1834 |
|
| 1835 |
ggml_free(ctx0);
|
| 1836 |
|
|
|
|
| 2118 |
t_mel_us = ggml_time_us() - t_start_us;
|
| 2119 |
}
|
| 2120 |
|
| 2121 |
+
// print some info about the processing
|
| 2122 |
+
{
|
| 2123 |
+
printf("\n");
|
| 2124 |
+
if (!vocab.is_multilingual()) {
|
| 2125 |
+
if (params.language != "en" || params.translate) {
|
| 2126 |
+
params.language = "en";
|
| 2127 |
+
params.translate = false;
|
| 2128 |
+
printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
| 2129 |
+
}
|
| 2130 |
+
}
|
| 2131 |
+
printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s ...\n",
|
| 2132 |
+
__func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
|
| 2133 |
+
g_lang.at(params.language).second.c_str(),
|
| 2134 |
+
params.translate ? "translate" : "transcribe");
|
| 2135 |
+
}
|
| 2136 |
+
|
| 2137 |
+
// the accumulated text context so far
|
| 2138 |
std::vector<whisper_vocab::id> prompt_past = { };
|
| 2139 |
|
| 2140 |
+
// these tokens determine the task that will be performed
|
| 2141 |
+
std::vector<whisper_vocab::id> prompt_init = { vocab.token_sot };
|
| 2142 |
+
if (vocab.is_multilingual()) {
|
| 2143 |
+
prompt_init.push_back(vocab.token_sot + 1 + g_lang.at(params.language).first);
|
| 2144 |
+
if (params.translate) {
|
| 2145 |
+
prompt_init.push_back(vocab.token_translate);
|
| 2146 |
+
} else {
|
| 2147 |
+
prompt_init.push_back(vocab.token_transcribe);
|
| 2148 |
+
}
|
| 2149 |
+
}
|
| 2150 |
+
|
| 2151 |
// main loop
|
| 2152 |
int seek = 0;
|
| 2153 |
while (true) {
|
|
|
|
| 2171 |
std::vector<float> probs;
|
| 2172 |
std::vector<float> logits;
|
| 2173 |
|
| 2174 |
+
std::vector<whisper_vocab::id> prompt;
|
|
|
|
|
|
|
|
|
|
| 2175 |
|
| 2176 |
int n_past = 0;
|
| 2177 |
|
| 2178 |
+
// if we have already generated some text, use it as a prompt to condition the next generation
|
| 2179 |
if (prompt_past.size() > 0) {
|
| 2180 |
int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size()));
|
| 2181 |
|
| 2182 |
prompt = { vocab.token_prev };
|
| 2183 |
+
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
|
|
|
|
| 2184 |
|
| 2185 |
prompt_past.clear();
|
| 2186 |
+
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
|
| 2187 |
}
|
| 2188 |
|
| 2189 |
+
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
| 2190 |
+
|
| 2191 |
bool done = false;
|
| 2192 |
int seek_delta = 100*CHUNK_SIZE;
|
| 2193 |
whisper_vocab::id last_id = 0;
|
|
|
|
| 2213 |
n_past += prompt.size();
|
| 2214 |
prompt.clear();
|
| 2215 |
|
| 2216 |
+
// very basic greedy sampling strategy:
|
| 2217 |
+
//
|
| 2218 |
+
// - always take the most probable token
|
| 2219 |
+
// - if we have accumulated more than 'params.max_tokens_per_iter' -> pick most probable timestamp token
|
| 2220 |
+
// and advance the sliding window by that amount
|
| 2221 |
+
// - in the meantime, if we encounter 2 consecutive timestamp tokens, we advance the sliding window too
|
| 2222 |
+
//
|
| 2223 |
+
// more sophisticated sampling strategies could be implemented here, but we keep it simple
|
| 2224 |
+
// feel free to experiment!
|
| 2225 |
+
//
|
| 2226 |
{
|
| 2227 |
// sample next token
|
| 2228 |
const float temp = 1.0; // TODO
|