ggerganov commited on
Commit
a8c74e6
·
unverified ·
1 Parent(s): 83d8317

talk-llama : add new example + sync ggml from llama.cpp (#664)

Browse files

* talk-llama : talk with LLaMA AI

* talk.llama : disable EOS token

* talk-llama : add README instructions

* ggml : fix build in debug

.gitignore CHANGED
@@ -18,6 +18,7 @@ build-sanitize-thread/
18
  /stream
19
  /command
20
  /talk
 
21
  /bench
22
 
23
  arm_neon.h
@@ -32,3 +33,5 @@ examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
32
  examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
33
 
34
  extra/bench-gg.txt
 
 
 
18
  /stream
19
  /command
20
  /talk
21
+ /talk-llama
22
  /bench
23
 
24
  arm_neon.h
 
33
  examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
34
 
35
  extra/bench-gg.txt
36
+
37
+ *.mlmodel*
Makefile CHANGED
@@ -36,7 +36,7 @@ LDFLAGS =
36
 
37
  # ref: https://github.com/ggerganov/whisper.cpp/issues/37
38
  ifneq ($(wildcard /usr/include/musl/*),)
39
- CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
40
  CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
41
  endif
42
 
@@ -178,7 +178,7 @@ $(info I CC: $(CCV))
178
  $(info I CXX: $(CXXV))
179
  $(info )
180
 
181
- default: main
182
 
183
  #
184
  # Build library
@@ -197,7 +197,7 @@ libwhisper.so: ggml.o whisper.o
197
  $(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
198
 
199
  clean:
200
- rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
201
 
202
  #
203
  # Examples
@@ -212,6 +212,9 @@ main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
212
  $(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
213
  ./main -h
214
 
 
 
 
215
  stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
216
  $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
217
 
@@ -221,8 +224,8 @@ command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whi
221
  talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
222
  $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
223
 
224
- bench: examples/bench/bench.cpp ggml.o whisper.o
225
- $(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
226
 
227
  #
228
  # Audio samples
 
36
 
37
  # ref: https://github.com/ggerganov/whisper.cpp/issues/37
38
  ifneq ($(wildcard /usr/include/musl/*),)
39
+ CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
40
  CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
41
  endif
42
 
 
178
  $(info I CXX: $(CXXV))
179
  $(info )
180
 
181
+ default: main bench
182
 
183
  #
184
  # Build library
 
197
  $(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
198
 
199
  clean:
200
+ rm -f *.o main stream command talk talk-llama bench libwhisper.a libwhisper.so
201
 
202
  #
203
  # Examples
 
212
  $(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
213
  ./main -h
214
 
215
+ bench: examples/bench/bench.cpp ggml.o whisper.o
216
+ $(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
217
+
218
  stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
219
  $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
220
 
 
224
  talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
225
  $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
226
 
227
+ talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
228
+ $(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk-llama $(CC_SDL) $(LDFLAGS)
229
 
230
  #
231
  # Audio samples
examples/CMakeLists.txt CHANGED
@@ -63,4 +63,5 @@ else()
63
  add_subdirectory(command)
64
  add_subdirectory(bench)
65
  add_subdirectory(talk)
 
66
  endif()
 
63
  add_subdirectory(command)
64
  add_subdirectory(bench)
65
  add_subdirectory(talk)
66
+ add_subdirectory(talk-llama)
67
  endif()
examples/talk-llama/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ eleven-labs.py
2
+ audio.mp3
examples/talk-llama/CMakeLists.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ if (WHISPER_SUPPORT_SDL2)
2
+ # talk-llama
3
+ set(TARGET talk-llama)
4
+
5
+ add_executable(${TARGET} talk-llama.cpp llama.cpp)
6
+
7
+ include(DefaultTargetOptions)
8
+
9
+ target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
10
+ endif ()
examples/talk-llama/README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # talk-llama
2
+
3
+ Talk with an LLaMA AI in your terminal
4
+
5
+ [Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
6
+
7
+ ## Building
8
+
9
+ The `talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
10
+
11
+ ```bash
12
+ # Install SDL2 on Linux
13
+ sudo apt-get install libsdl2-dev
14
+
15
+ # Install SDL2 on Mac OS
16
+ brew install sdl2
17
+
18
+ # Build the "talk-llama" executable
19
+ make talk-llama
20
+
21
+ # Run it
22
+ ./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
23
+ ```
24
+
25
+ - The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
26
+ - The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
27
+
28
+ ## TTS
29
+
30
+ For best experience, this example needs a TTS tool to convert the generated text responses to voice.
31
+ You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
32
+ By default, it is configured to use MacOS's `say`, but you can use whatever you wish.
examples/talk-llama/llama.cpp ADDED
@@ -0,0 +1,1865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama.h"
2
+
3
+ #include "ggml.h"
4
+
5
+ #include <cinttypes>
6
+ #include <fstream>
7
+ #include <random>
8
+ #include <map>
9
+ #include <unordered_map>
10
+ #include <queue>
11
+ #include <regex>
12
+ #include <cassert>
13
+ #include <cstring>
14
+
15
+ #define LLAMA_USE_SCRATCH
16
+ #define LLAMA_MAX_SCRATCH_BUFFERS 16
17
+
18
+ #define LLAMA_ASSERT(x) \
19
+ do { \
20
+ if (!(x)) { \
21
+ fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
22
+ abort(); \
23
+ } \
24
+ } while (0)
25
+
26
+
27
+ // determine number of model parts based on the dimension
28
+ static const std::unordered_map<int, int> LLAMA_N_PARTS = {
29
+ { 4096, 1 },
30
+ { 5120, 2 },
31
+ { 6656, 4 },
32
+ { 8192, 8 },
33
+ };
34
+
35
+ // available llama models
36
+ enum e_model {
37
+ MODEL_UNKNOWN,
38
+ MODEL_7B,
39
+ MODEL_13B,
40
+ MODEL_30B,
41
+ MODEL_65B,
42
+ };
43
+
44
+ static const size_t MB = 1024*1024;
45
+
46
+ // computed for n_ctx == 2048
47
+ // TODO: dynamically determine these sizes
48
+ // needs modifications in ggml
49
+
50
+ static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
51
+ { MODEL_7B, 512ull*MB },
52
+ { MODEL_13B, 512ull*MB },
53
+ { MODEL_30B, 512ull*MB },
54
+ { MODEL_65B, 512ull*MB },
55
+ };
56
+
57
+ static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
58
+ { MODEL_7B, 512ull*MB },
59
+ { MODEL_13B, 512ull*MB },
60
+ { MODEL_30B, 512ull*MB },
61
+ { MODEL_65B, 512ull*MB },
62
+ };
63
+
64
+ // 2*n_embd*n_ctx*n_layer*sizeof(float16)
65
+ static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
66
+ { MODEL_7B, 1026ull*MB },
67
+ { MODEL_13B, 1608ull*MB },
68
+ { MODEL_30B, 3124ull*MB },
69
+ { MODEL_65B, 5120ull*MB },
70
+ };
71
+
72
+ // this is mostly needed for temporary mul_mat buffers to dequantize the data
73
+ // not actually needed if BLAS is disabled
74
+ static const std::map<e_model, size_t> MEM_REQ_EVAL = {
75
+ { MODEL_7B, 768ull*MB },
76
+ { MODEL_13B, 1024ull*MB },
77
+ { MODEL_30B, 1280ull*MB },
78
+ { MODEL_65B, 1536ull*MB },
79
+ };
80
+
81
+ // default hparams (LLaMA 7B)
82
+ struct llama_hparams {
83
+ int32_t n_vocab = 32000;
84
+ int32_t n_ctx = 512; // this is provided as user input?
85
+ int32_t n_embd = 4096;
86
+ int32_t n_mult = 256;
87
+ int32_t n_head = 32;
88
+ int32_t n_layer = 32;
89
+ int32_t n_rot = 64;
90
+ int32_t f16 = 1;
91
+ };
92
+
93
+ struct llama_layer {
94
+ // normalization
95
+ struct ggml_tensor * attention_norm;
96
+
97
+ // attention
98
+ struct ggml_tensor * wq;
99
+ struct ggml_tensor * wk;
100
+ struct ggml_tensor * wv;
101
+ struct ggml_tensor * wo;
102
+
103
+ // normalization
104
+ struct ggml_tensor * ffn_norm;
105
+
106
+ // ff
107
+ struct ggml_tensor * w1;
108
+ struct ggml_tensor * w2;
109
+ struct ggml_tensor * w3;
110
+ };
111
+
112
+ struct llama_kv_cache {
113
+ struct ggml_tensor * k;
114
+ struct ggml_tensor * v;
115
+
116
+ struct ggml_context * ctx;
117
+
118
+ std::vector<uint8_t> buf;
119
+
120
+ int n; // number of tokens currently in the cache
121
+ };
122
+
123
+ struct llama_model {
124
+ e_model type = MODEL_UNKNOWN;
125
+
126
+ llama_hparams hparams;
127
+
128
+ struct ggml_tensor * tok_embeddings;
129
+
130
+ struct ggml_tensor * norm;
131
+ struct ggml_tensor * output;
132
+
133
+ std::vector<llama_layer> layers;
134
+
135
+ // context
136
+ struct ggml_context * ctx;
137
+
138
+ // key + value cache for the self attention
139
+ // TODO: move to llama_state
140
+ struct llama_kv_cache kv_self;
141
+
142
+ // the model memory buffer
143
+ std::vector<uint8_t> buf;
144
+
145
+ // tensors
146
+ int n_loaded;
147
+ std::unordered_map<std::string, struct ggml_tensor *> tensors;
148
+ };
149
+
150
+ struct llama_vocab {
151
+ using id = int32_t;
152
+ using token = std::string;
153
+
154
+ struct token_score {
155
+ token tok;
156
+ float score;
157
+ };
158
+
159
+ std::unordered_map<token, id> token_to_id;
160
+ std::vector<token_score> id_to_token;
161
+ };
162
+
163
+ struct llama_context {
164
+ std::mt19937 rng;
165
+
166
+ int64_t t_load_us = 0;
167
+ int64_t t_start_us = 0;
168
+
169
+ int64_t t_sample_us = 0;
170
+ int64_t t_eval_us = 0;
171
+ int64_t t_p_eval_us = 0;
172
+
173
+ int32_t n_sample = 0; // number of tokens sampled
174
+ int32_t n_eval = 0; // number of eval calls
175
+ int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
176
+
177
+ llama_model model;
178
+ llama_vocab vocab;
179
+
180
+ size_t mem_per_token = 0;
181
+
182
+ // decode output (2-dimensional array: [n_tokens][n_vocab])
183
+ std::vector<float> logits;
184
+ bool logits_all = false;
185
+
186
+ // input embedding (1-dimensional array: [n_embd])
187
+ std::vector<float> embedding;
188
+
189
+ // memory buffers used to evaluate the model
190
+ // TODO: move in llama_state
191
+ std::vector<uint8_t> buf_compute;
192
+ std::vector<uint8_t> buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
193
+
194
+ int buf_last = 0;
195
+ size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
196
+
197
+ void use_buf(struct ggml_context * ctx, int i) {
198
+ #if defined(LLAMA_USE_SCRATCH)
199
+ size_t last_size = 0;
200
+
201
+ if (i == -1) {
202
+ last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
203
+ } else {
204
+ auto & buf = buf_scratch[i];
205
+ last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
206
+ }
207
+
208
+ if (buf_last >= 0) {
209
+ buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
210
+ }
211
+
212
+ buf_last = i;
213
+ #else
214
+ (void) i;
215
+ (void) ctx;
216
+ #endif
217
+ }
218
+
219
+ size_t get_buf_max_mem(int i) const {
220
+ #if defined(LLAMA_USE_SCRATCH)
221
+ return buf_max_size[i];
222
+ #else
223
+ (void) i;
224
+ return 0;
225
+ #endif
226
+ }
227
+ };
228
+
229
+ //
230
+ // kv cache
231
+ //
232
+
233
+ static bool kv_cache_init(
234
+ const struct llama_hparams & hparams,
235
+ struct llama_kv_cache & cache,
236
+ ggml_type wtype,
237
+ int n_ctx) {
238
+ const int n_embd = hparams.n_embd;
239
+ const int n_layer = hparams.n_layer;
240
+
241
+ const int n_mem = n_layer*n_ctx;
242
+ const int n_elements = n_embd*n_mem;
243
+
244
+ cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
245
+
246
+ struct ggml_init_params params;
247
+ params.mem_size = cache.buf.size();
248
+ params.mem_buffer = cache.buf.data();
249
+
250
+ cache.ctx = ggml_init(params);
251
+
252
+ if (!cache.ctx) {
253
+ fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
254
+ return false;
255
+ }
256
+
257
+ cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
258
+ cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
259
+
260
+ return true;
261
+ }
262
+
263
+ static void kv_cache_free(struct llama_kv_cache & cache) {
264
+ if (cache.ctx) {
265
+ ggml_free(cache.ctx);
266
+ cache.ctx = nullptr;
267
+ }
268
+ }
269
+
270
+ struct llama_context_params llama_context_default_params() {
271
+ struct llama_context_params result = {
272
+ /*.n_ctx =*/ 512,
273
+ /*.n_parts =*/ -1,
274
+ /*.seed =*/ 0,
275
+ /*.f16_kv =*/ false,
276
+ /*.logits_all =*/ false,
277
+ /*.vocab_only =*/ false,
278
+ /*.use_mlock =*/ false,
279
+ /*.embedding =*/ false,
280
+ /*.progress_callback =*/ nullptr,
281
+ /*.progress_callback_user_data =*/ nullptr,
282
+ };
283
+
284
+ return result;
285
+ }
286
+
287
+ //
288
+ // model loading
289
+ //
290
+
291
+ static bool llama_model_load(
292
+ const std::string & fname,
293
+ llama_context & lctx,
294
+ int n_ctx,
295
+ int n_parts,
296
+ ggml_type memory_type,
297
+ bool vocab_only,
298
+ llama_progress_callback progress_callback,
299
+ void *progress_callback_user_data) {
300
+ fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
301
+
302
+ const int64_t t_start_us = ggml_time_us();
303
+
304
+ lctx.t_start_us = t_start_us;
305
+
306
+ std::vector<char> f_buf(1024*1024);
307
+
308
+ auto & model = lctx.model;
309
+ auto & vocab = lctx.vocab;
310
+
311
+ auto fin = std::ifstream(fname, std::ios::binary);
312
+ fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
313
+ if (!fin) {
314
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
315
+ return false;
316
+ }
317
+
318
+ // verify magic
319
+ {
320
+ uint32_t magic;
321
+ fin.read((char *) &magic, sizeof(magic));
322
+ if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) {
323
+ fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
324
+ __func__, fname.c_str());
325
+ return false;
326
+ }
327
+ if (magic != LLAMA_FILE_MAGIC) {
328
+ fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
329
+ return false;
330
+ }
331
+
332
+ uint32_t format_version;
333
+ fin.read((char *) &format_version, sizeof(format_version));
334
+
335
+ if (format_version != LLAMA_FILE_VERSION) {
336
+ fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ", expected %d)\n",
337
+ __func__, fname.c_str(), format_version, LLAMA_FILE_VERSION);
338
+ return false;
339
+ }
340
+ }
341
+
342
+ int n_ff = 0;
343
+
344
+ // load hparams
345
+ {
346
+ auto & hparams = model.hparams;
347
+
348
+ fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
349
+ //fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
350
+ fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
351
+ fin.read((char *) &hparams.n_mult, sizeof(hparams.n_mult));
352
+ fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
353
+ fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
354
+ fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
355
+ fin.read((char *) &hparams.f16, sizeof(hparams.f16));
356
+
357
+ hparams.n_ctx = n_ctx;
358
+
359
+ n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
360
+
361
+ if (n_parts < 1) {
362
+ n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
363
+ }
364
+
365
+ // temp warning to tell the user to use "--n_parts"
366
+ if (hparams.f16 == 4 && n_parts != 1) {
367
+ fprintf(stderr, "%s: GPTQ model detected - are you sure n_parts should be %d? we normally expect it to be 1\n", __func__, n_parts);
368
+ fprintf(stderr, "%s: use '--n_parts 1' if necessary\n", __func__);
369
+ }
370
+
371
+ if (hparams.n_layer == 32) {
372
+ model.type = e_model::MODEL_7B;
373
+ }
374
+
375
+ if (hparams.n_layer == 40) {
376
+ model.type = e_model::MODEL_13B;
377
+ }
378
+
379
+ if (hparams.n_layer == 60) {
380
+ model.type = e_model::MODEL_30B;
381
+ }
382
+
383
+ if (hparams.n_layer == 80) {
384
+ model.type = e_model::MODEL_65B;
385
+ }
386
+
387
+ fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
388
+ fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx);
389
+ fprintf(stderr, "%s: n_embd = %d\n", __func__, hparams.n_embd);
390
+ fprintf(stderr, "%s: n_mult = %d\n", __func__, hparams.n_mult);
391
+ fprintf(stderr, "%s: n_head = %d\n", __func__, hparams.n_head);
392
+ fprintf(stderr, "%s: n_layer = %d\n", __func__, hparams.n_layer);
393
+ fprintf(stderr, "%s: n_rot = %d\n", __func__, hparams.n_rot);
394
+ fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
395
+ fprintf(stderr, "%s: n_ff = %d\n", __func__, n_ff);
396
+ fprintf(stderr, "%s: n_parts = %d\n", __func__, n_parts);
397
+ fprintf(stderr, "%s: type = %d\n", __func__, model.type);
398
+ }
399
+
400
+ // load vocab
401
+ {
402
+ std::string word;
403
+ vocab.id_to_token.resize(model.hparams.n_vocab);
404
+ std::vector<char> tmp(64);
405
+
406
+ for (int i = 0; i < model.hparams.n_vocab; i++) {
407
+ uint32_t len;
408
+ fin.read((char *) &len, sizeof(len));
409
+
410
+ word.resize(len);
411
+ if (len > 0) {
412
+ tmp.resize(len);
413
+ fin.read(tmp.data(), len);
414
+ word.assign(tmp.data(), len);
415
+ } else {
416
+ word.clear();
417
+ }
418
+
419
+ float score;
420
+ fin.read((char *) &score, sizeof(score));
421
+
422
+ vocab.token_to_id[word] = i;
423
+
424
+ auto &tok_score = vocab.id_to_token[i];
425
+ tok_score.tok = word;
426
+ tok_score.score = score;
427
+ }
428
+ }
429
+
430
+ if (vocab_only) {
431
+ return true;
432
+ }
433
+
434
+ // for the big tensors, we have the option to store the data in 16-bit floats or quantized
435
+ // in order to save memory and also to speed up the computation
436
+ // wtype is for per-layer weights, while vtype is for other weights
437
+ ggml_type wtype, vtype;
438
+ switch (model.hparams.f16) {
439
+ case 0: wtype = vtype = GGML_TYPE_F32; break;
440
+ case 1: wtype = vtype = GGML_TYPE_F16; break;
441
+ case 2: wtype = vtype = GGML_TYPE_Q4_0; break;
442
+ case 3: wtype = vtype = GGML_TYPE_Q4_1; break;
443
+ case 4: wtype = GGML_TYPE_Q4_1; vtype = GGML_TYPE_F16; break;
444
+ default:
445
+ {
446
+ fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
447
+ __func__, fname.c_str(), model.hparams.f16);
448
+ return false;
449
+ }
450
+ }
451
+
452
+ auto & ctx = model.ctx;
453
+
454
+ size_t ctx_size = 0;
455
+
456
+ {
457
+ const auto & hparams = model.hparams;
458
+
459
+ const int n_embd = hparams.n_embd;
460
+ const int n_layer = hparams.n_layer;
461
+ const int n_ctx = hparams.n_ctx;
462
+ const int n_vocab = hparams.n_vocab;
463
+
464
+ ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings
465
+
466
+ ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
467
+
468
+ ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output
469
+
470
+ ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
471
+
472
+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
473
+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
474
+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
475
+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
476
+
477
+ ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
478
+
479
+ ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
480
+ ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
481
+ ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
482
+
483
+ ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_k
484
+ ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_v
485
+
486
+ ctx_size += (5 + 10*n_layer)*256; // object overhead
487
+
488
+ fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
489
+ }
490
+
491
+ // print memory requirements
492
+ {
493
+ const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
494
+
495
+ // this is the total memory required to run the inference
496
+ const size_t mem_required =
497
+ ctx_size +
498
+ MEM_REQ_SCRATCH0.at(model.type) +
499
+ MEM_REQ_SCRATCH1.at(model.type) +
500
+ MEM_REQ_EVAL.at (model.type);
501
+
502
+ // this is the memory required by one llama_state
503
+ const size_t mem_required_state =
504
+ scale*MEM_REQ_KV_SELF.at(model.type);
505
+
506
+ fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
507
+ mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
508
+ }
509
+
510
+ // create the ggml context
511
+ {
512
+ lctx.model.buf.resize(ctx_size);
513
+
514
+ struct ggml_init_params params = {
515
+ /*.mem_size =*/ lctx.model.buf.size(),
516
+ /*.mem_buffer =*/ lctx.model.buf.data(),
517
+ };
518
+
519
+ model.ctx = ggml_init(params);
520
+ if (!model.ctx) {
521
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
522
+ return false;
523
+ }
524
+ }
525
+
526
+ // prepare memory for the weights
527
+ {
528
+ const auto & hparams = model.hparams;
529
+
530
+ const int n_embd = hparams.n_embd;
531
+ const int n_layer = hparams.n_layer;
532
+ const int n_vocab = hparams.n_vocab;
533
+
534
+ model.layers.resize(n_layer);
535
+
536
+ model.tok_embeddings = ggml_new_tensor_2d(ctx, vtype, n_embd, n_vocab);
537
+
538
+ model.norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
539
+ model.output = ggml_new_tensor_2d(ctx, vtype, n_embd, n_vocab);
540
+
541
+ // map by name
542
+ model.tensors["tok_embeddings.weight"] = model.tok_embeddings;
543
+
544
+ model.tensors["norm.weight"] = model.norm;
545
+ model.tensors["output.weight"] = model.output;
546
+
547
+ for (int i = 0; i < n_layer; ++i) {
548
+ auto & layer = model.layers[i];
549
+
550
+ layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
551
+
552
+ layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
553
+ layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
554
+ layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
555
+ layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
556
+
557
+ layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
558
+
559
+ layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
560
+ layer.w2 = ggml_new_tensor_2d(ctx, wtype, n_ff, n_embd);
561
+ layer.w3 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
562
+
563
+ // map by name
564
+ model.tensors["layers." + std::to_string(i) + ".attention_norm.weight"] = layer.attention_norm;
565
+
566
+ model.tensors["layers." + std::to_string(i) + ".attention.wq.weight"] = layer.wq;
567
+ model.tensors["layers." + std::to_string(i) + ".attention.wk.weight"] = layer.wk;
568
+ model.tensors["layers." + std::to_string(i) + ".attention.wv.weight"] = layer.wv;
569
+ model.tensors["layers." + std::to_string(i) + ".attention.wo.weight"] = layer.wo;
570
+
571
+ model.tensors["layers." + std::to_string(i) + ".ffn_norm.weight"] = layer.ffn_norm;
572
+
573
+ model.tensors["layers." + std::to_string(i) + ".feed_forward.w1.weight"] = layer.w1;
574
+ model.tensors["layers." + std::to_string(i) + ".feed_forward.w2.weight"] = layer.w2;
575
+ model.tensors["layers." + std::to_string(i) + ".feed_forward.w3.weight"] = layer.w3;
576
+ }
577
+ }
578
+
579
+ const size_t file_offset = fin.tellg();
580
+
581
+ fin.close();
582
+
583
+ std::vector<uint8_t> tmp;
584
+
585
+ if (progress_callback) {
586
+ progress_callback(0.0, progress_callback_user_data);
587
+ }
588
+
589
+ for (int i = 0; i < n_parts; ++i) {
590
+ const int part_id = i;
591
+ //const int part_id = n_parts - i - 1;
592
+
593
+ std::string fname_part = fname;
594
+ if (i > 0) {
595
+ fname_part += "." + std::to_string(i);
596
+ }
597
+
598
+ fprintf(stderr, "%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str());
599
+
600
+ fin = std::ifstream(fname_part, std::ios::binary);
601
+ fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
602
+
603
+ fin.seekg(0, fin.end);
604
+ const size_t file_size = fin.tellg();
605
+
606
+ fin.seekg(file_offset);
607
+
608
+ // load weights
609
+ {
610
+ size_t total_size = 0;
611
+
612
+ model.n_loaded = 0;
613
+
614
+ fprintf(stderr, "%s: ", __func__);
615
+
616
+ while (true) {
617
+ int32_t n_dims;
618
+ int32_t length;
619
+ int32_t ftype;
620
+
621
+ fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
622
+ fin.read(reinterpret_cast<char *>(&length), sizeof(length));
623
+ fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
624
+
625
+ if (fin.eof()) {
626
+ break;
627
+ }
628
+
629
+ int32_t nelements = 1;
630
+ int32_t ne[2] = { 1, 1 };
631
+ for (int i = 0; i < n_dims; ++i) {
632
+ fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
633
+ nelements *= ne[i];
634
+ }
635
+
636
+ std::string name(length, 0);
637
+ fin.read(&name[0], length);
638
+
639
+ if (model.tensors.find(name.data()) == model.tensors.end()) {
640
+ fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
641
+ return false;
642
+ }
643
+
644
+ // split_type = 0: split by columns
645
+ // split_type = 1: split by rows
646
+ int split_type = 0;
647
+
648
+ // split_type = 0:
649
+ // regex:
650
+ // - tok_embeddings.*
651
+ // - layers.*.attention.wo.weight
652
+ // - layers.*.feed_forward.w2.weight
653
+
654
+ // split_type = 1:
655
+ // regex:
656
+ // - output.*
657
+ // - layers.*.attention.wq.weight
658
+ // - layers.*.attention.wk.weight
659
+ // - layers.*.attention.wv.weight
660
+ // - layers.*.feed_forward.w1.weight
661
+ // - layers.*.feed_forward.w3.weight
662
+ if (name.find("tok_embeddings") != std::string::npos) {
663
+ split_type = 0;
664
+ } else if (name.find("layers") != std::string::npos) {
665
+ if (name.find("attention.wo.weight") != std::string::npos) {
666
+ split_type = 0;
667
+ } else if (name.find("feed_forward.w2.weight") != std::string::npos) {
668
+ split_type = 0;
669
+ } else {
670
+ split_type = 1;
671
+ }
672
+ } else if (name.find("output") != std::string::npos) {
673
+ split_type = 1;
674
+ }
675
+
676
+ auto tensor = model.tensors[name.data()];
677
+
678
+ if (n_dims == 1) {
679
+ if (ggml_nelements(tensor) != nelements) {
680
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
681
+ return false;
682
+ }
683
+ } else {
684
+ if (ggml_nelements(tensor)/n_parts != nelements) {
685
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
686
+ return false;
687
+ }
688
+ }
689
+
690
+ if (n_dims == 1) {
691
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
692
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
693
+ __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
694
+ return false;
695
+ }
696
+ } else {
697
+ if (split_type == 0) {
698
+ if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) {
699
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
700
+ __func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]);
701
+ return false;
702
+ }
703
+ } else {
704
+ if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) {
705
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
706
+ __func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]);
707
+ return false;
708
+ }
709
+ }
710
+ }
711
+
712
+ if (0) {
713
+ static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
714
+ fprintf(stderr, "%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type);
715
+ }
716
+
717
+ size_t bpe = 0;
718
+
719
+ switch (ftype) {
720
+ case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
721
+ case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
722
+ case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
723
+ case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
724
+ default:
725
+ {
726
+ fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
727
+ return false;
728
+ }
729
+ };
730
+
731
+ if (n_dims == 1 || n_parts == 1) {
732
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
733
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
734
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
735
+ return false;
736
+ }
737
+
738
+ if (part_id == 0) {
739
+ fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
740
+ } else {
741
+ fin.seekg(ggml_nbytes(tensor), std::ios::cur);
742
+ }
743
+
744
+ total_size += ggml_nbytes(tensor);
745
+ } else {
746
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
747
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
748
+ __func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
749
+ return false;
750
+ }
751
+
752
+ if (split_type == 0) {
753
+ const int np0 = ne[0];
754
+
755
+ const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
756
+ assert(row_size == tensor->nb[1]);
757
+
758
+ for (int i1 = 0; i1 < ne[1]; ++i1) {
759
+ const size_t offset_row = i1*row_size;
760
+ const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
761
+ fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
762
+ }
763
+ } else {
764
+ const int np1 = ne[1];
765
+
766
+ const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
767
+
768
+ for (int i1 = 0; i1 < ne[1]; ++i1) {
769
+ const size_t offset_row = (i1 + part_id*np1)*row_size;
770
+ fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
771
+ }
772
+ }
773
+
774
+ total_size += ggml_nbytes(tensor)/n_parts;
775
+ }
776
+
777
+ //fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
778
+ model.n_loaded++;
779
+
780
+ // progress
781
+ if (progress_callback) {
782
+ double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
783
+ double current_progress = (double(i) + current_file_progress) / double(n_parts);
784
+ progress_callback(current_progress, progress_callback_user_data);
785
+ }
786
+ if (model.n_loaded % 8 == 0) {
787
+ fprintf(stderr, ".");
788
+ fflush(stderr);
789
+ }
790
+ }
791
+
792
+ fprintf(stderr, " done\n");
793
+
794
+ fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
795
+ if (model.n_loaded == 0) {
796
+ fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
797
+ } else if (model.n_loaded != (int) model.tensors.size()) {
798
+ fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
799
+ return false;
800
+ }
801
+ }
802
+
803
+ fin.close();
804
+ }
805
+
806
+ lctx.t_load_us = ggml_time_us() - t_start_us;
807
+
808
+ if (progress_callback) {
809
+ progress_callback(1.0, progress_callback_user_data);
810
+ }
811
+
812
+ return true;
813
+ }
814
+
815
+ // evaluate the transformer
816
+ //
817
+ // - lctx: llama context
818
+ // - tokens: new batch of tokens to process
819
+ // - n_past: the context size so far
820
+ // - n_threads: number of threads to use
821
+ //
822
+ static bool llama_eval_internal(
823
+ llama_context & lctx,
824
+ const llama_token * tokens,
825
+ const int n_tokens,
826
+ const int n_past,
827
+ const int n_threads) {
828
+ const int64_t t_start_us = ggml_time_us();
829
+
830
+ const int N = n_tokens;
831
+
832
+ const auto & model = lctx.model;
833
+ const auto & hparams = model.hparams;
834
+
835
+ auto & kv_self = model.kv_self;
836
+
837
+ LLAMA_ASSERT(!!kv_self.ctx);
838
+
839
+ const int n_embd = hparams.n_embd;
840
+ const int n_layer = hparams.n_layer;
841
+ const int n_ctx = hparams.n_ctx;
842
+ const int n_head = hparams.n_head;
843
+ const int n_vocab = hparams.n_vocab;
844
+ const int n_rot = hparams.n_embd/hparams.n_head;
845
+
846
+ auto & mem_per_token = lctx.mem_per_token;
847
+ auto & buf_compute = lctx.buf_compute;
848
+
849
+ struct ggml_init_params params = {
850
+ /*.mem_size =*/ buf_compute.size(),
851
+ /*.mem_buffer =*/ buf_compute.data(),
852
+ };
853
+
854
+ struct ggml_context * ctx0 = ggml_init(params);
855
+
856
+ // for big prompts, if BLAS is enabled, it is better to use only one thread
857
+ // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
858
+ ggml_cgraph gf = {};
859
+ gf.n_threads = N > 255 && ggml_cpu_has_blas() ? 1 : n_threads;
860
+
861
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
862
+ memcpy(embd->data, tokens, N*ggml_element_size(embd));
863
+
864
+ struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
865
+
866
+ for (int il = 0; il < n_layer; ++il) {
867
+ struct ggml_tensor * inpSA = inpL;
868
+
869
+ struct ggml_tensor * cur;
870
+
871
+ lctx.use_buf(ctx0, 0);
872
+
873
+ // norm
874
+ {
875
+ cur = ggml_rms_norm(ctx0, inpL);
876
+
877
+ // cur = attention_norm*cur
878
+ cur = ggml_mul(ctx0,
879
+ ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
880
+ cur);
881
+ }
882
+
883
+ // self-attention
884
+ {
885
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
886
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
887
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
888
+
889
+ // store key and value to memory
890
+ if (N >= 1) {
891
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
892
+ struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_embd, (ggml_element_size(kv_self.v)*n_embd)*(il*n_ctx + n_past));
893
+
894
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
895
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
896
+ }
897
+
898
+ // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
899
+ struct ggml_tensor * Q =
900
+ ggml_permute(ctx0,
901
+ ggml_rope(ctx0,
902
+ ggml_cpy(ctx0,
903
+ Qcur,
904
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
905
+ n_past, n_rot, 0),
906
+ 0, 2, 1, 3);
907
+
908
+ // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
909
+ struct ggml_tensor * K =
910
+ ggml_permute(ctx0,
911
+ ggml_rope(ctx0,
912
+ ggml_reshape_3d(ctx0,
913
+ ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
914
+ n_embd/n_head, n_head, n_past + N),
915
+ n_past, n_rot, 1),
916
+ 0, 2, 1, 3);
917
+
918
+ // K * Q
919
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
920
+
921
+ // KQ_scaled = KQ / sqrt(n_embd/n_head)
922
+ struct ggml_tensor * KQ_scaled =
923
+ ggml_scale(ctx0,
924
+ KQ,
925
+ ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)));
926
+
927
+ // KQ_masked = mask_past(KQ_scaled)
928
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
929
+
930
+ // KQ = soft_max(KQ_masked)
931
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
932
+
933
+ // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
934
+ struct ggml_tensor * V_trans =
935
+ ggml_cpy(ctx0,
936
+ ggml_permute(ctx0,
937
+ ggml_reshape_3d(ctx0,
938
+ ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
939
+ n_embd/n_head, n_head, n_past + N),
940
+ 1, 2, 0, 3),
941
+ ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
942
+
943
+ // KQV = transpose(V) * KQ_soft_max
944
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
945
+
946
+ // KQV_merged = KQV.permute(0, 2, 1, 3)
947
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
948
+
949
+ // cur = KQV_merged.contiguous().view(n_embd, N)
950
+ cur = ggml_cpy(ctx0,
951
+ KQV_merged,
952
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
953
+
954
+ // projection (no bias)
955
+ cur = ggml_mul_mat(ctx0,
956
+ model.layers[il].wo,
957
+ cur);
958
+ }
959
+
960
+ lctx.use_buf(ctx0, 1);
961
+
962
+ struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
963
+
964
+ // feed-forward network
965
+ {
966
+ // norm
967
+ {
968
+ cur = ggml_rms_norm(ctx0, inpFF);
969
+
970
+ // cur = ffn_norm*cur
971
+ cur = ggml_mul(ctx0,
972
+ ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
973
+ cur);
974
+ }
975
+
976
+ struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
977
+ model.layers[il].w3,
978
+ cur);
979
+
980
+ cur = ggml_mul_mat(ctx0,
981
+ model.layers[il].w1,
982
+ cur);
983
+
984
+ // SILU activation
985
+ cur = ggml_silu(ctx0, cur);
986
+
987
+ cur = ggml_mul(ctx0, cur, tmp);
988
+
989
+ cur = ggml_mul_mat(ctx0,
990
+ model.layers[il].w2,
991
+ cur);
992
+ }
993
+
994
+ cur = ggml_add(ctx0, cur, inpFF);
995
+
996
+ // input for next layer
997
+ inpL = cur;
998
+ }
999
+
1000
+ lctx.use_buf(ctx0, 0);
1001
+
1002
+ // used at the end to optionally extract the embeddings
1003
+ struct ggml_tensor * embeddings = NULL;
1004
+
1005
+ // norm
1006
+ {
1007
+
1008
+ inpL = ggml_rms_norm(ctx0, inpL);
1009
+
1010
+ // inpL = norm*inpL
1011
+ inpL = ggml_mul(ctx0,
1012
+ ggml_repeat(ctx0, model.norm, inpL),
1013
+ inpL);
1014
+
1015
+ embeddings = inpL;
1016
+ }
1017
+
1018
+ // lm_head
1019
+ inpL = ggml_mul_mat(ctx0, model.output, inpL);
1020
+
1021
+ lctx.use_buf(ctx0, -1);
1022
+
1023
+ // logits -> probs
1024
+ //inpL = ggml_soft_max(ctx0, inpL);
1025
+
1026
+ // run the computation
1027
+ ggml_build_forward_expand(&gf, inpL);
1028
+ ggml_graph_compute (ctx0, &gf);
1029
+
1030
+ //if (n_past%100 == 0) {
1031
+ // ggml_graph_print (&gf);
1032
+ // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
1033
+ //}
1034
+
1035
+ //embd_w.resize(n_vocab*N);
1036
+ //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
1037
+
1038
+ // extract logits
1039
+ {
1040
+ auto & logits_out = lctx.logits;
1041
+
1042
+ if (lctx.logits_all) {
1043
+ logits_out.resize(n_vocab * N);
1044
+ memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
1045
+ } else {
1046
+ // return result for just the last token
1047
+ logits_out.resize(n_vocab);
1048
+ memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
1049
+ }
1050
+ }
1051
+
1052
+ // extract embeddings
1053
+ if (lctx.embedding.size()) {
1054
+ auto & embedding_out = lctx.embedding;
1055
+
1056
+ embedding_out.resize(n_embd);
1057
+ memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
1058
+ }
1059
+
1060
+ if (mem_per_token == 0) {
1061
+ mem_per_token = ggml_used_mem(ctx0)/N;
1062
+ }
1063
+
1064
+ #if 0
1065
+ printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__,
1066
+ ggml_used_mem(ctx0)/1024.0/1024.0,
1067
+ lctx.get_buf_max_mem(0)/1024.0/1024.0,
1068
+ lctx.get_buf_max_mem(1)/1024.0/1024.0);
1069
+ #endif
1070
+
1071
+ ggml_free(ctx0);
1072
+
1073
+ // measure the performance only for the single-token evals
1074
+ if (N == 1) {
1075
+ lctx.t_eval_us += ggml_time_us() - t_start_us;
1076
+ lctx.n_eval++;
1077
+ }
1078
+ else if (N > 1) {
1079
+ lctx.t_p_eval_us += ggml_time_us() - t_start_us;
1080
+ lctx.n_p_eval += N;
1081
+ }
1082
+
1083
+ return true;
1084
+ }
1085
+
1086
+ //
1087
+ // tokenizer
1088
+ //
1089
+
1090
+ static size_t utf8_len(char src) {
1091
+ const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
1092
+ uint8_t highbits = static_cast<uint8_t>(src) >> 4;
1093
+ return lookup[highbits];
1094
+ }
1095
+
1096
+ struct llama_sp_symbol {
1097
+ using index = int;
1098
+ index prev;
1099
+ index next;
1100
+ const char * text;
1101
+ size_t n;
1102
+ };
1103
+
1104
+ struct llama_sp_bigram {
1105
+ struct comparator {
1106
+ bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) {
1107
+ return (l.score < r.score) || (l.score == r.score && l.left > r.left);
1108
+ }
1109
+ };
1110
+ using queue_storage = std::vector<llama_sp_bigram>;
1111
+ using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>;
1112
+ llama_sp_symbol::index left;
1113
+ llama_sp_symbol::index right;
1114
+ float score;
1115
+ size_t size;
1116
+ };
1117
+
1118
+ // original implementation:
1119
+ // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
1120
+ struct llama_tokenizer {
1121
+ llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
1122
+
1123
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
1124
+ // split string into utf8 chars
1125
+ int index = 0;
1126
+ size_t offs = 0;
1127
+ while (offs < text.size()) {
1128
+ llama_sp_symbol sym;
1129
+ size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
1130
+ sym.text = text.c_str() + offs;
1131
+ sym.n = char_len;
1132
+ offs += char_len;
1133
+ sym.prev = index - 1;
1134
+ sym.next = offs == text.size() ? -1 : index + 1;
1135
+ index++;
1136
+ symbols_.emplace_back(std::move(sym));
1137
+ }
1138
+
1139
+ // seed the work queue with all possible 2-character tokens.
1140
+ for (size_t i = 1; i < symbols_.size(); ++i) {
1141
+ try_add_bigram(i - 1, i);
1142
+ }
1143
+
1144
+ // keep substituting the highest frequency pairs for as long as we can.
1145
+ while (!work_queue_.empty()) {
1146
+ auto bigram = work_queue_.top();
1147
+ work_queue_.pop();
1148
+
1149
+ auto & left_sym = symbols_[bigram.left];
1150
+ auto & right_sym = symbols_[bigram.right];
1151
+
1152
+ // if one of the symbols already got merged, skip it.
1153
+ if (left_sym.n == 0 || right_sym.n == 0 ||
1154
+ left_sym.n + right_sym.n != bigram.size) {
1155
+ continue;
1156
+ }
1157
+
1158
+ // merge the right sym into the left one
1159
+ left_sym.n += right_sym.n;
1160
+ right_sym.n = 0;
1161
+
1162
+ //printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
1163
+
1164
+ // remove the right sym from the chain
1165
+ left_sym.next = right_sym.next;
1166
+ if (right_sym.next >= 0) {
1167
+ symbols_[right_sym.next].prev = bigram.left;
1168
+ }
1169
+
1170
+ // find more substitutions
1171
+ try_add_bigram(left_sym.prev, bigram.left);
1172
+ try_add_bigram(bigram.left, left_sym.next);
1173
+ }
1174
+
1175
+ for (int i = 0; i != -1; i = symbols_[i].next) {
1176
+ auto & symbol = symbols_[i];
1177
+ auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n));
1178
+
1179
+ if (token == vocab_.token_to_id.end()) {
1180
+ // output any symbols that did not form tokens as bytes.
1181
+ for (int j = 0; j < (int) symbol.n; ++j) {
1182
+ llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
1183
+ output.push_back(token_id);
1184
+ }
1185
+ } else {
1186
+ output.push_back((*token).second);
1187
+ }
1188
+ }
1189
+ }
1190
+
1191
+ private:
1192
+ void try_add_bigram(int left, int right) {
1193
+ if (left == -1 || right == -1) {
1194
+ return;
1195
+ }
1196
+
1197
+ const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n);
1198
+ auto token = vocab_.token_to_id.find(text);
1199
+
1200
+ if (token == vocab_.token_to_id.end()) {
1201
+ return;
1202
+ }
1203
+
1204
+ if (static_cast<size_t>((*token).second) >= vocab_.id_to_token.size()) {
1205
+ return;
1206
+ }
1207
+
1208
+ const auto &tok_score = vocab_.id_to_token[(*token).second];
1209
+
1210
+ llama_sp_bigram bigram;
1211
+ bigram.left = left;
1212
+ bigram.right = right;
1213
+ bigram.score = tok_score.score;
1214
+ bigram.size = text.size();
1215
+ work_queue_.push(bigram);
1216
+ }
1217
+
1218
+ const llama_vocab & vocab_;
1219
+ std::vector<llama_sp_symbol> symbols_;
1220
+ llama_sp_bigram::queue work_queue_;
1221
+ };
1222
+
1223
+ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
1224
+ llama_tokenizer tokenizer(vocab);
1225
+ std::vector<llama_vocab::id> output;
1226
+
1227
+ if (text.size() == 0) {
1228
+ return output;
1229
+ }
1230
+
1231
+ if (bos) {
1232
+ output.push_back(1);
1233
+ }
1234
+
1235
+ tokenizer.tokenize(text, output);
1236
+ return output;
1237
+ }
1238
+
1239
+ //
1240
+ // sampling
1241
+ //
1242
+
1243
+ static void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k) {
1244
+ // find the top k tokens
1245
+ std::partial_sort(
1246
+ logits_id.begin(),
1247
+ logits_id.begin() + top_k, logits_id.end(),
1248
+ [](const std::pair<double, llama_vocab::id> & a, const std::pair<double, llama_vocab::id> & b) {
1249
+ return a.first > b.first;
1250
+ });
1251
+
1252
+ logits_id.resize(top_k);
1253
+ }
1254
+
1255
+ static llama_vocab::id llama_sample_top_p_top_k(
1256
+ llama_context & lctx,
1257
+ const std::vector<llama_vocab::id> & last_n_tokens,
1258
+ int top_k,
1259
+ double top_p,
1260
+ double temp,
1261
+ double repeat_penalty) {
1262
+ auto & rng = lctx.rng;
1263
+
1264
+ const int n_logits = lctx.model.hparams.n_vocab;
1265
+
1266
+ const auto & logits = lctx.logits;
1267
+ const auto * plogits = logits.data() + logits.size() - n_logits;
1268
+
1269
+ std::vector<std::pair<double, llama_vocab::id>> logits_id;
1270
+ logits_id.reserve(n_logits);
1271
+
1272
+ {
1273
+ const double scale = 1.0/temp;
1274
+ for (int i = 0; i < n_logits; ++i) {
1275
+ // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
1276
+ // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
1277
+ if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
1278
+ // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
1279
+ if (plogits[i] < 0.0) {
1280
+ logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
1281
+ } else {
1282
+ logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
1283
+ }
1284
+ } else {
1285
+ logits_id.push_back(std::make_pair(plogits[i]*scale, i));
1286
+ }
1287
+ }
1288
+ }
1289
+
1290
+ sample_top_k(logits_id, top_k);
1291
+
1292
+ double maxl = -std::numeric_limits<double>::infinity();
1293
+ for (const auto & kv : logits_id) {
1294
+ maxl = std::max(maxl, kv.first);
1295
+ }
1296
+
1297
+ // compute probs for the top k tokens
1298
+ std::vector<double> probs;
1299
+ probs.reserve(logits_id.size());
1300
+
1301
+ double sum = 0.0;
1302
+ for (const auto & kv : logits_id) {
1303
+ double p = exp(kv.first - maxl);
1304
+ probs.push_back(p);
1305
+ sum += p;
1306
+ }
1307
+
1308
+ // normalize the probs
1309
+ for (auto & p : probs) {
1310
+ p /= sum;
1311
+ }
1312
+
1313
+ if (top_p < 1.0f) {
1314
+ double cumsum = 0.0f;
1315
+ for (int i = 0; i < (int) probs.size(); i++) {
1316
+ cumsum += probs[i];
1317
+ if (cumsum >= top_p) {
1318
+ probs.resize(i + 1);
1319
+ logits_id.resize(i + 1);
1320
+ break;
1321
+ }
1322
+ }
1323
+
1324
+ cumsum = 1.0/cumsum;
1325
+ for (int i = 0; i < (int) probs.size(); i++) {
1326
+ probs[i] *= cumsum;
1327
+ }
1328
+ }
1329
+
1330
+ //printf("\n");
1331
+ //for (int i = 0; i < (int) 10; i++) {
1332
+ // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
1333
+ //}
1334
+ //printf("\n\n");
1335
+ //exit(0);
1336
+
1337
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
1338
+ int idx = dist(rng);
1339
+
1340
+ return logits_id[idx].second;
1341
+ }
1342
+
1343
+ //
1344
+ // quantization
1345
+ //
1346
+
1347
+ // TODO: reuse code from the llama_model_load() somehow
1348
+ bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype, int qk) {
1349
+ ggml_type type = GGML_TYPE_Q4_1;
1350
+
1351
+ switch (itype) {
1352
+ case 2: type = GGML_TYPE_Q4_0; break;
1353
+ case 3: type = GGML_TYPE_Q4_1; break;
1354
+ default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1;
1355
+ };
1356
+
1357
+ if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) {
1358
+ fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type);
1359
+ return false;
1360
+ }
1361
+
1362
+ llama_vocab vocab;
1363
+
1364
+ printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
1365
+
1366
+ auto finp = std::ifstream(fname_inp, std::ios::binary);
1367
+ if (!finp) {
1368
+ fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
1369
+ return false;
1370
+ }
1371
+
1372
+ auto fout = std::ofstream(fname_out, std::ios::binary);
1373
+ if (!fout) {
1374
+ fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
1375
+ return false;
1376
+ }
1377
+
1378
+ // verify magic
1379
+ {
1380
+ uint32_t magic;
1381
+ finp.read((char *) &magic, sizeof(magic));
1382
+ if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) {
1383
+ fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
1384
+ __func__, fname_inp.c_str());
1385
+ return false;
1386
+ }
1387
+ if (magic != LLAMA_FILE_MAGIC) {
1388
+ fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
1389
+ return false;
1390
+ }
1391
+
1392
+ fout.write((char *) &magic, sizeof(magic));
1393
+
1394
+ uint32_t format_version;
1395
+ finp.read((char *) &format_version, sizeof(format_version));
1396
+
1397
+ if (format_version != LLAMA_FILE_VERSION) {
1398
+ fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ", expected %d)\n",
1399
+ __func__, fname_inp.c_str(), format_version, LLAMA_FILE_VERSION);
1400
+ return false;
1401
+ }
1402
+
1403
+ fout.write((char *) &format_version, sizeof(format_version));
1404
+ }
1405
+
1406
+ llama_hparams hparams;
1407
+
1408
+ // load hparams
1409
+ {
1410
+ finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
1411
+ //finp.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
1412
+ finp.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
1413
+ finp.read((char *) &hparams.n_mult, sizeof(hparams.n_mult));
1414
+ finp.read((char *) &hparams.n_head, sizeof(hparams.n_head));
1415
+ finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
1416
+ finp.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
1417
+ finp.read((char *) &hparams.f16, sizeof(hparams.f16));
1418
+
1419
+ printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
1420
+ printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
1421
+ printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
1422
+ printf("%s: n_mult = %d\n", __func__, hparams.n_mult);
1423
+ printf("%s: n_head = %d\n", __func__, hparams.n_head);
1424
+ printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
1425
+ printf("%s: f16 = %d\n", __func__, hparams.f16);
1426
+
1427
+ fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
1428
+ //fout.write((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
1429
+ fout.write((char *) &hparams.n_embd, sizeof(hparams.n_embd));
1430
+ fout.write((char *) &hparams.n_mult, sizeof(hparams.n_mult));
1431
+ fout.write((char *) &hparams.n_head, sizeof(hparams.n_head));
1432
+ fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
1433
+ fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot));
1434
+ fout.write((char *) &itype, sizeof(hparams.f16));
1435
+ }
1436
+
1437
+ // load vocab
1438
+ {
1439
+ const int32_t n_vocab = hparams.n_vocab;
1440
+
1441
+ if (n_vocab != hparams.n_vocab) {
1442
+ fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
1443
+ __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
1444
+ return false;
1445
+ }
1446
+
1447
+ std::string word;
1448
+ vocab.id_to_token.resize(n_vocab);
1449
+ for (int i = 0; i < n_vocab; i++) {
1450
+ uint32_t len;
1451
+ finp.read ((char *) &len, sizeof(len));
1452
+ fout.write((char *) &len, sizeof(len));
1453
+
1454
+ word.resize(len);
1455
+ finp.read ((char *) word.data(), len);
1456
+ fout.write((char *) word.data(), len);
1457
+
1458
+ float score;
1459
+ finp.read ((char *) &score, sizeof(score));
1460
+ fout.write((char *) &score, sizeof(score));
1461
+
1462
+ vocab.token_to_id[word] = i;
1463
+
1464
+ auto &tok_score = vocab.id_to_token[i];
1465
+ tok_score.tok = word;
1466
+ tok_score.score = score;
1467
+ }
1468
+ }
1469
+
1470
+ // load weights
1471
+ {
1472
+ size_t total_size_org = 0;
1473
+ size_t total_size_new = 0;
1474
+
1475
+ std::vector<float> work;
1476
+
1477
+ std::vector<uint8_t> data_u8;
1478
+ std::vector<ggml_fp16_t> data_f16;
1479
+ std::vector<float> data_f32;
1480
+
1481
+ std::vector<int64_t> hist_all(1 << 4, 0);
1482
+
1483
+ while (true) {
1484
+ int32_t n_dims;
1485
+ int32_t length;
1486
+ int32_t ftype;
1487
+
1488
+ finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
1489
+ finp.read(reinterpret_cast<char *>(&length), sizeof(length));
1490
+ finp.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
1491
+
1492
+ if (finp.eof()) {
1493
+ break;
1494
+ }
1495
+
1496
+ int32_t nelements = 1;
1497
+ int32_t ne[2] = { 1, 1 };
1498
+ for (int i = 0; i < n_dims; ++i) {
1499
+ finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
1500
+ nelements *= ne[i];
1501
+ }
1502
+
1503
+ std::string name(length, 0);
1504
+ finp.read (&name[0], length);
1505
+
1506
+ {
1507
+ static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
1508
+ printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
1509
+ }
1510
+
1511
+ // regexes of tensor names to be quantized
1512
+ const std::vector<std::string> k_names = {
1513
+ ".*weight",
1514
+ };
1515
+
1516
+ bool quantize = false;
1517
+ for (const auto & s : k_names) {
1518
+ if (std::regex_match(name, std::regex(s))) {
1519
+ quantize = true;
1520
+ break;
1521
+ }
1522
+ }
1523
+
1524
+ // quantize only 2D tensors
1525
+ quantize &= (n_dims == 2);
1526
+
1527
+ if (quantize) {
1528
+ if (ftype != 0 && ftype != 1) {
1529
+ fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype);
1530
+ return false;
1531
+ }
1532
+
1533
+ if (ftype == 1) {
1534
+ data_f16.resize(nelements);
1535
+ finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
1536
+ data_f32.resize(nelements);
1537
+ for (int i = 0; i < nelements; ++i) {
1538
+ data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
1539
+ }
1540
+ } else {
1541
+ data_f32.resize(nelements);
1542
+ finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
1543
+ }
1544
+
1545
+ ftype = itype;
1546
+ } else {
1547
+ const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t);
1548
+
1549
+ data_u8.resize(nelements*bpe);
1550
+ finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
1551
+ }
1552
+
1553
+ fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
1554
+ fout.write(reinterpret_cast<char *>(&length), sizeof(length));
1555
+ fout.write(reinterpret_cast<char *>(&ftype), sizeof(ftype));
1556
+ for (int i = 0; i < n_dims; ++i) {
1557
+ fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
1558
+ }
1559
+ fout.write(&name[0], length);
1560
+
1561
+ if (quantize) {
1562
+ printf("quantizing .. ");
1563
+ work.resize(nelements); // for quantization
1564
+
1565
+ size_t cur_size = 0;
1566
+ std::vector<int64_t> hist_cur(1 << 4, 0);
1567
+
1568
+ switch (type) {
1569
+ case GGML_TYPE_Q4_0:
1570
+ {
1571
+ cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], qk, hist_cur.data());
1572
+ } break;
1573
+ case GGML_TYPE_Q4_1:
1574
+ {
1575
+ cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], qk, hist_cur.data());
1576
+ } break;
1577
+ default:
1578
+ {
1579
+ fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type);
1580
+ return false;
1581
+ }
1582
+ }
1583
+
1584
+ fout.write(reinterpret_cast<char *>(work.data()), cur_size);
1585
+ total_size_new += cur_size;
1586
+
1587
+ printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
1588
+ for (int i = 0; i < (int) hist_cur.size(); ++i) {
1589
+ hist_all[i] += hist_cur[i];
1590
+ }
1591
+
1592
+ for (int i = 0; i < (int) hist_cur.size(); ++i) {
1593
+ printf("%5.3f ", hist_cur[i] / (float)nelements);
1594
+ }
1595
+ printf("\n");
1596
+ } else {
1597
+ printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
1598
+ fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
1599
+ total_size_new += data_u8.size();
1600
+ }
1601
+
1602
+ total_size_org += nelements * sizeof(float);
1603
+ }
1604
+
1605
+ printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
1606
+ printf("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
1607
+
1608
+ {
1609
+ int64_t sum_all = 0;
1610
+ for (int i = 0; i < (int) hist_all.size(); ++i) {
1611
+ sum_all += hist_all[i];
1612
+ }
1613
+
1614
+ printf("%s: hist: ", __func__);
1615
+ for (int i = 0; i < (int) hist_all.size(); ++i) {
1616
+ printf("%5.3f ", hist_all[i] / (float)sum_all);
1617
+ }
1618
+ printf("\n");
1619
+ }
1620
+ }
1621
+
1622
+ finp.close();
1623
+ fout.close();
1624
+
1625
+ return true;
1626
+ }
1627
+
1628
+ //
1629
+ // interface implementation
1630
+ //
1631
+
1632
+ struct llama_context * llama_init_from_file(
1633
+ const char * path_model,
1634
+ struct llama_context_params params) {
1635
+ ggml_time_init();
1636
+
1637
+ llama_context * ctx = new llama_context;
1638
+
1639
+ if (params.seed <= 0) {
1640
+ params.seed = time(NULL);
1641
+ }
1642
+
1643
+ ctx->rng = std::mt19937(params.seed);
1644
+ ctx->logits_all = params.logits_all;
1645
+
1646
+ ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
1647
+
1648
+ if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
1649
+ params.vocab_only, params.progress_callback,
1650
+ params.progress_callback_user_data)) {
1651
+ fprintf(stderr, "%s: failed to load model\n", __func__);
1652
+ llama_free(ctx);
1653
+ return nullptr;
1654
+ }
1655
+
1656
+ if (params.use_mlock) {
1657
+ char *err;
1658
+ if (!ggml_mlock(ctx->model.ctx, &err)) {
1659
+ fprintf(stderr, "%s\n", err);
1660
+ free(err);
1661
+ llama_free(ctx);
1662
+ return nullptr;
1663
+ }
1664
+ }
1665
+
1666
+ // reserve memory for context buffers
1667
+ {
1668
+ if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) {
1669
+ fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
1670
+ llama_free(ctx);
1671
+ return nullptr;
1672
+ }
1673
+
1674
+ {
1675
+ const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
1676
+ fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
1677
+ }
1678
+
1679
+ const auto & hparams = ctx->model.hparams;
1680
+
1681
+ // resized during inference
1682
+ if (params.logits_all) {
1683
+ ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
1684
+ } else {
1685
+ ctx->logits.reserve(hparams.n_ctx);
1686
+ }
1687
+
1688
+ if (params.embedding){
1689
+ ctx->embedding.resize(hparams.n_embd);
1690
+ }
1691
+
1692
+ ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
1693
+
1694
+ ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
1695
+ ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
1696
+ }
1697
+
1698
+ return ctx;
1699
+ }
1700
+
1701
+ void llama_free(struct llama_context * ctx) {
1702
+ kv_cache_free(ctx->model.kv_self);
1703
+
1704
+ if (ctx->model.ctx) {
1705
+ ggml_free(ctx->model.ctx);
1706
+ }
1707
+
1708
+ delete ctx;
1709
+ }
1710
+
1711
+ int llama_model_quantize(
1712
+ const char * fname_inp,
1713
+ const char * fname_out,
1714
+ int itype,
1715
+ int qk) {
1716
+ if (!llama_model_quantize_internal(fname_inp, fname_out, itype, qk)) {
1717
+ fprintf(stderr, "%s: failed to quantize\n", __func__);
1718
+ return 1;
1719
+ }
1720
+
1721
+ return 0;
1722
+ }
1723
+
1724
+ int llama_eval(
1725
+ struct llama_context * ctx,
1726
+ const llama_token * tokens,
1727
+ int n_tokens,
1728
+ int n_past,
1729
+ int n_threads) {
1730
+ if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) {
1731
+ fprintf(stderr, "%s: failed to eval\n", __func__);
1732
+ return 1;
1733
+ }
1734
+
1735
+ return 0;
1736
+ }
1737
+
1738
+ int llama_tokenize(
1739
+ struct llama_context * ctx,
1740
+ const char * text,
1741
+ llama_token * tokens,
1742
+ int n_max_tokens,
1743
+ bool add_bos) {
1744
+ auto res = llama_tokenize(ctx->vocab, text, add_bos);
1745
+
1746
+ if (n_max_tokens < (int) res.size()) {
1747
+ fprintf(stderr, "%s: too many tokens\n", __func__);
1748
+ return -((int) res.size());
1749
+ }
1750
+
1751
+ for (size_t i = 0; i < res.size(); i++) {
1752
+ tokens[i] = res[i];
1753
+ }
1754
+
1755
+ return res.size();
1756
+ }
1757
+
1758
+ int llama_n_vocab(struct llama_context * ctx) {
1759
+ return ctx->vocab.id_to_token.size();
1760
+ }
1761
+
1762
+ int llama_n_ctx(struct llama_context * ctx) {
1763
+ return ctx->model.hparams.n_ctx;
1764
+ }
1765
+
1766
+ int llama_n_embd(struct llama_context * ctx) {
1767
+ return ctx->model.hparams.n_embd;
1768
+ }
1769
+
1770
+ float * llama_get_logits(struct llama_context * ctx) {
1771
+ return ctx->logits.data();
1772
+ }
1773
+
1774
+ float * llama_get_embeddings(struct llama_context * ctx) {
1775
+ return ctx->embedding.data();
1776
+ }
1777
+
1778
+ const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
1779
+ if (token >= llama_n_vocab(ctx)) {
1780
+ return nullptr;
1781
+ }
1782
+
1783
+ return ctx->vocab.id_to_token[token].tok.c_str();
1784
+ }
1785
+
1786
+ llama_token llama_token_bos() {
1787
+ return 1;
1788
+ }
1789
+
1790
+ llama_token llama_token_eos() {
1791
+ return 2;
1792
+ }
1793
+
1794
+ llama_token llama_sample_top_p_top_k(
1795
+ llama_context * ctx,
1796
+ const llama_token * last_n_tokens_data,
1797
+ int last_n_tokens_size,
1798
+ int top_k,
1799
+ double top_p,
1800
+ double temp,
1801
+ double repeat_penalty) {
1802
+ const int64_t t_start_sample_us = ggml_time_us();
1803
+
1804
+ llama_token result = 0;
1805
+
1806
+ // TODO: avoid this ...
1807
+ const auto last_n_tokens = std::vector<llama_token>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
1808
+
1809
+ result = llama_sample_top_p_top_k(
1810
+ *ctx,
1811
+ last_n_tokens,
1812
+ top_k,
1813
+ top_p,
1814
+ temp,
1815
+ repeat_penalty);
1816
+
1817
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
1818
+ ctx->n_sample++;
1819
+
1820
+ return result;
1821
+ }
1822
+
1823
+
1824
+ void llama_print_timings(struct llama_context * ctx) {
1825
+ const int64_t t_end_us = ggml_time_us();
1826
+
1827
+ const int32_t n_sample = std::max(1, ctx->n_sample);
1828
+ const int32_t n_eval = std::max(1, ctx->n_eval);
1829
+ const int32_t n_p_eval = std::max(1, ctx->n_p_eval);
1830
+
1831
+ fprintf(stderr, "\n");
1832
+ fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
1833
+ fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->t_sample_us, n_sample, 1e-3f * ctx->t_sample_us / n_sample);
1834
+ fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3f * ctx->t_p_eval_us, n_p_eval, 1e-3f * ctx->t_p_eval_us / n_p_eval);
1835
+ fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->t_eval_us, n_eval, 1e-3f * ctx->t_eval_us / n_eval);
1836
+ fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
1837
+ }
1838
+
1839
+ void llama_reset_timings(struct llama_context * ctx) {
1840
+ ctx->t_start_us = ggml_time_us();
1841
+
1842
+ ctx->t_sample_us = ctx->n_sample = 0;
1843
+ ctx->t_eval_us = ctx->n_eval = 0;
1844
+ ctx->t_p_eval_us = ctx->n_p_eval = 0;
1845
+ }
1846
+
1847
+ const char * llama_print_system_info(void) {
1848
+ static std::string s;
1849
+
1850
+ s = "";
1851
+ s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
1852
+ s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
1853
+ s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
1854
+ s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
1855
+ s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
1856
+ s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
1857
+ s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
1858
+ s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
1859
+ s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
1860
+ s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
1861
+ s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
1862
+ s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
1863
+
1864
+ return s.c_str();
1865
+ }
examples/talk-llama/llama.h ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef LLAMA_H
2
+ #define LLAMA_H
3
+
4
+ #include <stddef.h>
5
+ #include <stdint.h>
6
+ #include <stdbool.h>
7
+
8
+ #ifdef LLAMA_SHARED
9
+ # ifdef _WIN32
10
+ # ifdef LLAMA_BUILD
11
+ # define LLAMA_API __declspec(dllexport)
12
+ # else
13
+ # define LLAMA_API __declspec(dllimport)
14
+ # endif
15
+ # else
16
+ # define LLAMA_API __attribute__ ((visibility ("default")))
17
+ # endif
18
+ #else
19
+ # define LLAMA_API
20
+ #endif
21
+
22
+ #define LLAMA_FILE_VERSION 1
23
+ #define LLAMA_FILE_MAGIC 0x67676d66 // 'ggmf' in hex
24
+ #define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
25
+
26
+ #ifdef __cplusplus
27
+ extern "C" {
28
+ #endif
29
+
30
+ //
31
+ // C interface
32
+ //
33
+ // TODO: show sample usage
34
+ //
35
+
36
+ struct llama_context;
37
+
38
+ typedef int llama_token;
39
+
40
+ typedef struct llama_token_data {
41
+ llama_token id; // token id
42
+
43
+ float p; // probability of the token
44
+ float plog; // log probability of the token
45
+
46
+ } llama_token_data;
47
+
48
+ typedef void (*llama_progress_callback)(double progress, void *ctx);
49
+
50
+ struct llama_context_params {
51
+ int n_ctx; // text context
52
+ int n_parts; // -1 for default
53
+ int seed; // RNG seed, 0 for random
54
+
55
+ bool f16_kv; // use fp16 for KV cache
56
+ bool logits_all; // the llama_eval() call computes all logits, not just the last one
57
+ bool vocab_only; // only load the vocabulary, no weights
58
+ bool use_mlock; // force system to keep model in RAM
59
+ bool embedding; // embedding mode only
60
+
61
+ // called with a progress value between 0 and 1, pass NULL to disable
62
+ llama_progress_callback progress_callback;
63
+ // context pointer passed to the progress callback
64
+ void * progress_callback_user_data;
65
+ };
66
+
67
+ LLAMA_API struct llama_context_params llama_context_default_params();
68
+
69
+ // Various functions for loading a ggml llama model.
70
+ // Allocate (almost) all memory needed for the model.
71
+ // Return NULL on failure
72
+ LLAMA_API struct llama_context * llama_init_from_file(
73
+ const char * path_model,
74
+ struct llama_context_params params);
75
+
76
+ // Frees all allocated memory
77
+ LLAMA_API void llama_free(struct llama_context * ctx);
78
+
79
+ // TODO: not great API - very likely to change
80
+ // Returns 0 on success
81
+ LLAMA_API int llama_model_quantize(
82
+ const char * fname_inp,
83
+ const char * fname_out,
84
+ int itype,
85
+ int qk);
86
+
87
+ // Run the llama inference to obtain the logits and probabilities for the next token.
88
+ // tokens + n_tokens is the provided batch of new tokens to process
89
+ // n_past is the number of tokens to use from previous eval calls
90
+ // Returns 0 on success
91
+ LLAMA_API int llama_eval(
92
+ struct llama_context * ctx,
93
+ const llama_token * tokens,
94
+ int n_tokens,
95
+ int n_past,
96
+ int n_threads);
97
+
98
+ // Convert the provided text into tokens.
99
+ // The tokens pointer must be large enough to hold the resulting tokens.
100
+ // Returns the number of tokens on success, no more than n_max_tokens
101
+ // Returns a negative number on failure - the number of tokens that would have been returned
102
+ // TODO: not sure if correct
103
+ LLAMA_API int llama_tokenize(
104
+ struct llama_context * ctx,
105
+ const char * text,
106
+ llama_token * tokens,
107
+ int n_max_tokens,
108
+ bool add_bos);
109
+
110
+ LLAMA_API int llama_n_vocab(struct llama_context * ctx);
111
+ LLAMA_API int llama_n_ctx (struct llama_context * ctx);
112
+ LLAMA_API int llama_n_embd (struct llama_context * ctx);
113
+
114
+ // Token logits obtained from the last call to llama_eval()
115
+ // The logits for the last token are stored in the last row
116
+ // Can be mutated in order to change the probabilities of the next token
117
+ // Rows: n_tokens
118
+ // Cols: n_vocab
119
+ LLAMA_API float * llama_get_logits(struct llama_context * ctx);
120
+
121
+ // Get the embeddings for the input
122
+ // shape: [n_embd] (1-dimensional)
123
+ LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
124
+
125
+ // Token Id -> String. Uses the vocabulary in the provided context
126
+ LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
127
+
128
+ // Special tokens
129
+ LLAMA_API llama_token llama_token_bos();
130
+ LLAMA_API llama_token llama_token_eos();
131
+
132
+ // TODO: improve the last_n_tokens interface ?
133
+ LLAMA_API llama_token llama_sample_top_p_top_k(
134
+ struct llama_context * ctx,
135
+ const llama_token * last_n_tokens_data,
136
+ int last_n_tokens_size,
137
+ int top_k,
138
+ double top_p,
139
+ double temp,
140
+ double repeat_penalty);
141
+
142
+ // Performance information
143
+ LLAMA_API void llama_print_timings(struct llama_context * ctx);
144
+ LLAMA_API void llama_reset_timings(struct llama_context * ctx);
145
+
146
+ // Print system information
147
+ LLAMA_API const char * llama_print_system_info(void);
148
+
149
+ #ifdef __cplusplus
150
+ }
151
+ #endif
152
+
153
+ #endif
examples/talk-llama/speak.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Usage:
4
+ # speak.sh <voice_id> <text-to-speak>
5
+
6
+ # espeak
7
+ # Mac OS: brew install espeak
8
+ # Linux: apt-get install espeak
9
+ #
10
+ #espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2"
11
+
12
+ # for Mac
13
+ say "$2"
14
+
15
+ # Eleven Labs
16
+ #
17
+ #wd=$(dirname $0)
18
+ #script=$wd/eleven-labs.py
19
+ #python3 $script $1 "$2" >/dev/null 2>&1
20
+ #ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1
examples/talk-llama/talk-llama.cpp ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Talk with AI
2
+ //
3
+
4
+ #include "common.h"
5
+ #include "common-sdl.h"
6
+ #include "whisper.h"
7
+ #include "llama.h"
8
+
9
+ #include <cassert>
10
+ #include <cstdio>
11
+ #include <fstream>
12
+ #include <regex>
13
+ #include <string>
14
+ #include <thread>
15
+ #include <vector>
16
+ #include <regex>
17
+
18
+ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
19
+ // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
20
+ std::vector<llama_token> res(text.size() + (int)add_bos);
21
+ int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
22
+ assert(n >= 0);
23
+ res.resize(n);
24
+
25
+ return res;
26
+ }
27
+
28
+ // command-line parameters
29
+ struct whisper_params {
30
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
31
+ int32_t voice_ms = 10000;
32
+ int32_t capture_id = -1;
33
+ int32_t max_tokens = 32;
34
+ int32_t audio_ctx = 0;
35
+
36
+ float vad_thold = 0.6f;
37
+ float freq_thold = 100.0f;
38
+
39
+ bool speed_up = false;
40
+ bool translate = false;
41
+ bool print_special = false;
42
+ bool print_energy = false;
43
+ bool no_timestamps = true;
44
+
45
+ std::string person = "Georgi";
46
+ std::string language = "en";
47
+ std::string model_wsp = "models/ggml-base.en.bin";
48
+ std::string model_llama = "models/ggml-llama-7B.bin";
49
+ std::string speak = "./examples/talk/speak.sh";
50
+ std::string fname_out;
51
+ };
52
+
53
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
54
+
55
+ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
56
+ for (int i = 1; i < argc; i++) {
57
+ std::string arg = argv[i];
58
+
59
+ if (arg == "-h" || arg == "--help") {
60
+ whisper_print_usage(argc, argv, params);
61
+ exit(0);
62
+ }
63
+ else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
64
+ else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
65
+ else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
66
+ else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
67
+ else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
68
+ else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
69
+ else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
70
+ else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
71
+ else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
72
+ else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
73
+ else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
74
+ else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
75
+ else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
76
+ else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
77
+ else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
78
+ else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
79
+ else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
80
+ else {
81
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
82
+ whisper_print_usage(argc, argv, params);
83
+ exit(0);
84
+ }
85
+ }
86
+
87
+ return true;
88
+ }
89
+
90
+ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
91
+ fprintf(stderr, "\n");
92
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
93
+ fprintf(stderr, "\n");
94
+ fprintf(stderr, "options:\n");
95
+ fprintf(stderr, " -h, --help [default] show this help message and exit\n");
96
+ fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
97
+ fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
98
+ fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
99
+ fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
100
+ fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
101
+ fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
102
+ fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
103
+ fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
104
+ fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
105
+ fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
106
+ fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
107
+ fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
108
+ fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
109
+ fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
110
+ fprintf(stderr, " -mg FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
111
+ fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
112
+ fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
113
+ fprintf(stderr, "\n");
114
+ }
115
+
116
+ std::string transcribe(
117
+ whisper_context * ctx,
118
+ const whisper_params & params,
119
+ const std::vector<float> & pcmf32,
120
+ const std::string prompt_text,
121
+ float & prob,
122
+ int64_t & t_ms) {
123
+ const auto t_start = std::chrono::high_resolution_clock::now();
124
+
125
+ prob = 0.0f;
126
+ t_ms = 0;
127
+
128
+ std::vector<whisper_token> prompt_tokens;
129
+
130
+ whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
131
+
132
+ prompt_tokens.resize(1024);
133
+ prompt_tokens.resize(whisper_tokenize(ctx, prompt_text.c_str(), prompt_tokens.data(), prompt_tokens.size()));
134
+
135
+ wparams.print_progress = false;
136
+ wparams.print_special = params.print_special;
137
+ wparams.print_realtime = false;
138
+ wparams.print_timestamps = !params.no_timestamps;
139
+ wparams.translate = params.translate;
140
+ wparams.no_context = true;
141
+ wparams.single_segment = true;
142
+ wparams.max_tokens = params.max_tokens;
143
+ wparams.language = params.language.c_str();
144
+ wparams.n_threads = params.n_threads;
145
+
146
+ wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
147
+ wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
148
+
149
+ wparams.audio_ctx = params.audio_ctx;
150
+ wparams.speed_up = params.speed_up;
151
+
152
+ if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
153
+ return "";
154
+ }
155
+
156
+ int prob_n = 0;
157
+ std::string result;
158
+
159
+ const int n_segments = whisper_full_n_segments(ctx);
160
+ for (int i = 0; i < n_segments; ++i) {
161
+ const char * text = whisper_full_get_segment_text(ctx, i);
162
+
163
+ result += text;
164
+
165
+ const int n_tokens = whisper_full_n_tokens(ctx, i);
166
+ for (int j = 0; j < n_tokens; ++j) {
167
+ const auto token = whisper_full_get_token_data(ctx, i, j);
168
+
169
+ prob += token.p;
170
+ ++prob_n;
171
+ }
172
+ }
173
+
174
+ if (prob_n > 0) {
175
+ prob /= prob_n;
176
+ }
177
+
178
+ const auto t_end = std::chrono::high_resolution_clock::now();
179
+ t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
180
+
181
+ return result;
182
+ }
183
+
184
+ const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
185
+
186
+ // need to have leading ' '
187
+ const std::string k_prompt_llama = R"( Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
188
+ {1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision.
189
+ There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
190
+ The transcript only includes text, it does not include markup like HTML and Markdown.
191
+ {1} responds with short and concise answers.
192
+
193
+ {0}{4} Hello, {1}!
194
+ {1}{4} Hello {0}! How may I help you today?
195
+ {0}{4} What time is it?
196
+ {1}{4} It is {2} o'clock.
197
+ {0}{4} What year is it?
198
+ {1}{4} We are in {3}.
199
+ {0}{4} What is a cat?
200
+ {1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
201
+ {0}{4} Name a color.
202
+ {1}{4} Blue
203
+ {0}{4})";
204
+
205
+ int main(int argc, char ** argv) {
206
+ whisper_params params;
207
+
208
+ if (whisper_params_parse(argc, argv, params) == false) {
209
+ return 1;
210
+ }
211
+
212
+ if (whisper_lang_id(params.language.c_str()) == -1) {
213
+ fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
214
+ whisper_print_usage(argc, argv, params);
215
+ exit(0);
216
+ }
217
+
218
+ // whisper init
219
+
220
+ struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
221
+
222
+ // llama init
223
+
224
+ auto lparams = llama_context_default_params();
225
+
226
+ // tune these to your liking
227
+ lparams.n_ctx = 512;
228
+ lparams.seed = 1;
229
+ lparams.f16_kv = true;
230
+
231
+ struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
232
+
233
+ // print some info about the processing
234
+ {
235
+ fprintf(stderr, "\n");
236
+
237
+ if (!whisper_is_multilingual(ctx_wsp)) {
238
+ if (params.language != "en" || params.translate) {
239
+ params.language = "en";
240
+ params.translate = false;
241
+ fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
242
+ }
243
+ }
244
+ fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
245
+ __func__,
246
+ params.n_threads,
247
+ params.language.c_str(),
248
+ params.translate ? "translate" : "transcribe",
249
+ params.no_timestamps ? 0 : 1);
250
+
251
+ fprintf(stderr, "\n");
252
+ }
253
+
254
+
255
+ // init audio
256
+
257
+ audio_async audio(30*1000);
258
+ if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
259
+ fprintf(stderr, "%s: audio.init() failed!\n", __func__);
260
+ return 1;
261
+ }
262
+
263
+ audio.resume();
264
+
265
+ int n_iter = 0;
266
+
267
+ bool is_running = true;
268
+ bool force_speak = false;
269
+
270
+ float prob0 = 0.0f;
271
+
272
+ const std::string chat_symb = ":";
273
+ const std::string bot_name = "LLaMA";
274
+
275
+ std::vector<float> pcmf32_cur;
276
+ std::vector<float> pcmf32_prompt;
277
+
278
+ const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
279
+
280
+ // construct the initial prompt for LLaMA inference
281
+ std::string prompt_llama = k_prompt_llama;
282
+
283
+ prompt_llama = ::replace(prompt_llama, "{0}", params.person);
284
+ prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
285
+
286
+ {
287
+ // get time string
288
+ std::string time_str;
289
+ {
290
+ time_t t = time(0);
291
+ struct tm * now = localtime(&t);
292
+ char buf[128];
293
+ strftime(buf, sizeof(buf), "%H:%M", now);
294
+ time_str = buf;
295
+ }
296
+ prompt_llama = ::replace(prompt_llama, "{2}", time_str);
297
+ }
298
+
299
+ {
300
+ // get year string
301
+ std::string year_str;
302
+ {
303
+ time_t t = time(0);
304
+ struct tm * now = localtime(&t);
305
+ char buf[128];
306
+ strftime(buf, sizeof(buf), "%Y", now);
307
+ year_str = buf;
308
+ }
309
+ prompt_llama = ::replace(prompt_llama, "{3}", year_str);
310
+ }
311
+
312
+ prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
313
+
314
+ // evaluate the initial prompt
315
+
316
+ auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
317
+
318
+ printf("\n");
319
+ printf("%s : initializing - please wait ...\n", __func__);
320
+
321
+ if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
322
+ fprintf(stderr, "%s : failed to eval\n", __func__);
323
+ return 1;
324
+ }
325
+
326
+ //fprintf(stdout, "\n");
327
+ //fprintf(stdout, "%s", prompt_llama.c_str());
328
+ //fflush(stdout);
329
+
330
+ printf("%s : done! start speaking in the microphone\n", __func__);
331
+ printf("\n");
332
+ printf("%s%s", params.person.c_str(), chat_symb.c_str());
333
+ fflush(stdout);
334
+
335
+ // clear audio buffer
336
+ audio.clear();
337
+
338
+ // text inference variables
339
+ const int voice_id = 2;
340
+ const int n_keep = embd_inp.size();
341
+ const int n_ctx = llama_n_ctx(ctx_llama);
342
+
343
+ int n_past = n_keep;
344
+ int n_prev = 64; // TODO arg
345
+
346
+ std::vector<llama_token> embd;
347
+
348
+ // reverse prompts for detecting when it's time to stop speaking
349
+ std::vector<std::string> antiprompts = {
350
+ params.person + chat_symb,
351
+ };
352
+
353
+ // main loop
354
+ while (is_running) {
355
+ // handle Ctrl + C
356
+ is_running = sdl_poll_events();
357
+
358
+ if (!is_running) {
359
+ break;
360
+ }
361
+
362
+ // delay
363
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
364
+
365
+ int64_t t_ms = 0;
366
+
367
+ {
368
+ audio.get(2000, pcmf32_cur);
369
+
370
+ if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
371
+ //fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
372
+
373
+ audio.get(params.voice_ms, pcmf32_cur);
374
+
375
+ std::string text_heard;
376
+
377
+ if (!force_speak) {
378
+ text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
379
+ }
380
+
381
+ // remove text between brackets using regex
382
+ {
383
+ std::regex re("\\[.*?\\]");
384
+ text_heard = std::regex_replace(text_heard, re, "");
385
+ }
386
+
387
+ // remove text between brackets using regex
388
+ {
389
+ std::regex re("\\(.*?\\)");
390
+ text_heard = std::regex_replace(text_heard, re, "");
391
+ }
392
+
393
+ // remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
394
+ text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
395
+
396
+ // take first line
397
+ text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
398
+
399
+ // remove leading and trailing whitespace
400
+ text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
401
+ text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
402
+
403
+ const std::vector<llama_token> tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false);
404
+
405
+ if (text_heard.empty() || tokens.empty() || force_speak) {
406
+ //fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
407
+ audio.clear();
408
+
409
+ continue;
410
+ }
411
+
412
+ force_speak = false;
413
+
414
+ text_heard.insert(0, 1, ' ');
415
+ text_heard += "\n" + bot_name + chat_symb;
416
+ fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m");
417
+ fflush(stdout);
418
+
419
+ embd = ::llama_tokenize(ctx_llama, text_heard, false);
420
+
421
+ // text inference
422
+ bool done = false;
423
+ std::string text_to_speak;
424
+ while (true) {
425
+ // predict
426
+ if (embd.size() > 0) {
427
+ if (n_past + (int) embd.size() > n_ctx) {
428
+ n_past = n_keep;
429
+
430
+ // insert n_left/2 tokens at the start of embd from last_n_tokens
431
+ embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
432
+
433
+ //printf("\n---\n");
434
+ //printf("resetting: '");
435
+ //for (int i = 0; i < (int) embd.size(); i++) {
436
+ // printf("%s", llama_token_to_str(ctx_llama, embd[i]));
437
+ //}
438
+ //printf("'\n");
439
+ //printf("\n---\n");
440
+ }
441
+
442
+ if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
443
+ fprintf(stderr, "%s : failed to eval\n", __func__);
444
+ return 1;
445
+ }
446
+ }
447
+
448
+ //printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
449
+
450
+ embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
451
+ n_past += embd.size();
452
+ embd.clear();
453
+
454
+ if (done) break;
455
+
456
+ {
457
+ // out of user input, sample next token
458
+ const float top_k = 5;
459
+ const float top_p = 0.80f;
460
+ const float temp = 0.30f;
461
+ const float repeat_penalty = 1.1764f;
462
+
463
+ const int repeat_last_n = 256;
464
+
465
+ llama_token id = 0;
466
+
467
+ {
468
+ auto logits = llama_get_logits(ctx_llama);
469
+ logits[llama_token_eos()] = 0;
470
+
471
+ id = llama_sample_top_p_top_k(ctx_llama,
472
+ embd_inp.data() + std::max(0, n_past - repeat_last_n),
473
+ repeat_last_n, top_k, top_p, temp, repeat_penalty);
474
+ }
475
+
476
+ if (id != llama_token_eos()) {
477
+ // add it to the context
478
+ embd.push_back(id);
479
+
480
+ text_to_speak += llama_token_to_str(ctx_llama, id);
481
+
482
+ printf("%s", llama_token_to_str(ctx_llama, id));
483
+ }
484
+ }
485
+
486
+ {
487
+ std::string last_output;
488
+ for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
489
+ last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
490
+ }
491
+ last_output += llama_token_to_str(ctx_llama, embd[0]);
492
+
493
+ for (std::string & antiprompt : antiprompts) {
494
+ if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
495
+ done = true;
496
+ text_to_speak = ::replace(text_to_speak, antiprompt, "");
497
+ fflush(stdout);
498
+ break;
499
+ }
500
+ }
501
+ }
502
+
503
+ is_running = sdl_poll_events();
504
+
505
+ if (!is_running) {
506
+ break;
507
+ }
508
+ }
509
+
510
+ text_to_speak = ::replace(text_to_speak, "\"", "");
511
+ system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
512
+
513
+ audio.clear();
514
+
515
+ ++n_iter;
516
+ }
517
+ }
518
+ }
519
+
520
+ audio.pause();
521
+
522
+ whisper_print_timings(ctx_wsp);
523
+ whisper_free(ctx_wsp);
524
+
525
+ llama_print_timings(ctx_llama);
526
+ llama_free(ctx_llama);
527
+
528
+ return 0;
529
+ }
examples/talk/speak.sh CHANGED
@@ -7,7 +7,10 @@
7
  # Mac OS: brew install espeak
8
  # Linux: apt-get install espeak
9
  #
10
- espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
 
 
 
11
 
12
  # Eleven Labs
13
  #
 
7
  # Mac OS: brew install espeak
8
  # Linux: apt-get install espeak
9
  #
10
+ #espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
11
+
12
+ # Mac OS "say" command
13
+ say "$2"
14
 
15
  # Eleven Labs
16
  #
ggml.c CHANGED
The diff for this file is too large to render. See raw diff
 
ggml.h CHANGED
@@ -198,6 +198,8 @@ struct ggml_object;
198
  struct ggml_context;
199
 
200
  enum ggml_type {
 
 
201
  GGML_TYPE_I8,
202
  GGML_TYPE_I16,
203
  GGML_TYPE_I32,
@@ -226,7 +228,9 @@ enum ggml_op {
226
  GGML_OP_STEP,
227
  GGML_OP_RELU,
228
  GGML_OP_GELU,
 
229
  GGML_OP_NORM, // normalize
 
230
 
231
  GGML_OP_MUL_MAT,
232
 
@@ -326,7 +330,10 @@ void ggml_print_objects(const struct ggml_context * ctx);
326
  int ggml_nelements(const struct ggml_tensor * tensor);
327
  size_t ggml_nbytes (const struct ggml_tensor * tensor);
328
 
329
- size_t ggml_type_size (enum ggml_type type);
 
 
 
330
  size_t ggml_element_size(const struct ggml_tensor * tensor);
331
 
332
  struct ggml_context * ggml_init(struct ggml_init_params params);
@@ -336,6 +343,9 @@ size_t ggml_used_mem(const struct ggml_context * ctx);
336
 
337
  size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
338
 
 
 
 
339
  struct ggml_tensor * ggml_new_tensor(
340
  struct ggml_context * ctx,
341
  enum ggml_type type,
@@ -466,12 +476,20 @@ struct ggml_tensor * ggml_gelu(
466
  struct ggml_context * ctx,
467
  struct ggml_tensor * a);
468
 
 
 
 
 
469
  // normalize along rows
470
  // TODO: eps is hardcoded to 1e-5 for now
471
  struct ggml_tensor * ggml_norm(
472
  struct ggml_context * ctx,
473
  struct ggml_tensor * a);
474
 
 
 
 
 
475
  // A: m rows, n columns
476
  // B: p rows, n columns (i.e. we transpose it internally)
477
  // result is m columns, p rows
@@ -726,6 +744,13 @@ enum ggml_opt_result ggml_opt(
726
  struct ggml_opt_params params,
727
  struct ggml_tensor * f);
728
 
 
 
 
 
 
 
 
729
  //
730
  // system info
731
  //
 
198
  struct ggml_context;
199
 
200
  enum ggml_type {
201
+ GGML_TYPE_Q4_0,
202
+ GGML_TYPE_Q4_1,
203
  GGML_TYPE_I8,
204
  GGML_TYPE_I16,
205
  GGML_TYPE_I32,
 
228
  GGML_OP_STEP,
229
  GGML_OP_RELU,
230
  GGML_OP_GELU,
231
+ GGML_OP_SILU,
232
  GGML_OP_NORM, // normalize
233
+ GGML_OP_RMS_NORM,
234
 
235
  GGML_OP_MUL_MAT,
236
 
 
330
  int ggml_nelements(const struct ggml_tensor * tensor);
331
  size_t ggml_nbytes (const struct ggml_tensor * tensor);
332
 
333
+ int ggml_blck_size (enum ggml_type type);
334
+ size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
335
+ float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
336
+
337
  size_t ggml_element_size(const struct ggml_tensor * tensor);
338
 
339
  struct ggml_context * ggml_init(struct ggml_init_params params);
 
343
 
344
  size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
345
 
346
+ bool ggml_mlock_supported(void);
347
+ bool ggml_mlock(struct ggml_context * ctx, char ** err_p);
348
+
349
  struct ggml_tensor * ggml_new_tensor(
350
  struct ggml_context * ctx,
351
  enum ggml_type type,
 
476
  struct ggml_context * ctx,
477
  struct ggml_tensor * a);
478
 
479
+ struct ggml_tensor * ggml_silu(
480
+ struct ggml_context * ctx,
481
+ struct ggml_tensor * a);
482
+
483
  // normalize along rows
484
  // TODO: eps is hardcoded to 1e-5 for now
485
  struct ggml_tensor * ggml_norm(
486
  struct ggml_context * ctx,
487
  struct ggml_tensor * a);
488
 
489
+ struct ggml_tensor * ggml_rms_norm(
490
+ struct ggml_context * ctx,
491
+ struct ggml_tensor * a);
492
+
493
  // A: m rows, n columns
494
  // B: p rows, n columns (i.e. we transpose it internally)
495
  // result is m columns, p rows
 
744
  struct ggml_opt_params params,
745
  struct ggml_tensor * f);
746
 
747
+ //
748
+ // quantization
749
+ //
750
+
751
+ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
752
+ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
753
+
754
  //
755
  // system info
756
  //
whisper.cpp CHANGED
@@ -636,6 +636,8 @@ struct whisper_context {
636
  whisper_model model;
637
  whisper_vocab vocab;
638
  whisper_state * state = nullptr;
 
 
639
  };
640
 
641
  template<typename T>
@@ -1597,7 +1599,7 @@ static bool whisper_encode_internal(
1597
  ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1598
  cur),
1599
  ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1600
- }
1601
 
1602
  #ifdef WHISPER_USE_FLASH_FF
1603
  wstate.use_buf(ctx0, 0);
@@ -1637,7 +1639,7 @@ static bool whisper_encode_internal(
1637
  ggml_repeat(ctx0, layer.mlp_1_b, cur),
1638
  cur);
1639
  #endif
1640
- }
1641
 
1642
  wstate.use_buf(ctx0, 3);
1643
 
@@ -1841,8 +1843,6 @@ static bool whisper_decode_internal(
1841
 
1842
  // self-attention
1843
  {
1844
- wstate.use_buf(ctx0, 1);
1845
-
1846
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1847
  layer.attn_q_w,
1848
  cur);
@@ -1904,8 +1904,6 @@ static bool whisper_decode_internal(
1904
  // K * Q
1905
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1906
 
1907
- wstate.use_buf(ctx0, 0);
1908
-
1909
  //struct ggml_tensor * KQ_scaled =
1910
  // ggml_scale(ctx0,
1911
  // KQ,
@@ -1914,20 +1912,16 @@ static bool whisper_decode_internal(
1914
 
1915
  struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
1916
 
1917
- wstate.use_buf(ctx0, 1);
1918
-
1919
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
1920
 
1921
- wstate.use_buf(ctx0, 0);
1922
-
1923
  struct ggml_tensor * V_trans =
1924
- ggml_permute(ctx0,
1925
- ggml_reshape_3d(ctx0,
1926
- ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
1927
- n_state/n_head, n_head, n_past + N),
1928
- 1, 2, 0, 3);
1929
-
1930
- wstate.use_buf(ctx0, 1);
1931
 
1932
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1933
 
@@ -1964,8 +1958,6 @@ static bool whisper_decode_internal(
1964
 
1965
  cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
1966
 
1967
- wstate.use_buf(ctx0, 1);
1968
-
1969
  // cur = ln_0_w*cur + ln_0_b
1970
  cur = ggml_add(ctx0,
1971
  ggml_mul(ctx0,
@@ -1976,8 +1968,6 @@ static bool whisper_decode_internal(
1976
 
1977
  // cross-attention
1978
  {
1979
- wstate.use_buf(ctx0, 0);
1980
-
1981
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1982
  layer.cross_attn_q_w,
1983
  cur);
@@ -2001,12 +1991,13 @@ static bool whisper_decode_internal(
2001
  ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
2002
  n_state/n_head, n_head, M);
2003
 
2004
- struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
 
 
 
2005
 
2006
  // ------
2007
 
2008
- wstate.use_buf(ctx0, 1);
2009
-
2010
  struct ggml_tensor * Q =
2011
  ggml_permute(ctx0,
2012
  ggml_cpy(ctx0,
@@ -2016,8 +2007,6 @@ static bool whisper_decode_internal(
2016
 
2017
  struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2018
 
2019
- wstate.use_buf(ctx0, 0);
2020
-
2021
  // K * Q
2022
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2023
 
@@ -2030,16 +2019,10 @@ static bool whisper_decode_internal(
2030
  // no masking for cross-attention
2031
  //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2032
 
2033
- wstate.use_buf(ctx0, 1);
2034
-
2035
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2036
 
2037
- wstate.use_buf(ctx0, 0);
2038
-
2039
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2040
 
2041
- wstate.use_buf(ctx0, 1);
2042
-
2043
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2044
 
2045
  // cur = KQV_merged.contiguous().view(n_state, N)
@@ -2482,7 +2465,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2482
 
2483
  const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
2484
 
2485
-
2486
  if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
2487
  fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
2488
  return nullptr;
@@ -2503,7 +2485,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2503
  fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2504
  }
2505
 
2506
-
2507
  state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
2508
 
2509
  state->logits_id.reserve(ctx->model.hparams.n_vocab);
@@ -2554,7 +2535,13 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
2554
  fin->close();
2555
  };
2556
 
2557
- return whisper_init_no_state(&loader);
 
 
 
 
 
 
2558
  }
2559
 
2560
  struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
 
636
  whisper_model model;
637
  whisper_vocab vocab;
638
  whisper_state * state = nullptr;
639
+
640
+ std::string path_model; // populated by whisper_init_from_file()
641
  };
642
 
643
  template<typename T>
 
1599
  ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1600
  cur),
1601
  ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1602
+ }
1603
 
1604
  #ifdef WHISPER_USE_FLASH_FF
1605
  wstate.use_buf(ctx0, 0);
 
1639
  ggml_repeat(ctx0, layer.mlp_1_b, cur),
1640
  cur);
1641
  #endif
1642
+ }
1643
 
1644
  wstate.use_buf(ctx0, 3);
1645
 
 
1843
 
1844
  // self-attention
1845
  {
 
 
1846
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1847
  layer.attn_q_w,
1848
  cur);
 
1904
  // K * Q
1905
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1906
 
 
 
1907
  //struct ggml_tensor * KQ_scaled =
1908
  // ggml_scale(ctx0,
1909
  // KQ,
 
1912
 
1913
  struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
1914
 
 
 
1915
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
1916
 
 
 
1917
  struct ggml_tensor * V_trans =
1918
+ ggml_cpy(ctx0,
1919
+ ggml_permute(ctx0,
1920
+ ggml_reshape_3d(ctx0,
1921
+ ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
1922
+ n_state/n_head, n_head, n_past + N),
1923
+ 1, 2, 0, 3),
1924
+ ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_state/n_head, n_head));
1925
 
1926
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1927
 
 
1958
 
1959
  cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
1960
 
 
 
1961
  // cur = ln_0_w*cur + ln_0_b
1962
  cur = ggml_add(ctx0,
1963
  ggml_mul(ctx0,
 
1968
 
1969
  // cross-attention
1970
  {
 
 
1971
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1972
  layer.cross_attn_q_w,
1973
  cur);
 
1991
  ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
1992
  n_state/n_head, n_head, M);
1993
 
1994
+ struct ggml_tensor * V_trans =
1995
+ ggml_cpy(ctx0,
1996
+ ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
1997
+ ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
1998
 
1999
  // ------
2000
 
 
 
2001
  struct ggml_tensor * Q =
2002
  ggml_permute(ctx0,
2003
  ggml_cpy(ctx0,
 
2007
 
2008
  struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2009
 
 
 
2010
  // K * Q
2011
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2012
 
 
2019
  // no masking for cross-attention
2020
  //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2021
 
 
 
2022
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2023
 
 
 
2024
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2025
 
 
 
2026
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2027
 
2028
  // cur = KQV_merged.contiguous().view(n_state, N)
 
2465
 
2466
  const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
2467
 
 
2468
  if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
2469
  fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
2470
  return nullptr;
 
2485
  fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2486
  }
2487
 
 
2488
  state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
2489
 
2490
  state->logits_id.reserve(ctx->model.hparams.n_vocab);
 
2535
  fin->close();
2536
  };
2537
 
2538
+ auto ctx = whisper_init_no_state(&loader);
2539
+
2540
+ if (ctx) {
2541
+ ctx->path_model = path_model;
2542
+ }
2543
+
2544
+ return ctx;
2545
  }
2546
 
2547
  struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {