ggerganov commited on
Commit
b462700
·
unverified ·
1 Parent(s): 89d94b1

talk-llama : sync llama.cpp (#2709)

Browse files
Files changed (36) hide show
  1. examples/talk-llama/CMakeLists.txt +17 -1
  2. examples/talk-llama/llama-adapter.cpp +334 -0
  3. examples/talk-llama/llama-adapter.h +66 -0
  4. examples/talk-llama/llama-arch.cpp +1434 -0
  5. examples/talk-llama/llama-arch.h +395 -0
  6. examples/talk-llama/llama-batch.cpp +368 -0
  7. examples/talk-llama/llama-batch.h +88 -0
  8. examples/talk-llama/llama-chat.cpp +567 -0
  9. examples/talk-llama/llama-chat.h +51 -0
  10. examples/talk-llama/llama-context.cpp +1771 -0
  11. examples/talk-llama/llama-context.h +128 -0
  12. examples/talk-llama/llama-cparams.cpp +1 -0
  13. examples/talk-llama/llama-cparams.h +37 -0
  14. examples/talk-llama/llama-grammar.cpp +16 -15
  15. examples/talk-llama/llama-grammar.h +5 -6
  16. examples/talk-llama/llama-hparams.cpp +71 -0
  17. examples/talk-llama/llama-hparams.h +140 -0
  18. examples/talk-llama/llama-impl.cpp +166 -0
  19. examples/talk-llama/llama-impl.h +16 -136
  20. examples/talk-llama/llama-kv-cache.cpp +718 -0
  21. examples/talk-llama/llama-kv-cache.h +218 -0
  22. examples/talk-llama/llama-mmap.cpp +589 -0
  23. examples/talk-llama/llama-mmap.h +67 -0
  24. examples/talk-llama/llama-model-loader.cpp +1010 -0
  25. examples/talk-llama/llama-model-loader.h +158 -0
  26. examples/talk-llama/llama-model.cpp +0 -0
  27. examples/talk-llama/llama-model.h +391 -0
  28. examples/talk-llama/llama-quant.cpp +929 -0
  29. examples/talk-llama/llama-quant.h +1 -0
  30. examples/talk-llama/llama-sampling.cpp +117 -4
  31. examples/talk-llama/llama-vocab.cpp +26 -29
  32. examples/talk-llama/llama-vocab.h +14 -2
  33. examples/talk-llama/llama.cpp +0 -0
  34. examples/talk-llama/llama.h +31 -9
  35. examples/talk-llama/talk-llama.cpp +1 -1
  36. examples/talk-llama/unicode.cpp +6 -0
examples/talk-llama/CMakeLists.txt CHANGED
@@ -1,10 +1,26 @@
1
  if (WHISPER_SDL2)
 
 
 
2
  set(TARGET whisper-talk-llama)
3
  add_executable(${TARGET} talk-llama.cpp
4
  llama.cpp
5
- llama-vocab.cpp
 
 
 
 
 
6
  llama-grammar.cpp
 
 
 
 
 
 
 
7
  llama-sampling.cpp
 
8
  unicode.cpp
9
  unicode-data.cpp)
10
  target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
 
1
  if (WHISPER_SDL2)
2
+ set(CMAKE_CXX_STANDARD 17)
3
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
4
+
5
  set(TARGET whisper-talk-llama)
6
  add_executable(${TARGET} talk-llama.cpp
7
  llama.cpp
8
+ llama-adapter.cpp
9
+ llama-arch.cpp
10
+ llama-batch.cpp
11
+ llama-chat.cpp
12
+ llama-context.cpp
13
+ llama-cparams.cpp
14
  llama-grammar.cpp
15
+ llama-hparams.cpp
16
+ llama-impl.cpp
17
+ llama-kv-cache.cpp
18
+ llama-mmap.cpp
19
+ llama-model-loader.cpp
20
+ llama-model.cpp
21
+ llama-quant.cpp
22
  llama-sampling.cpp
23
+ llama-vocab.cpp
24
  unicode.cpp
25
  unicode-data.cpp)
26
  target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
examples/talk-llama/llama-adapter.cpp ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-adapter.h"
2
+
3
+ #include "llama-model.h"
4
+
5
+ #include <algorithm>
6
+ #include <map>
7
+ #include <cassert>
8
+ #include <stdexcept>
9
+
10
+ // vec
11
+
12
+ struct ggml_tensor * llama_control_vector::tensor_for(int il) const {
13
+ if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
14
+ return nullptr;
15
+ }
16
+
17
+ return tensors[il];
18
+ }
19
+
20
+ struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
21
+ ggml_tensor * layer_dir = tensor_for(il);
22
+ if (layer_dir != nullptr) {
23
+ cur = ggml_add(ctx, cur, layer_dir);
24
+ }
25
+
26
+ return cur;
27
+ }
28
+
29
+ static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) {
30
+ const auto & hparams = model.hparams;
31
+
32
+ GGML_ASSERT(cvec.tensors.empty());
33
+ GGML_ASSERT(cvec.ctxs.empty());
34
+ GGML_ASSERT(cvec.bufs.empty());
35
+
36
+ // create a context for each buffer type
37
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
38
+ auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
39
+ auto it = ctx_map.find(buft);
40
+ if (it == ctx_map.end()) {
41
+ struct ggml_init_params params = {
42
+ /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
43
+ /*.mem_buffer =*/ NULL,
44
+ /*.no_alloc =*/ true,
45
+ };
46
+
47
+ ggml_context * ctx = ggml_init(params);
48
+ if (!ctx) {
49
+ return nullptr;
50
+ }
51
+
52
+ ctx_map[buft] = ctx;
53
+ cvec.ctxs.emplace_back(ctx);
54
+
55
+ return ctx;
56
+ }
57
+
58
+ return it->second;
59
+ };
60
+
61
+ // make tensors
62
+ cvec.tensors.reserve(hparams.n_layer);
63
+ cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
64
+ for (size_t il = 1; il < hparams.n_layer; il++) {
65
+ ggml_backend_buffer_type_t buft = llama_model_select_buft(model, il);
66
+ ggml_context * ctx = ctx_for_buft(buft);
67
+ if (!ctx) {
68
+ LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
69
+ return false;
70
+ }
71
+ ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
72
+ cvec.tensors.push_back(tensor);
73
+ }
74
+
75
+ // allocate tensors / buffers and zero
76
+ cvec.bufs.reserve(ctx_map.size());
77
+ for (auto it : ctx_map) {
78
+ ggml_backend_buffer_type_t buft = it.first;
79
+ ggml_context * ctx = it.second;
80
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
81
+ if (!buf) {
82
+ LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__);
83
+ return false;
84
+ }
85
+ ggml_backend_buffer_clear(buf, 0);
86
+ cvec.bufs.emplace_back(buf);
87
+ }
88
+
89
+ return true;
90
+ }
91
+
92
+ int32_t llama_control_vector_apply(
93
+ struct llama_control_vector & cvec,
94
+ const llama_model & model,
95
+ const float * data,
96
+ size_t len,
97
+ int32_t n_embd,
98
+ int32_t il_start,
99
+ int32_t il_end) {
100
+ const auto & hparams = model.hparams;
101
+
102
+ if (data == nullptr) {
103
+ // disable the current control vector (but leave allocated for later)
104
+ cvec.layer_start = -1;
105
+ cvec.layer_end = -1;
106
+ return 0;
107
+ }
108
+
109
+ if (n_embd != (int) hparams.n_embd) {
110
+ LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
111
+ return 1;
112
+ }
113
+
114
+ if (cvec.tensors.empty()) {
115
+ if (!llama_control_vector_init(cvec, model)) {
116
+ return 1;
117
+ }
118
+ }
119
+
120
+ cvec.layer_start = il_start;
121
+ cvec.layer_end = il_end;
122
+
123
+ for (size_t il = 1; il < hparams.n_layer; il++) {
124
+ assert(cvec.tensors[il] != nullptr);
125
+
126
+ const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
127
+ if (off + n_embd <= len) {
128
+ ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il]));
129
+ }
130
+ }
131
+
132
+ return 0;
133
+ }
134
+
135
+ // lora
136
+
137
+ llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) {
138
+ const std::string name(w->name);
139
+
140
+ const auto pos = ab_map.find(name);
141
+ if (pos != ab_map.end()) {
142
+ return &pos->second;
143
+ }
144
+
145
+ return nullptr;
146
+ }
147
+
148
+ void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
149
+ delete adapter;
150
+ }
151
+
152
+ static void llama_lora_adapter_init_impl(struct llama_model & model, const char * path_lora, struct llama_lora_adapter & adapter) {
153
+ LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
154
+
155
+ ggml_context * ctx_init;
156
+ struct gguf_init_params meta_gguf_params = {
157
+ /* .no_alloc = */ true,
158
+ /* .ctx = */ &ctx_init,
159
+ };
160
+
161
+ gguf_context_ptr ctx_gguf { gguf_init_from_file(path_lora, meta_gguf_params) };
162
+ if (!ctx_gguf) {
163
+ throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora));
164
+ }
165
+
166
+ ggml_context_ptr ctx { ctx_init };
167
+
168
+ // check metadata
169
+ {
170
+ auto get_kv_str = [&](const std::string & key) -> std::string {
171
+ int id = gguf_find_key(ctx_gguf.get(), key.c_str());
172
+ return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id));
173
+ };
174
+ auto get_kv_f32 = [&](const std::string & key) -> float {
175
+ int id = gguf_find_key(ctx_gguf.get(), key.c_str());
176
+ return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id);
177
+ };
178
+ LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
179
+
180
+ auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE));
181
+ if (general_type != "adapter") {
182
+ throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type);
183
+ }
184
+
185
+ auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
186
+ auto general_arch = llm_arch_from_string(general_arch_str);
187
+ if (general_arch != model.arch) {
188
+ throw std::runtime_error("model arch and LoRA arch mismatch");
189
+ }
190
+
191
+ auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE));
192
+ if (adapter_type != "lora") {
193
+ throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type);
194
+ }
195
+
196
+ adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
197
+ }
198
+
199
+ int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
200
+
201
+ // contexts for each buffer type
202
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
203
+ auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
204
+ auto it = ctx_map.find(buft);
205
+ if (it == ctx_map.end()) {
206
+ // add a new context
207
+ struct ggml_init_params params = {
208
+ /*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
209
+ /*.mem_buffer =*/ NULL,
210
+ /*.no_alloc =*/ true,
211
+ };
212
+ ggml_context * buft_ctx = ggml_init(params);
213
+ if (!buft_ctx) {
214
+ return nullptr;
215
+ }
216
+ ctx_map[buft] = buft_ctx;
217
+ adapter.ctxs.emplace_back(buft_ctx);
218
+ return buft_ctx;
219
+ };
220
+ return it->second;
221
+ };
222
+
223
+ // bundle lora_a and lora_b into pairs
224
+ std::map<std::string, llama_lora_weight> ab_map;
225
+ auto str_endswith = [](const std::string & str, const std::string & suffix) {
226
+ return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
227
+ };
228
+
229
+ for (ggml_tensor * cur = ggml_get_first_tensor(ctx.get()); cur; cur = ggml_get_next_tensor(ctx.get(), cur)) {
230
+ std::string name(cur->name);
231
+ if (str_endswith(name, ".lora_a")) {
232
+ replace_all(name, ".lora_a", "");
233
+ if (ab_map.find(name) == ab_map.end()) {
234
+ ab_map[name] = llama_lora_weight(cur, nullptr);
235
+ } else {
236
+ ab_map[name].a = cur;
237
+ }
238
+ } else if (str_endswith(name, ".lora_b")) {
239
+ replace_all(name, ".lora_b", "");
240
+ if (ab_map.find(name) == ab_map.end()) {
241
+ ab_map[name] = llama_lora_weight(nullptr, cur);
242
+ } else {
243
+ ab_map[name].b = cur;
244
+ }
245
+ } else {
246
+ throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
247
+ }
248
+ }
249
+
250
+ // add tensors
251
+ for (auto & it : ab_map) {
252
+ const std::string & name = it.first;
253
+ llama_lora_weight & w = it.second;
254
+
255
+ if (!w.a || !w.b) {
256
+ throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
257
+ }
258
+
259
+ // device buft and device ctx
260
+ auto * model_tensor = llama_model_get_tensor(model, name.c_str());
261
+ if (!model_tensor) {
262
+ throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
263
+ }
264
+
265
+ struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
266
+ // validate tensor shape
267
+ if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
268
+ throw std::runtime_error("tensor '" + name + "' has incorrect shape");
269
+ }
270
+ if (w.a->ne[1] != w.b->ne[0]) {
271
+ throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
272
+ }
273
+
274
+ // save tensor to adapter
275
+ struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
276
+ struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
277
+ ggml_set_name(tensor_a, w.a->name);
278
+ ggml_set_name(tensor_b, w.b->name);
279
+ adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
280
+ }
281
+
282
+ // allocate tensors / buffers and zero
283
+ {
284
+ adapter.ctxs.reserve(ctx_map.size());
285
+ adapter.bufs.reserve(ctx_map.size());
286
+ for (auto & it : ctx_map) {
287
+ ggml_backend_buffer_type_t buft = it.first;
288
+ ggml_context * ctx_dev = it.second;
289
+ ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft) };
290
+ if (!buf) {
291
+ throw std::runtime_error("failed to allocate buffer for lora adapter\n");
292
+ }
293
+ LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get())/1024.0/1024.0);
294
+ adapter.bufs.emplace_back(std::move(buf));
295
+ }
296
+ }
297
+
298
+ // set tensor data
299
+ {
300
+ llama_file gguf_file(path_lora, "rb");
301
+ std::vector<uint8_t> read_buf;
302
+ auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
303
+ size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
304
+ size_t size = ggml_nbytes(orig);
305
+ read_buf.resize(size);
306
+ gguf_file.seek(offs, SEEK_SET);
307
+ gguf_file.read_raw(read_buf.data(), size);
308
+ ggml_backend_tensor_set(dev, read_buf.data(), 0, size);
309
+ };
310
+ for (auto & it : adapter.ab_map) {
311
+ auto orig = ab_map[it.first];
312
+ auto dev = it.second;
313
+ set_tensor(orig.a, dev.a);
314
+ set_tensor(orig.b, dev.b);
315
+ }
316
+ }
317
+
318
+ LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
319
+ }
320
+
321
+ struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
322
+ struct llama_lora_adapter * adapter = new llama_lora_adapter();
323
+
324
+ try {
325
+ llama_lora_adapter_init_impl(*model, path_lora, *adapter);
326
+ return adapter;
327
+ } catch (const std::exception & err) {
328
+ LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
329
+
330
+ delete adapter;
331
+ }
332
+
333
+ return nullptr;
334
+ }
examples/talk-llama/llama-adapter.h ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-hparams.h"
5
+
6
+ #include "ggml-cpp.h"
7
+
8
+ #include <unordered_map>
9
+ #include <vector>
10
+
11
+ //
12
+ // llama_adapter_cvec
13
+ //
14
+
15
+ // TODO: rename to llama_adapter_cvec
16
+ struct llama_control_vector {
17
+ std::vector<ggml_context_ptr> ctxs;
18
+ std::vector<ggml_backend_buffer_ptr> bufs;
19
+
20
+ std::vector<struct ggml_tensor *> tensors; // per layer
21
+
22
+ int32_t layer_start = -1;
23
+ int32_t layer_end = -1;
24
+
25
+ struct ggml_tensor * tensor_for(int il) const;
26
+
27
+ struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
28
+ };
29
+
30
+ int32_t llama_control_vector_apply(
31
+ struct llama_control_vector & cvec,
32
+ const llama_model & model,
33
+ const float * data,
34
+ size_t len,
35
+ int32_t n_embd,
36
+ int32_t il_start,
37
+ int32_t il_end);
38
+
39
+ //
40
+ // llama_adapter_lora
41
+ //
42
+
43
+ // TODO: rename to llama_adapter_lora_weight
44
+ struct llama_lora_weight {
45
+ struct ggml_tensor * a = nullptr;
46
+ struct ggml_tensor * b = nullptr;
47
+
48
+ llama_lora_weight() = default;
49
+ llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
50
+ };
51
+
52
+ // TODO: rename to llama_adapter_lora
53
+ struct llama_lora_adapter {
54
+ // map tensor name to lora_a_b
55
+ std::unordered_map<std::string, struct llama_lora_weight> ab_map;
56
+
57
+ std::vector<ggml_context_ptr> ctxs;
58
+ std::vector<ggml_backend_buffer_ptr> bufs;
59
+
60
+ float alpha;
61
+
62
+ llama_lora_adapter() = default;
63
+ ~llama_lora_adapter() = default;
64
+
65
+ llama_lora_weight * get_weight(struct ggml_tensor * w);
66
+ };
examples/talk-llama/llama-arch.cpp ADDED
@@ -0,0 +1,1434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-arch.h"
2
+
3
+ #include "llama-impl.h"
4
+
5
+ #include <map>
6
+
7
+ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8
+ { LLM_ARCH_LLAMA, "llama" },
9
+ { LLM_ARCH_DECI, "deci" },
10
+ { LLM_ARCH_FALCON, "falcon" },
11
+ { LLM_ARCH_GROK, "grok" },
12
+ { LLM_ARCH_GPT2, "gpt2" },
13
+ { LLM_ARCH_GPTJ, "gptj" },
14
+ { LLM_ARCH_GPTNEOX, "gptneox" },
15
+ { LLM_ARCH_MPT, "mpt" },
16
+ { LLM_ARCH_BAICHUAN, "baichuan" },
17
+ { LLM_ARCH_STARCODER, "starcoder" },
18
+ { LLM_ARCH_REFACT, "refact" },
19
+ { LLM_ARCH_BERT, "bert" },
20
+ { LLM_ARCH_NOMIC_BERT, "nomic-bert" },
21
+ { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
22
+ { LLM_ARCH_BLOOM, "bloom" },
23
+ { LLM_ARCH_STABLELM, "stablelm" },
24
+ { LLM_ARCH_QWEN, "qwen" },
25
+ { LLM_ARCH_QWEN2, "qwen2" },
26
+ { LLM_ARCH_QWEN2MOE, "qwen2moe" },
27
+ { LLM_ARCH_QWEN2VL, "qwen2vl" },
28
+ { LLM_ARCH_PHI2, "phi2" },
29
+ { LLM_ARCH_PHI3, "phi3" },
30
+ { LLM_ARCH_PLAMO, "plamo" },
31
+ { LLM_ARCH_CODESHELL, "codeshell" },
32
+ { LLM_ARCH_ORION, "orion" },
33
+ { LLM_ARCH_INTERNLM2, "internlm2" },
34
+ { LLM_ARCH_MINICPM, "minicpm" },
35
+ { LLM_ARCH_MINICPM3, "minicpm3" },
36
+ { LLM_ARCH_GEMMA, "gemma" },
37
+ { LLM_ARCH_GEMMA2, "gemma2" },
38
+ { LLM_ARCH_STARCODER2, "starcoder2" },
39
+ { LLM_ARCH_MAMBA, "mamba" },
40
+ { LLM_ARCH_XVERSE, "xverse" },
41
+ { LLM_ARCH_COMMAND_R, "command-r" },
42
+ { LLM_ARCH_COHERE2, "cohere2" },
43
+ { LLM_ARCH_DBRX, "dbrx" },
44
+ { LLM_ARCH_OLMO, "olmo" },
45
+ { LLM_ARCH_OLMO2, "olmo2" },
46
+ { LLM_ARCH_OLMOE, "olmoe" },
47
+ { LLM_ARCH_OPENELM, "openelm" },
48
+ { LLM_ARCH_ARCTIC, "arctic" },
49
+ { LLM_ARCH_DEEPSEEK, "deepseek" },
50
+ { LLM_ARCH_DEEPSEEK2, "deepseek2" },
51
+ { LLM_ARCH_CHATGLM, "chatglm" },
52
+ { LLM_ARCH_BITNET, "bitnet" },
53
+ { LLM_ARCH_T5, "t5" },
54
+ { LLM_ARCH_T5ENCODER, "t5encoder" },
55
+ { LLM_ARCH_JAIS, "jais" },
56
+ { LLM_ARCH_NEMOTRON, "nemotron" },
57
+ { LLM_ARCH_EXAONE, "exaone" },
58
+ { LLM_ARCH_RWKV6, "rwkv6" },
59
+ { LLM_ARCH_GRANITE, "granite" },
60
+ { LLM_ARCH_GRANITE_MOE, "granitemoe" },
61
+ { LLM_ARCH_CHAMELEON, "chameleon" },
62
+ { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
63
+ { LLM_ARCH_UNKNOWN, "(unknown)" },
64
+ };
65
+
66
+ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
67
+ { LLM_KV_GENERAL_TYPE, "general.type" },
68
+ { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
69
+ { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
70
+ { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
71
+ { LLM_KV_GENERAL_NAME, "general.name" },
72
+ { LLM_KV_GENERAL_AUTHOR, "general.author" },
73
+ { LLM_KV_GENERAL_VERSION, "general.version" },
74
+ { LLM_KV_GENERAL_URL, "general.url" },
75
+ { LLM_KV_GENERAL_DESCRIPTION, "general.description" },
76
+ { LLM_KV_GENERAL_LICENSE, "general.license" },
77
+ { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
78
+ { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
79
+
80
+ { LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
81
+ { LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
82
+ { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
83
+ { LLM_KV_FEATURES_LENGTH, "%s.features_length" },
84
+ { LLM_KV_BLOCK_COUNT, "%s.block_count" },
85
+ { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
86
+ { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
87
+ { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
88
+ { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
89
+ { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
90
+ { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
91
+ { LLM_KV_EXPERT_COUNT, "%s.expert_count" },
92
+ { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
93
+ { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
94
+ { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
95
+ { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
96
+ { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
97
+ { LLM_KV_POOLING_TYPE, "%s.pooling_type" },
98
+ { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
99
+ { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
100
+ { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
101
+ { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
102
+ { LLM_KV_SWIN_NORM, "%s.swin_norm" },
103
+ { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
104
+ { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" },
105
+ { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
106
+ { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
107
+ { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
108
+
109
+ { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
110
+ { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
111
+ { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
112
+ { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
113
+ { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
114
+ { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
115
+ { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
116
+ { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
117
+ { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
118
+ { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
119
+ { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
120
+ { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
121
+ { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
122
+ { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
123
+ { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
124
+ { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
125
+
126
+ { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
127
+ { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
128
+ { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
129
+ { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
130
+ { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
131
+ { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
132
+ { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
133
+ { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
134
+ { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
135
+ { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
136
+
137
+ { LLM_KV_SPLIT_NO, "split.no" },
138
+ { LLM_KV_SPLIT_COUNT, "split.count" },
139
+ { LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" },
140
+
141
+ { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" },
142
+ { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
143
+ { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
144
+ { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
145
+ { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
146
+
147
+ { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
148
+
149
+ { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" },
150
+ { LLM_KV_POSNET_BLOCK_COUNT, "%s.posnet.block_count" },
151
+
152
+ { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
153
+ { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
154
+
155
+ { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
156
+ { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
157
+ { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
158
+ { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
159
+ { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" },
160
+ { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" },
161
+ { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" },
162
+ { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" },
163
+ { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" },
164
+ { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
165
+ { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" },
166
+ { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
167
+ { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
168
+ { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
169
+ { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" },
170
+ { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
171
+ { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
172
+ { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
173
+ { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
174
+ { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
175
+ { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
176
+ { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
177
+ { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
178
+ { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
179
+ { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
180
+ { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
181
+ { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" },
182
+ { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
183
+ { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
184
+
185
+ { LLM_KV_ADAPTER_TYPE, "adapter.type" },
186
+ { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
187
+
188
+ // deprecated
189
+ { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
190
+ { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
191
+ { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" },
192
+ };
193
+
194
+ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
195
+ {
196
+ LLM_ARCH_LLAMA,
197
+ {
198
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
199
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
200
+ { LLM_TENSOR_OUTPUT, "output" },
201
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
202
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
203
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
204
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
205
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
206
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
207
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
208
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
209
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
210
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
211
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
212
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
213
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
214
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
215
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
216
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
217
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
218
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
219
+ },
220
+ },
221
+ {
222
+ LLM_ARCH_DECI,
223
+ {
224
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
225
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
226
+ { LLM_TENSOR_OUTPUT, "output" },
227
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
228
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
229
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
230
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
231
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
232
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
233
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
234
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
235
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
236
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
237
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
238
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
239
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
240
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
241
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
242
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
243
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
244
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
245
+ },
246
+ },
247
+ {
248
+ LLM_ARCH_BAICHUAN,
249
+ {
250
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
251
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
252
+ { LLM_TENSOR_OUTPUT, "output" },
253
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
254
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
255
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
256
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
257
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
258
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
259
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
260
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
261
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
262
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
263
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
264
+ },
265
+ },
266
+ {
267
+ LLM_ARCH_FALCON,
268
+ {
269
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
270
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
271
+ { LLM_TENSOR_OUTPUT, "output" },
272
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
273
+ { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
274
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
275
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
276
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
277
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
278
+ },
279
+ },
280
+ {
281
+ LLM_ARCH_GROK,
282
+ {
283
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
284
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
285
+ { LLM_TENSOR_OUTPUT, "output" },
286
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
287
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
288
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
289
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
290
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
291
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
292
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
293
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
294
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
295
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
296
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
297
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
298
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
299
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
300
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
301
+ { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
302
+ { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
303
+ },
304
+ },
305
+ {
306
+ LLM_ARCH_GPT2,
307
+ {
308
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
309
+ { LLM_TENSOR_POS_EMBD, "position_embd" },
310
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
311
+ { LLM_TENSOR_OUTPUT, "output" },
312
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
313
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
314
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
315
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
316
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
317
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
318
+ },
319
+ },
320
+ {
321
+ LLM_ARCH_GPTJ,
322
+ {
323
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
324
+ },
325
+ },
326
+ {
327
+ LLM_ARCH_GPTNEOX,
328
+ {
329
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
330
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
331
+ { LLM_TENSOR_OUTPUT, "output" },
332
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
333
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
334
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
335
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
336
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
337
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
338
+ },
339
+ },
340
+ {
341
+ LLM_ARCH_MPT,
342
+ {
343
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
344
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
345
+ { LLM_TENSOR_OUTPUT, "output"},
346
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
347
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
348
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
349
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
350
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
351
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
352
+ { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" },
353
+ { LLM_TENSOR_POS_EMBD, "position_embd" },
354
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"},
355
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"},
356
+ },
357
+ },
358
+ {
359
+ LLM_ARCH_STARCODER,
360
+ {
361
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
362
+ { LLM_TENSOR_POS_EMBD, "position_embd" },
363
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
364
+ { LLM_TENSOR_OUTPUT, "output" },
365
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
366
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
367
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
368
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
369
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
370
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
371
+ },
372
+ },
373
+ {
374
+ LLM_ARCH_REFACT,
375
+ {
376
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
377
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
378
+ { LLM_TENSOR_OUTPUT, "output" },
379
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
380
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
381
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
382
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
383
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
384
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
385
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
386
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
387
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
388
+ },
389
+ },
390
+ {
391
+ LLM_ARCH_BERT,
392
+ {
393
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
394
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
395
+ { LLM_TENSOR_TOKEN_TYPES, "token_types" },
396
+ { LLM_TENSOR_POS_EMBD, "position_embd" },
397
+ { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
398
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
399
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
400
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
401
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
402
+ { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
403
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
404
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
405
+ { LLM_TENSOR_CLS, "cls" },
406
+ { LLM_TENSOR_CLS_OUT, "cls.output" },
407
+ },
408
+ },
409
+ {
410
+ LLM_ARCH_NOMIC_BERT,
411
+ {
412
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
413
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
414
+ { LLM_TENSOR_TOKEN_TYPES, "token_types" },
415
+ { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
416
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
417
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
418
+ { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
419
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
420
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
421
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
422
+ },
423
+ },
424
+ {
425
+ LLM_ARCH_JINA_BERT_V2,
426
+ {
427
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
428
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
429
+ { LLM_TENSOR_TOKEN_TYPES, "token_types" },
430
+ { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
431
+ { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
432
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
433
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
434
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
435
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
436
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
437
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
438
+ { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
439
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
440
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
441
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
442
+ { LLM_TENSOR_CLS, "cls" },
443
+ },
444
+ },
445
+ {
446
+ LLM_ARCH_BLOOM,
447
+ {
448
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
449
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
450
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
451
+ { LLM_TENSOR_OUTPUT, "output" },
452
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
453
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
454
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
455
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
456
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
457
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
458
+ },
459
+ },
460
+ {
461
+ LLM_ARCH_STABLELM,
462
+ {
463
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
464
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
465
+ { LLM_TENSOR_OUTPUT, "output" },
466
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
467
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
468
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
469
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
470
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
471
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
472
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
473
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
474
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
475
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
476
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
477
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
478
+ },
479
+ },
480
+ {
481
+ LLM_ARCH_QWEN,
482
+ {
483
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
484
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
485
+ { LLM_TENSOR_OUTPUT, "output" },
486
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
487
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
488
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
489
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
490
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
491
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
492
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
493
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
494
+ },
495
+ },
496
+ {
497
+ LLM_ARCH_QWEN2,
498
+ {
499
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
500
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
501
+ { LLM_TENSOR_OUTPUT, "output" },
502
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
503
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
504
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
505
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
506
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
507
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
508
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
509
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
510
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
511
+ },
512
+ },
513
+ {
514
+ LLM_ARCH_QWEN2VL,
515
+ {
516
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
517
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
518
+ { LLM_TENSOR_OUTPUT, "output" },
519
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
520
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
521
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
522
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
523
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
524
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
525
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
526
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
527
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
528
+ },
529
+ },
530
+ {
531
+ LLM_ARCH_QWEN2MOE,
532
+ {
533
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
534
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
535
+ { LLM_TENSOR_OUTPUT, "output" },
536
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
537
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
538
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
539
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
540
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
541
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
542
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
543
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
544
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
545
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
546
+ { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
547
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
548
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
549
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
550
+ },
551
+ },
552
+ {
553
+ LLM_ARCH_PHI2,
554
+ {
555
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
556
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
557
+ { LLM_TENSOR_OUTPUT, "output" },
558
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
559
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
560
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
561
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
562
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
563
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
564
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
565
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
566
+ },
567
+ },
568
+ {
569
+ LLM_ARCH_PHI3,
570
+ {
571
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
572
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
573
+ { LLM_TENSOR_OUTPUT, "output" },
574
+ { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
575
+ { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
576
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
577
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
578
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
579
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
580
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
581
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
582
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
583
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
584
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
585
+ },
586
+ },
587
+ {
588
+ LLM_ARCH_PLAMO,
589
+ {
590
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
591
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
592
+ { LLM_TENSOR_OUTPUT, "output" },
593
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
594
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
595
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
596
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
597
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
598
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
599
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
600
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
601
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
602
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
603
+ },
604
+ },
605
+ {
606
+ LLM_ARCH_CODESHELL,
607
+ {
608
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
609
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
610
+ { LLM_TENSOR_OUTPUT, "output" },
611
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
612
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
613
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
614
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
615
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
616
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
617
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
618
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
619
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
620
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
621
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
622
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
623
+ },
624
+ },
625
+ {
626
+ LLM_ARCH_ORION,
627
+ {
628
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
629
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
630
+ { LLM_TENSOR_OUTPUT, "output" },
631
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
632
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
633
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
634
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
635
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
636
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
637
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
638
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
639
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
640
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
641
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
642
+ },
643
+ },
644
+ {
645
+ LLM_ARCH_INTERNLM2,
646
+ {
647
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
648
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
649
+ { LLM_TENSOR_OUTPUT, "output" },
650
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
651
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
652
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
653
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
654
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
655
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
656
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
657
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
658
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
659
+ },
660
+ },
661
+ {
662
+ LLM_ARCH_MINICPM,
663
+ {
664
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
665
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
666
+ { LLM_TENSOR_OUTPUT, "output" },
667
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
668
+ { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
669
+ { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
670
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
671
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
672
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
673
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
674
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
675
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
676
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
677
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
678
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
679
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
680
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
681
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
682
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
683
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
684
+ },
685
+ },
686
+ {
687
+ LLM_ARCH_MINICPM3,
688
+ {
689
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
690
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
691
+ { LLM_TENSOR_OUTPUT, "output" },
692
+ { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
693
+ { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
694
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
695
+ { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
696
+ { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
697
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
698
+ { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
699
+ { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
700
+ { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
701
+ { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
702
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
703
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
704
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
705
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
706
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
707
+ },
708
+ },
709
+ {
710
+ LLM_ARCH_GEMMA,
711
+ {
712
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
713
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
714
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
715
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
716
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
717
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
718
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
719
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
720
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
721
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
722
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
723
+ },
724
+ },
725
+ {
726
+ LLM_ARCH_GEMMA2,
727
+ {
728
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
729
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
730
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
731
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
732
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
733
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
734
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
735
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
736
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
737
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
738
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
739
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
740
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
741
+ },
742
+ },
743
+ {
744
+ LLM_ARCH_STARCODER2,
745
+ {
746
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
747
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
748
+ { LLM_TENSOR_OUTPUT, "output" },
749
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
750
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
751
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
752
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
753
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
754
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
755
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
756
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
757
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
758
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
759
+ },
760
+ },
761
+ {
762
+ LLM_ARCH_MAMBA,
763
+ {
764
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
765
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
766
+ { LLM_TENSOR_OUTPUT, "output" },
767
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
768
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
769
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
770
+ { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
771
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
772
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
773
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
774
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
775
+ },
776
+ },
777
+ {
778
+ LLM_ARCH_XVERSE,
779
+ {
780
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
781
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
782
+ { LLM_TENSOR_OUTPUT, "output" },
783
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
784
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
785
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
786
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
787
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
788
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
789
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
790
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
791
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
792
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
793
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
794
+ },
795
+ },
796
+ {
797
+ LLM_ARCH_COMMAND_R,
798
+ {
799
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
800
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
801
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
802
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
803
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
804
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
805
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
806
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
807
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
808
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
809
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
810
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
811
+ },
812
+ },
813
+ {
814
+ LLM_ARCH_COHERE2,
815
+ {
816
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
817
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
818
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
819
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
820
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
821
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
822
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
823
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
824
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
825
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
826
+ },
827
+ },
828
+ {
829
+ LLM_ARCH_DBRX,
830
+ {
831
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
832
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
833
+ { LLM_TENSOR_OUTPUT, "output" },
834
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
835
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
836
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
837
+ { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
838
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
839
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
840
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
841
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
842
+ },
843
+ },
844
+ {
845
+ LLM_ARCH_OLMO,
846
+ {
847
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
848
+ { LLM_TENSOR_OUTPUT, "output" },
849
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
850
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
851
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
852
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
853
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
854
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
855
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
856
+ },
857
+ },
858
+ {
859
+ LLM_ARCH_OLMO2,
860
+ {
861
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
862
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
863
+ { LLM_TENSOR_OUTPUT, "output" },
864
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
865
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
866
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
867
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
868
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
869
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
870
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
871
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
872
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
873
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
874
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
875
+ },
876
+ },
877
+ {
878
+ LLM_ARCH_OLMOE,
879
+ {
880
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
881
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
882
+ { LLM_TENSOR_OUTPUT, "output" },
883
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
884
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
885
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
886
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
887
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
888
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
889
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
890
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
891
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
892
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
893
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
894
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
895
+ },
896
+ },
897
+ {
898
+ LLM_ARCH_OPENELM,
899
+ {
900
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
901
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
902
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
903
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
904
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
905
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
906
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
907
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
908
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
909
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
910
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
911
+ },
912
+ },
913
+ {
914
+ LLM_ARCH_ARCTIC,
915
+ {
916
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
917
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
918
+ { LLM_TENSOR_OUTPUT, "output" },
919
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
920
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
921
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
922
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
923
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
924
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
925
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
926
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
927
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
928
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
929
+ { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" },
930
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
931
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
932
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
933
+ },
934
+ },
935
+ {
936
+ LLM_ARCH_DEEPSEEK,
937
+ {
938
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
939
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
940
+ { LLM_TENSOR_OUTPUT, "output" },
941
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
942
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
943
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
944
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
945
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
946
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
947
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
948
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
949
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
950
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
951
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
952
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
953
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
954
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
955
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
956
+ { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
957
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
958
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
959
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
960
+ },
961
+ },
962
+ {
963
+ LLM_ARCH_DEEPSEEK2,
964
+ {
965
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
966
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
967
+ { LLM_TENSOR_OUTPUT, "output" },
968
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
969
+ { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
970
+ { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
971
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
972
+ { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
973
+ { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
974
+ { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
975
+ { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
976
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
977
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
978
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
979
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
980
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
981
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
982
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
983
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
984
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
985
+ { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
986
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
987
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
988
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
989
+ { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
990
+ },
991
+ },
992
+ {
993
+ LLM_ARCH_CHATGLM,
994
+ {
995
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
996
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
997
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
998
+ { LLM_TENSOR_OUTPUT, "output" },
999
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1000
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
1001
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1002
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1003
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1004
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1005
+ },
1006
+ },
1007
+ {
1008
+ LLM_ARCH_BITNET,
1009
+ {
1010
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1011
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1012
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1013
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1014
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1015
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1016
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1017
+ { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
1018
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1019
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1020
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1021
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1022
+ { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
1023
+ },
1024
+ },
1025
+ {
1026
+ LLM_ARCH_T5,
1027
+ {
1028
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1029
+ { LLM_TENSOR_OUTPUT, "output" },
1030
+ { LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" },
1031
+ { LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" },
1032
+ { LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" },
1033
+ { LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" },
1034
+ { LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" },
1035
+ { LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" },
1036
+ { LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" },
1037
+ { LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" },
1038
+ { LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" },
1039
+ { LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" },
1040
+ { LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" },
1041
+ { LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" },
1042
+ { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" },
1043
+ { LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" },
1044
+ { LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" },
1045
+ { LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" },
1046
+ { LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" },
1047
+ { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
1048
+ { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
1049
+ { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
1050
+ { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
1051
+ { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
1052
+ { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
1053
+ { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" },
1054
+ { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
1055
+ { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" },
1056
+ { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
1057
+ { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
1058
+ },
1059
+ },
1060
+ {
1061
+ LLM_ARCH_T5ENCODER,
1062
+ {
1063
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1064
+ { LLM_TENSOR_OUTPUT, "output" },
1065
+ { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
1066
+ { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
1067
+ { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
1068
+ { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
1069
+ { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
1070
+ { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
1071
+ { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" },
1072
+ { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
1073
+ { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" },
1074
+ { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
1075
+ { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
1076
+ },
1077
+ },
1078
+ {
1079
+ LLM_ARCH_JAIS,
1080
+ {
1081
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1082
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1083
+ { LLM_TENSOR_OUTPUT, "output" },
1084
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1085
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
1086
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1087
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1088
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1089
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1090
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1091
+ },
1092
+ },
1093
+ {
1094
+ LLM_ARCH_NEMOTRON,
1095
+ {
1096
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1097
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1098
+ { LLM_TENSOR_OUTPUT, "output" },
1099
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
1100
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1101
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1102
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1103
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1104
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1105
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
1106
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1107
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1108
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1109
+ },
1110
+ },
1111
+ {
1112
+ LLM_ARCH_EXAONE,
1113
+ {
1114
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1115
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1116
+ { LLM_TENSOR_OUTPUT, "output" },
1117
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
1118
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1119
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1120
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1121
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1122
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1123
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
1124
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1125
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1126
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1127
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1128
+ },
1129
+ },
1130
+ {
1131
+ LLM_ARCH_RWKV6,
1132
+ {
1133
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1134
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
1135
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1136
+ { LLM_TENSOR_OUTPUT, "output" },
1137
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1138
+ { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
1139
+ { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
1140
+ { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
1141
+ { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" },
1142
+ { LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" },
1143
+ { LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" },
1144
+ { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" },
1145
+ { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" },
1146
+ { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" },
1147
+ { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
1148
+ { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
1149
+ { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
1150
+ { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" },
1151
+ { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
1152
+ { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
1153
+ { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
1154
+ { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" },
1155
+ { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
1156
+ { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
1157
+ { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
1158
+ { LLM_TENSOR_CHANNEL_MIX_LERP_R, "blk.%d.channel_mix_lerp_r" },
1159
+ { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
1160
+ { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
1161
+ { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" },
1162
+ },
1163
+ },
1164
+ {
1165
+ LLM_ARCH_GRANITE,
1166
+ {
1167
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1168
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1169
+ { LLM_TENSOR_OUTPUT, "output" },
1170
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1171
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1172
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1173
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1174
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1175
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1176
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1177
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1178
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1179
+ },
1180
+ },
1181
+ {
1182
+ LLM_ARCH_GRANITE_MOE,
1183
+ {
1184
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1185
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1186
+ { LLM_TENSOR_OUTPUT, "output" },
1187
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1188
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1189
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1190
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1191
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1192
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1193
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1194
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1195
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1196
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1197
+ },
1198
+ },
1199
+ {
1200
+ LLM_ARCH_CHAMELEON,
1201
+ {
1202
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1203
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1204
+ { LLM_TENSOR_OUTPUT, "output" },
1205
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1206
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1207
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1208
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1209
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1210
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1211
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1212
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1213
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1214
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1215
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1216
+ },
1217
+ },
1218
+ {
1219
+ LLM_ARCH_WAVTOKENIZER_DEC,
1220
+ {
1221
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1222
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
1223
+ { LLM_TENSOR_CONV1D, "conv1d" },
1224
+ { LLM_TENSOR_CONVNEXT_DW, "convnext.%d.dw" },
1225
+ { LLM_TENSOR_CONVNEXT_NORM, "convnext.%d.norm" },
1226
+ { LLM_TENSOR_CONVNEXT_PW1, "convnext.%d.pw1" },
1227
+ { LLM_TENSOR_CONVNEXT_PW2, "convnext.%d.pw2" },
1228
+ { LLM_TENSOR_CONVNEXT_GAMMA, "convnext.%d.gamma" },
1229
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1230
+ { LLM_TENSOR_OUTPUT, "output" },
1231
+ { LLM_TENSOR_POS_NET_CONV1, "posnet.%d.conv1" },
1232
+ { LLM_TENSOR_POS_NET_CONV2, "posnet.%d.conv2" },
1233
+ { LLM_TENSOR_POS_NET_NORM, "posnet.%d.norm" },
1234
+ { LLM_TENSOR_POS_NET_NORM1, "posnet.%d.norm1" },
1235
+ { LLM_TENSOR_POS_NET_NORM2, "posnet.%d.norm2" },
1236
+ { LLM_TENSOR_POS_NET_ATTN_NORM, "posnet.%d.attn_norm" },
1237
+ { LLM_TENSOR_POS_NET_ATTN_Q, "posnet.%d.attn_q" },
1238
+ { LLM_TENSOR_POS_NET_ATTN_K, "posnet.%d.attn_k" },
1239
+ { LLM_TENSOR_POS_NET_ATTN_V, "posnet.%d.attn_v" },
1240
+ { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
1241
+ },
1242
+ },
1243
+ {
1244
+ LLM_ARCH_UNKNOWN,
1245
+ {
1246
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1247
+ },
1248
+ },
1249
+ };
1250
+
1251
+ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
1252
+ {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
1253
+ {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
1254
+ {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
1255
+ {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
1256
+ {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
1257
+ {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
1258
+ {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
1259
+ {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
1260
+ {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
1261
+ {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
1262
+ {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
1263
+ {LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
1264
+ {LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
1265
+ {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1266
+ {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1267
+ {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1268
+ {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1269
+ {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1270
+ {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1271
+ {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1272
+ {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1273
+ {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1274
+ {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1275
+ {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1276
+ {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1277
+ {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1278
+ {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1279
+ {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1280
+ {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1281
+ {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1282
+ {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1283
+ {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1284
+ {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1285
+ {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1286
+ {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1287
+ {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1288
+ {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1289
+ {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1290
+ {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1291
+ {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1292
+ {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1293
+ {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1294
+ {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1295
+ {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1296
+ {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1297
+ {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1298
+ {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1299
+ {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1300
+ {LLM_TENSOR_DEC_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1301
+ {LLM_TENSOR_DEC_CROSS_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1302
+ {LLM_TENSOR_DEC_CROSS_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1303
+ {LLM_TENSOR_DEC_CROSS_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1304
+ {LLM_TENSOR_DEC_CROSS_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1305
+ {LLM_TENSOR_DEC_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1306
+ {LLM_TENSOR_DEC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1307
+ {LLM_TENSOR_DEC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1308
+ {LLM_TENSOR_ENC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1309
+ {LLM_TENSOR_ENC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1310
+ {LLM_TENSOR_ENC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1311
+ {LLM_TENSOR_ENC_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1312
+ {LLM_TENSOR_ENC_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1313
+ {LLM_TENSOR_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1314
+ {LLM_TENSOR_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1315
+ {LLM_TENSOR_FFN_GATE_INP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1316
+ {LLM_TENSOR_FFN_GATE_INP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1317
+ {LLM_TENSOR_SSM_IN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1318
+ {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1319
+ {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1320
+ {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1321
+ {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1322
+ {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1323
+ {LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1324
+ {LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1325
+ {LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1326
+ {LLM_TENSOR_TIME_MIX_VALUE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1327
+ {LLM_TENSOR_TIME_MIX_RECEPTANCE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1328
+ {LLM_TENSOR_TIME_MIX_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1329
+ {LLM_TENSOR_TIME_MIX_OUTPUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1330
+ {LLM_TENSOR_CHANNEL_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1331
+ {LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1332
+ {LLM_TENSOR_CHANNEL_MIX_VALUE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1333
+ {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
1334
+ {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
1335
+ {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
1336
+ {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1337
+ {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1338
+ {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1339
+ {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1340
+ {LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1341
+ {LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1342
+ {LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1343
+ {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1344
+ {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1345
+ {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1346
+ {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1347
+ {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
1348
+ {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1349
+ {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1350
+ {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1351
+ {LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1352
+ {LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1353
+ {LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1354
+ {LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1355
+ {LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1356
+ {LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1357
+ {LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1358
+ {LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1359
+ {LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1360
+ {LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1361
+ {LLM_TENSOR_FFN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1362
+ {LLM_TENSOR_DEC_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1363
+ {LLM_TENSOR_DEC_CROSS_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1364
+ {LLM_TENSOR_DEC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1365
+ {LLM_TENSOR_ENC_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1366
+ {LLM_TENSOR_ENC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1367
+ {LLM_TENSOR_DEC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
1368
+ {LLM_TENSOR_ENC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
1369
+ {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
1370
+ {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
1371
+ {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
1372
+ {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1373
+ // this tensor is loaded for T5, but never used
1374
+ {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
1375
+ {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
1376
+ {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1377
+ {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1378
+ {LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1379
+ {LLM_TENSOR_POS_NET_CONV1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
1380
+ {LLM_TENSOR_POS_NET_CONV2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
1381
+ {LLM_TENSOR_POS_NET_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1382
+ {LLM_TENSOR_POS_NET_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1383
+ {LLM_TENSOR_POS_NET_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1384
+ {LLM_TENSOR_POS_NET_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1385
+ {LLM_TENSOR_POS_NET_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1386
+ {LLM_TENSOR_CONVNEXT_DW, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
1387
+ {LLM_TENSOR_CONVNEXT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1388
+ {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1389
+ {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1390
+ {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1391
+ };
1392
+
1393
+ LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
1394
+
1395
+ std::string LLM_KV::operator()(llm_kv kv) const {
1396
+ return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
1397
+ }
1398
+
1399
+ std::string LLM_TN_IMPL::str() const {
1400
+ if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
1401
+ return "__missing__";
1402
+ }
1403
+
1404
+ std::string name = ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid, xid);
1405
+
1406
+ if (suffix != nullptr) {
1407
+ name += ".";
1408
+ name += suffix;
1409
+ }
1410
+
1411
+ return name;
1412
+ }
1413
+
1414
+ const char * llm_arch_name(llm_arch arch) {
1415
+ auto it = LLM_ARCH_NAMES.find(arch);
1416
+ if (it == LLM_ARCH_NAMES.end()) {
1417
+ return "unknown";
1418
+ }
1419
+ return it->second;
1420
+ }
1421
+
1422
+ llm_arch llm_arch_from_string(const std::string & name) {
1423
+ for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
1424
+ if (kv.second == name) {
1425
+ return kv.first;
1426
+ }
1427
+ }
1428
+
1429
+ return LLM_ARCH_UNKNOWN;
1430
+ }
1431
+
1432
+ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
1433
+ return LLM_TENSOR_INFOS.at(tensor);
1434
+ }
examples/talk-llama/llama-arch.h ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml.h" // ggml_op
4
+
5
+ #include <string>
6
+
7
+ //
8
+ // gguf constants (sync with gguf.py)
9
+ //
10
+
11
+ enum llm_arch {
12
+ LLM_ARCH_LLAMA,
13
+ LLM_ARCH_DECI,
14
+ LLM_ARCH_FALCON,
15
+ LLM_ARCH_BAICHUAN,
16
+ LLM_ARCH_GROK,
17
+ LLM_ARCH_GPT2,
18
+ LLM_ARCH_GPTJ,
19
+ LLM_ARCH_GPTNEOX,
20
+ LLM_ARCH_MPT,
21
+ LLM_ARCH_STARCODER,
22
+ LLM_ARCH_REFACT,
23
+ LLM_ARCH_BERT,
24
+ LLM_ARCH_NOMIC_BERT,
25
+ LLM_ARCH_JINA_BERT_V2,
26
+ LLM_ARCH_BLOOM,
27
+ LLM_ARCH_STABLELM,
28
+ LLM_ARCH_QWEN,
29
+ LLM_ARCH_QWEN2,
30
+ LLM_ARCH_QWEN2MOE,
31
+ LLM_ARCH_QWEN2VL,
32
+ LLM_ARCH_PHI2,
33
+ LLM_ARCH_PHI3,
34
+ LLM_ARCH_PLAMO,
35
+ LLM_ARCH_CODESHELL,
36
+ LLM_ARCH_ORION,
37
+ LLM_ARCH_INTERNLM2,
38
+ LLM_ARCH_MINICPM,
39
+ LLM_ARCH_MINICPM3,
40
+ LLM_ARCH_GEMMA,
41
+ LLM_ARCH_GEMMA2,
42
+ LLM_ARCH_STARCODER2,
43
+ LLM_ARCH_MAMBA,
44
+ LLM_ARCH_XVERSE,
45
+ LLM_ARCH_COMMAND_R,
46
+ LLM_ARCH_COHERE2,
47
+ LLM_ARCH_DBRX,
48
+ LLM_ARCH_OLMO,
49
+ LLM_ARCH_OLMO2,
50
+ LLM_ARCH_OLMOE,
51
+ LLM_ARCH_OPENELM,
52
+ LLM_ARCH_ARCTIC,
53
+ LLM_ARCH_DEEPSEEK,
54
+ LLM_ARCH_DEEPSEEK2,
55
+ LLM_ARCH_CHATGLM,
56
+ LLM_ARCH_BITNET,
57
+ LLM_ARCH_T5,
58
+ LLM_ARCH_T5ENCODER,
59
+ LLM_ARCH_JAIS,
60
+ LLM_ARCH_NEMOTRON,
61
+ LLM_ARCH_EXAONE,
62
+ LLM_ARCH_RWKV6,
63
+ LLM_ARCH_GRANITE,
64
+ LLM_ARCH_GRANITE_MOE,
65
+ LLM_ARCH_CHAMELEON,
66
+ LLM_ARCH_WAVTOKENIZER_DEC,
67
+ LLM_ARCH_UNKNOWN,
68
+ };
69
+
70
+ enum llm_kv {
71
+ LLM_KV_GENERAL_TYPE,
72
+ LLM_KV_GENERAL_ARCHITECTURE,
73
+ LLM_KV_GENERAL_QUANTIZATION_VERSION,
74
+ LLM_KV_GENERAL_ALIGNMENT,
75
+ LLM_KV_GENERAL_NAME,
76
+ LLM_KV_GENERAL_AUTHOR,
77
+ LLM_KV_GENERAL_VERSION,
78
+ LLM_KV_GENERAL_URL,
79
+ LLM_KV_GENERAL_DESCRIPTION,
80
+ LLM_KV_GENERAL_LICENSE,
81
+ LLM_KV_GENERAL_SOURCE_URL,
82
+ LLM_KV_GENERAL_SOURCE_HF_REPO,
83
+
84
+ LLM_KV_VOCAB_SIZE,
85
+ LLM_KV_CONTEXT_LENGTH,
86
+ LLM_KV_EMBEDDING_LENGTH,
87
+ LLM_KV_FEATURES_LENGTH,
88
+ LLM_KV_BLOCK_COUNT,
89
+ LLM_KV_LEADING_DENSE_BLOCK_COUNT,
90
+ LLM_KV_FEED_FORWARD_LENGTH,
91
+ LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
92
+ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
93
+ LLM_KV_USE_PARALLEL_RESIDUAL,
94
+ LLM_KV_TENSOR_DATA_LAYOUT,
95
+ LLM_KV_EXPERT_COUNT,
96
+ LLM_KV_EXPERT_USED_COUNT,
97
+ LLM_KV_EXPERT_SHARED_COUNT,
98
+ LLM_KV_EXPERT_WEIGHTS_SCALE,
99
+ LLM_KV_EXPERT_WEIGHTS_NORM,
100
+ LLM_KV_EXPERT_GATING_FUNC,
101
+ LLM_KV_POOLING_TYPE,
102
+ LLM_KV_LOGIT_SCALE,
103
+ LLM_KV_DECODER_START_TOKEN_ID,
104
+ LLM_KV_ATTN_LOGIT_SOFTCAPPING,
105
+ LLM_KV_FINAL_LOGIT_SOFTCAPPING,
106
+ LLM_KV_SWIN_NORM,
107
+ LLM_KV_RESCALE_EVERY_N_LAYERS,
108
+ LLM_KV_TIME_MIX_EXTRA_DIM,
109
+ LLM_KV_TIME_DECAY_EXTRA_DIM,
110
+ LLM_KV_RESIDUAL_SCALE,
111
+ LLM_KV_EMBEDDING_SCALE,
112
+
113
+ LLM_KV_ATTENTION_HEAD_COUNT,
114
+ LLM_KV_ATTENTION_HEAD_COUNT_KV,
115
+ LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
116
+ LLM_KV_ATTENTION_CLAMP_KQV,
117
+ LLM_KV_ATTENTION_KEY_LENGTH,
118
+ LLM_KV_ATTENTION_VALUE_LENGTH,
119
+ LLM_KV_ATTENTION_LAYERNORM_EPS,
120
+ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
121
+ LLM_KV_ATTENTION_GROUPNORM_EPS,
122
+ LLM_KV_ATTENTION_GROUPNORM_GROUPS,
123
+ LLM_KV_ATTENTION_CAUSAL,
124
+ LLM_KV_ATTENTION_Q_LORA_RANK,
125
+ LLM_KV_ATTENTION_KV_LORA_RANK,
126
+ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
127
+ LLM_KV_ATTENTION_SLIDING_WINDOW,
128
+ LLM_KV_ATTENTION_SCALE,
129
+
130
+ LLM_KV_ROPE_DIMENSION_COUNT,
131
+ LLM_KV_ROPE_DIMENSION_SECTIONS,
132
+ LLM_KV_ROPE_FREQ_BASE,
133
+ LLM_KV_ROPE_SCALE_LINEAR,
134
+ LLM_KV_ROPE_SCALING_TYPE,
135
+ LLM_KV_ROPE_SCALING_FACTOR,
136
+ LLM_KV_ROPE_SCALING_ATTN_FACTOR,
137
+ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
138
+ LLM_KV_ROPE_SCALING_FINETUNED,
139
+ LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
140
+
141
+ LLM_KV_SPLIT_NO,
142
+ LLM_KV_SPLIT_COUNT,
143
+ LLM_KV_SPLIT_TENSORS_COUNT,
144
+
145
+ LLM_KV_SSM_INNER_SIZE,
146
+ LLM_KV_SSM_CONV_KERNEL,
147
+ LLM_KV_SSM_STATE_SIZE,
148
+ LLM_KV_SSM_TIME_STEP_RANK,
149
+ LLM_KV_SSM_DT_B_C_RMS,
150
+
151
+ LLM_KV_WKV_HEAD_SIZE,
152
+
153
+ LLM_KV_TOKENIZER_MODEL,
154
+ LLM_KV_TOKENIZER_PRE,
155
+ LLM_KV_TOKENIZER_LIST,
156
+ LLM_KV_TOKENIZER_TOKEN_TYPE,
157
+ LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
158
+ LLM_KV_TOKENIZER_SCORES,
159
+ LLM_KV_TOKENIZER_MERGES,
160
+ LLM_KV_TOKENIZER_BOS_ID,
161
+ LLM_KV_TOKENIZER_EOS_ID,
162
+ LLM_KV_TOKENIZER_EOT_ID,
163
+ LLM_KV_TOKENIZER_EOM_ID,
164
+ LLM_KV_TOKENIZER_UNK_ID,
165
+ LLM_KV_TOKENIZER_SEP_ID,
166
+ LLM_KV_TOKENIZER_PAD_ID,
167
+ LLM_KV_TOKENIZER_CLS_ID,
168
+ LLM_KV_TOKENIZER_MASK_ID,
169
+ LLM_KV_TOKENIZER_ADD_BOS,
170
+ LLM_KV_TOKENIZER_ADD_EOS,
171
+ LLM_KV_TOKENIZER_ADD_PREFIX,
172
+ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
173
+ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
174
+ LLM_KV_TOKENIZER_HF_JSON,
175
+ LLM_KV_TOKENIZER_RWKV,
176
+ LLM_KV_TOKENIZER_FIM_PRE_ID,
177
+ LLM_KV_TOKENIZER_FIM_SUF_ID,
178
+ LLM_KV_TOKENIZER_FIM_MID_ID,
179
+ LLM_KV_TOKENIZER_FIM_PAD_ID,
180
+ LLM_KV_TOKENIZER_FIM_REP_ID,
181
+ LLM_KV_TOKENIZER_FIM_SEP_ID,
182
+
183
+ LLM_KV_ADAPTER_TYPE,
184
+ LLM_KV_ADAPTER_LORA_ALPHA,
185
+
186
+ LLM_KV_POSNET_EMBEDDING_LENGTH,
187
+ LLM_KV_POSNET_BLOCK_COUNT,
188
+
189
+ LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
190
+ LLM_KV_CONVNEXT_BLOCK_COUNT,
191
+
192
+ // deprecated:
193
+ LLM_KV_TOKENIZER_PREFIX_ID,
194
+ LLM_KV_TOKENIZER_SUFFIX_ID,
195
+ LLM_KV_TOKENIZER_MIDDLE_ID,
196
+ };
197
+
198
+ enum llm_tensor {
199
+ LLM_TENSOR_TOKEN_EMBD,
200
+ LLM_TENSOR_TOKEN_EMBD_NORM,
201
+ LLM_TENSOR_TOKEN_TYPES,
202
+ LLM_TENSOR_POS_EMBD,
203
+ LLM_TENSOR_OUTPUT,
204
+ LLM_TENSOR_OUTPUT_NORM,
205
+ LLM_TENSOR_ROPE_FREQS,
206
+ LLM_TENSOR_ROPE_FACTORS_LONG,
207
+ LLM_TENSOR_ROPE_FACTORS_SHORT,
208
+ LLM_TENSOR_ATTN_Q,
209
+ LLM_TENSOR_ATTN_K,
210
+ LLM_TENSOR_ATTN_V,
211
+ LLM_TENSOR_ATTN_QKV,
212
+ LLM_TENSOR_ATTN_OUT,
213
+ LLM_TENSOR_ATTN_NORM,
214
+ LLM_TENSOR_ATTN_NORM_2,
215
+ LLM_TENSOR_ATTN_OUT_NORM,
216
+ LLM_TENSOR_ATTN_POST_NORM,
217
+ LLM_TENSOR_ATTN_ROT_EMBD,
218
+ LLM_TENSOR_FFN_GATE_INP,
219
+ LLM_TENSOR_FFN_GATE_INP_SHEXP,
220
+ LLM_TENSOR_FFN_NORM,
221
+ LLM_TENSOR_FFN_POST_NORM,
222
+ LLM_TENSOR_FFN_GATE,
223
+ LLM_TENSOR_FFN_DOWN,
224
+ LLM_TENSOR_FFN_UP,
225
+ LLM_TENSOR_FFN_ACT,
226
+ LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility
227
+ LLM_TENSOR_FFN_GATE_EXP,
228
+ LLM_TENSOR_FFN_UP_EXP,
229
+ LLM_TENSOR_FFN_NORM_EXPS,
230
+ LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
231
+ LLM_TENSOR_FFN_GATE_EXPS,
232
+ LLM_TENSOR_FFN_UP_EXPS,
233
+ LLM_TENSOR_FFN_DOWN_SHEXP,
234
+ LLM_TENSOR_FFN_GATE_SHEXP,
235
+ LLM_TENSOR_FFN_UP_SHEXP,
236
+ LLM_TENSOR_FFN_EXP_PROBS_B,
237
+ LLM_TENSOR_ATTN_Q_NORM,
238
+ LLM_TENSOR_ATTN_K_NORM,
239
+ LLM_TENSOR_LAYER_OUT_NORM,
240
+ LLM_TENSOR_SSM_IN,
241
+ LLM_TENSOR_SSM_CONV1D,
242
+ LLM_TENSOR_SSM_X,
243
+ LLM_TENSOR_SSM_DT,
244
+ LLM_TENSOR_SSM_A,
245
+ LLM_TENSOR_SSM_D,
246
+ LLM_TENSOR_SSM_OUT,
247
+ LLM_TENSOR_TIME_MIX_W1,
248
+ LLM_TENSOR_TIME_MIX_W2,
249
+ LLM_TENSOR_TIME_MIX_LERP_X,
250
+ LLM_TENSOR_TIME_MIX_LERP_W,
251
+ LLM_TENSOR_TIME_MIX_LERP_K,
252
+ LLM_TENSOR_TIME_MIX_LERP_V,
253
+ LLM_TENSOR_TIME_MIX_LERP_R,
254
+ LLM_TENSOR_TIME_MIX_LERP_G,
255
+ LLM_TENSOR_TIME_MIX_FIRST,
256
+ LLM_TENSOR_TIME_MIX_DECAY,
257
+ LLM_TENSOR_TIME_MIX_DECAY_W1,
258
+ LLM_TENSOR_TIME_MIX_DECAY_W2,
259
+ LLM_TENSOR_TIME_MIX_KEY,
260
+ LLM_TENSOR_TIME_MIX_VALUE,
261
+ LLM_TENSOR_TIME_MIX_RECEPTANCE,
262
+ LLM_TENSOR_TIME_MIX_GATE,
263
+ LLM_TENSOR_TIME_MIX_LN,
264
+ LLM_TENSOR_TIME_MIX_OUTPUT,
265
+ LLM_TENSOR_CHANNEL_MIX_LERP_K,
266
+ LLM_TENSOR_CHANNEL_MIX_LERP_R,
267
+ LLM_TENSOR_CHANNEL_MIX_KEY,
268
+ LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,
269
+ LLM_TENSOR_CHANNEL_MIX_VALUE,
270
+ LLM_TENSOR_ATTN_Q_A,
271
+ LLM_TENSOR_ATTN_Q_B,
272
+ LLM_TENSOR_ATTN_KV_A_MQA,
273
+ LLM_TENSOR_ATTN_KV_B,
274
+ LLM_TENSOR_ATTN_Q_A_NORM,
275
+ LLM_TENSOR_ATTN_KV_A_NORM,
276
+ LLM_TENSOR_ATTN_SUB_NORM,
277
+ LLM_TENSOR_FFN_SUB_NORM,
278
+ LLM_TENSOR_DEC_ATTN_NORM,
279
+ LLM_TENSOR_DEC_ATTN_Q,
280
+ LLM_TENSOR_DEC_ATTN_K,
281
+ LLM_TENSOR_DEC_ATTN_V,
282
+ LLM_TENSOR_DEC_ATTN_OUT,
283
+ LLM_TENSOR_DEC_ATTN_REL_B,
284
+ LLM_TENSOR_DEC_CROSS_ATTN_NORM,
285
+ LLM_TENSOR_DEC_CROSS_ATTN_Q,
286
+ LLM_TENSOR_DEC_CROSS_ATTN_K,
287
+ LLM_TENSOR_DEC_CROSS_ATTN_V,
288
+ LLM_TENSOR_DEC_CROSS_ATTN_OUT,
289
+ LLM_TENSOR_DEC_CROSS_ATTN_REL_B,
290
+ LLM_TENSOR_DEC_FFN_NORM,
291
+ LLM_TENSOR_DEC_FFN_GATE,
292
+ LLM_TENSOR_DEC_FFN_DOWN,
293
+ LLM_TENSOR_DEC_FFN_UP,
294
+ LLM_TENSOR_DEC_OUTPUT_NORM,
295
+ LLM_TENSOR_ENC_ATTN_NORM,
296
+ LLM_TENSOR_ENC_ATTN_Q,
297
+ LLM_TENSOR_ENC_ATTN_K,
298
+ LLM_TENSOR_ENC_ATTN_V,
299
+ LLM_TENSOR_ENC_ATTN_OUT,
300
+ LLM_TENSOR_ENC_ATTN_REL_B,
301
+ LLM_TENSOR_ENC_FFN_NORM,
302
+ LLM_TENSOR_ENC_FFN_GATE,
303
+ LLM_TENSOR_ENC_FFN_DOWN,
304
+ LLM_TENSOR_ENC_FFN_UP,
305
+ LLM_TENSOR_ENC_OUTPUT_NORM,
306
+ LLM_TENSOR_CLS,
307
+ LLM_TENSOR_CLS_OUT,
308
+ LLM_TENSOR_CONV1D,
309
+ LLM_TENSOR_CONVNEXT_DW,
310
+ LLM_TENSOR_CONVNEXT_NORM,
311
+ LLM_TENSOR_CONVNEXT_PW1,
312
+ LLM_TENSOR_CONVNEXT_PW2,
313
+ LLM_TENSOR_CONVNEXT_GAMMA,
314
+ LLM_TENSOR_POS_NET_CONV1,
315
+ LLM_TENSOR_POS_NET_CONV2,
316
+ LLM_TENSOR_POS_NET_NORM,
317
+ LLM_TENSOR_POS_NET_NORM1,
318
+ LLM_TENSOR_POS_NET_NORM2,
319
+ LLM_TENSOR_POS_NET_ATTN_NORM,
320
+ LLM_TENSOR_POS_NET_ATTN_Q,
321
+ LLM_TENSOR_POS_NET_ATTN_K,
322
+ LLM_TENSOR_POS_NET_ATTN_V,
323
+ LLM_TENSOR_POS_NET_ATTN_OUT,
324
+ };
325
+
326
+ enum llm_tensor_layer {
327
+ LLM_TENSOR_LAYER_INPUT,
328
+ LLM_TENSOR_LAYER_REPEATING,
329
+ LLM_TENSOR_LAYER_OUTPUT,
330
+ };
331
+
332
+ struct LLM_KV {
333
+ LLM_KV(llm_arch arch);
334
+
335
+ llm_arch arch;
336
+
337
+ std::string operator()(llm_kv kv) const;
338
+ };
339
+
340
+ // helper to handle gguf constants
341
+ // usage:
342
+ //
343
+ // const auto tn = LLM_TN(LLM_ARCH_LLAMA);
344
+ //
345
+ // std::string name = tn(LLM_TENSOR_OUTPUT); -> "output"
346
+ // std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias"
347
+ // std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight"
348
+ //
349
+ struct LLM_TN_IMPL {
350
+ const llm_arch arch;
351
+ const llm_tensor tensor;
352
+ const char * const suffix;
353
+ const int bid;
354
+ const int xid;
355
+
356
+ std::string str() const;
357
+
358
+ operator std::string() const {
359
+ return str();
360
+ }
361
+
362
+ friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) {
363
+ return str == tn.str();
364
+ }
365
+
366
+ friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) {
367
+ return str != tn.str();
368
+ }
369
+ };
370
+
371
+ struct LLM_TN {
372
+ LLM_TN(llm_arch arch) : arch(arch) {}
373
+
374
+ llm_arch arch;
375
+
376
+ LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
377
+ return { arch, tensor, suffix, bid, xid };
378
+ }
379
+
380
+ LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const {
381
+ return { arch, tensor, nullptr, bid, xid };
382
+ }
383
+ };
384
+
385
+
386
+ struct llm_tensor_info {
387
+ llm_tensor_layer layer;
388
+ ggml_op op;
389
+ };
390
+
391
+ const char * llm_arch_name(llm_arch arch);
392
+
393
+ llm_arch llm_arch_from_string(const std::string & name);
394
+
395
+ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
examples/talk-llama/llama-batch.cpp ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-batch.h"
2
+
3
+ #include <cstring>
4
+ #include <algorithm>
5
+
6
+ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
+ // clear empty sequences
8
+ // the previous ubatch is assumed to be gone,
9
+ // so nothing should refer to values in these sequences anymore.
10
+ for (size_t i = seq.size(); i-- > 0;) {
11
+ if (seq[i].length == 0) {
12
+ seq.pop_back();
13
+ } else {
14
+ break;
15
+ }
16
+ }
17
+ ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
+ ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
+ ubatch_pos.resize(n_ubatch);
20
+ ubatch_n_seq_id.resize(n_ubatch);
21
+ ubatch_seq_id.resize(n_ubatch);
22
+ ubatch_output.resize(n_ubatch);
23
+ llama_ubatch ubatch = {
24
+ /*equal_seqs =*/ true,
25
+ /*n_tokens =*/ 0,
26
+ /*n_seq_tokens =*/ 0,
27
+ /*n_seqs =*/ 0,
28
+ /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
+ /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
+ /*pos =*/ ubatch_pos.data(),
31
+ /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
+ /*seq_id =*/ ubatch_seq_id.data(),
33
+ /*output =*/ ubatch_output.data(),
34
+ };
35
+ return ubatch;
36
+ }
37
+
38
+ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
+ GGML_ASSERT(batch != nullptr);
40
+ GGML_ASSERT(length <= seq.length);
41
+ // Can only add sequences of equal lengths to a batch,
42
+ // otherwise it isn't clear to which sequence a token belongs
43
+ GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
+ GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
+ // NOTE: loops are separated for cache-friendliness
46
+ if (batch->token) {
47
+ if (ubatch.equal_seqs) {
48
+ for (size_t i = 0; i < length; ++i) {
49
+ ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
+ }
51
+ } else {
52
+ // simple split
53
+ ubatch.token = batch->token + seq.offset;
54
+ }
55
+ } else {
56
+ ubatch.token = nullptr;
57
+ }
58
+ if (batch->embd) {
59
+ if (ubatch.equal_seqs) {
60
+ for (size_t i = 0; i < length; ++i) {
61
+ memcpy(
62
+ ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
+ batch->embd + (n_embd * ids[seq.offset + i]),
64
+ n_embd * sizeof(float)
65
+ );
66
+ }
67
+ } else {
68
+ // simple split
69
+ ubatch.embd = batch->embd + (n_embd * seq.offset);
70
+ }
71
+ } else {
72
+ ubatch.embd = nullptr;
73
+ }
74
+ if (ubatch.equal_seqs) {
75
+ for (size_t i = 0; i < length; ++i) {
76
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
+ }
78
+ } else {
79
+ // simple split
80
+ ubatch.pos = batch->pos + seq.offset;
81
+ }
82
+ if (ubatch.equal_seqs) {
83
+ ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
+ if (seq.seq_id) {
85
+ ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
+ }
87
+ } else {
88
+ // simple split
89
+ if (batch->n_seq_id) {
90
+ ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
+ } else {
92
+ for (size_t i = 0; i < length; ++i) {
93
+ ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
+ }
95
+ }
96
+ if (batch->seq_id) {
97
+ ubatch.seq_id = batch->seq_id + seq.offset;
98
+ }
99
+ }
100
+ if (logits_all) {
101
+ for (size_t i = 0; i < length; ++i) {
102
+ ubatch.output[ubatch.n_tokens + i] = 1;
103
+ out_ids.push_back(ids[seq.offset + i]);
104
+ }
105
+ } else if (batch->logits) {
106
+ if (ubatch.equal_seqs) {
107
+ for (size_t i = 0; i < length; ++i) {
108
+ size_t id = ids[seq.offset + i];
109
+ int8_t is_output = batch->logits[id];
110
+ ubatch.output[ubatch.n_tokens + i] = is_output;
111
+ if (is_output) { out_ids.push_back(id); }
112
+ }
113
+ } else {
114
+ // simple split
115
+ ubatch.output = batch->logits + seq.offset;
116
+ for (size_t i = 0; i < length; ++i) {
117
+ if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
+ }
119
+ }
120
+ } else {
121
+ // only get last output
122
+ for (size_t i = 0; i < length; ++i) {
123
+ size_t id = ids[seq.offset + i];
124
+ int8_t is_last = id == ids.size() - 1;
125
+ ubatch.output[ubatch.n_tokens + i] = is_last;
126
+ if (is_last) { out_ids.push_back(id); }
127
+ }
128
+ }
129
+ if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
+ ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
+ }
132
+ ubatch.n_tokens += length;
133
+ ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
+ seq.offset += length;
135
+ seq.length -= length;
136
+ n_tokens -= length;
137
+ GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
+ }
139
+
140
+ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
+ ubatch.equal_seqs = false;
144
+ if (!seq.empty()) {
145
+ llama_sbatch_seq & s = seq[0];
146
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
+ GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
+ add_seq_to_ubatch(ubatch, s, length);
149
+ }
150
+ return ubatch;
151
+ }
152
+
153
+ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
+ if (!seq.empty()) {
157
+ size_t length = 0;
158
+ size_t n_tokens_in_ubatch = 0;
159
+ GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
+ // smallest first, because it's easier to split this way;
161
+ // starting from the end to pop in constant time.
162
+ for (size_t i = seq.size(); i-- > 0;) {
163
+ llama_sbatch_seq & s = seq[i];
164
+ GGML_ASSERT(s.length > 0);
165
+ if (length == 0) {
166
+ length = s.length < n_ubatch ? s.length : n_ubatch;
167
+ }
168
+ add_seq_to_ubatch(ubatch, s, length);
169
+ n_tokens_in_ubatch += length;
170
+ // shared prompts can't be mixed with any of their sequences,
171
+ // so it's safer to compute them in their own ubatch
172
+ if (s.n_seq_id > 1) { break; }
173
+ // stop when there isn't enough space for another sequence
174
+ if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
+ }
176
+ }
177
+ return ubatch;
178
+ }
179
+
180
+ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
+ if (!seq.empty()) {
184
+ llama_sbatch_seq & s = seq[seq.size() - 1];
185
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
+ GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
+ add_seq_to_ubatch(ubatch, s, length);
188
+ }
189
+ return ubatch;
190
+ }
191
+
192
+ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
+ GGML_ASSERT(batch.n_tokens >= 0);
194
+ this->batch = &batch;
195
+ this->n_embd = n_embd;
196
+ this->logits_all = logits_all;
197
+
198
+ n_tokens = batch.n_tokens;
199
+ ids.resize(n_tokens);
200
+ out_ids.clear();
201
+ // TODO: reserve out_ids and seq
202
+
203
+ for (size_t i = 0; i < n_tokens; ++i) {
204
+ ids[i] = i;
205
+ }
206
+ if (simple_split) {
207
+ seq.resize(1);
208
+ llama_sbatch_seq & s = seq[0];
209
+ s.n_seq_id = 0;
210
+ s.seq_id = nullptr;
211
+ s.offset = 0;
212
+ s.length = n_tokens;
213
+ return;
214
+ }
215
+ std::sort(ids.begin(), ids.end(),
216
+ [&batch](size_t a, size_t b) {
217
+ int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
+ int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
+ // sort by seq_id, then by pos
220
+ if (n_seq_a == n_seq_b) {
221
+ if (batch.seq_id) {
222
+ for (int32_t i = 0; i < n_seq_a; ++i) {
223
+ llama_seq_id seq_id_a = batch.seq_id[a][i];
224
+ llama_seq_id seq_id_b = batch.seq_id[b][i];
225
+ // smaller seq_ids go first
226
+ if (seq_id_a != seq_id_b) {
227
+ return seq_id_a < seq_id_b;
228
+ }
229
+ }
230
+ }
231
+ // when all else is equal, sort by pos
232
+ if (batch.pos) {
233
+ return batch.pos[a] < batch.pos[b];
234
+ }
235
+ // no pos, sort by id
236
+ return a < b;
237
+ }
238
+ // shared prompts go first
239
+ return n_seq_a > n_seq_b;
240
+ }
241
+ );
242
+ // init seq
243
+ llama_sbatch_seq * last_seq = nullptr;
244
+
245
+ for (size_t i = 0; i < n_tokens; ++i) {
246
+ const size_t bi = ids[i];
247
+ const int32_t n_seqs = batch.n_seq_id[bi];
248
+ llama_seq_id * seq_ids = batch.seq_id[bi];
249
+ if (last_seq != nullptr) {
250
+ bool same = n_seqs == last_seq->n_seq_id;
251
+ for (int32_t j = 0; same && j < n_seqs; ++j) {
252
+ if (seq_ids[j] != last_seq->seq_id[j]) {
253
+ same = false;
254
+ }
255
+ }
256
+ if (same) {
257
+ last_seq->length += 1;
258
+ continue;
259
+ }
260
+ }
261
+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
+ seq.push_back(new_seq);
263
+ last_seq = &seq.back();
264
+ }
265
+ // keep shared prompts first at the end, then sort by length descending.
266
+ std::sort(seq.begin(), seq.end(),
267
+ [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
+ if (a.n_seq_id == b.n_seq_id) {
269
+ return a.length > b.length;
270
+ }
271
+ return a.n_seq_id < b.n_seq_id;
272
+ }
273
+ );
274
+ }
275
+
276
+ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
+ batch = in_batch;
278
+ GGML_ASSERT(batch.n_tokens > 0);
279
+ if (!batch.pos) {
280
+ pos.resize(batch.n_tokens);
281
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
282
+ pos[i] = i + p0;
283
+ }
284
+ batch.pos = pos.data();
285
+ }
286
+ if (!batch.n_seq_id) {
287
+ n_seq_id.resize(batch.n_tokens);
288
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
289
+ n_seq_id[i] = seq_id_0.size();
290
+ }
291
+ batch.n_seq_id = n_seq_id.data();
292
+ }
293
+ if (!batch.seq_id) {
294
+ seq_id.resize(batch.n_tokens + 1);
295
+ seq_id[batch.n_tokens] = NULL;
296
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
297
+ seq_id[i] = seq_id_0.data();
298
+ }
299
+ batch.seq_id = seq_id.data();
300
+ }
301
+ if (!batch.logits) {
302
+ logits.resize(batch.n_tokens);
303
+ logits[logits.size() - 1] = true;
304
+ batch.logits = logits.data();
305
+ }
306
+ }
307
+
308
+ //
309
+ // interface implementation
310
+ //
311
+
312
+ struct llama_batch llama_batch_get_one(
313
+ llama_token * tokens,
314
+ int32_t n_tokens) {
315
+ return {
316
+ /*n_tokens =*/ n_tokens,
317
+ /*tokens =*/ tokens,
318
+ /*embd =*/ nullptr,
319
+ /*pos =*/ nullptr,
320
+ /*n_seq_id =*/ nullptr,
321
+ /*seq_id =*/ nullptr,
322
+ /*logits =*/ nullptr,
323
+ };
324
+ }
325
+
326
+ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
+ llama_batch batch = {
328
+ /*n_tokens =*/ 0,
329
+ /*tokens =*/ nullptr,
330
+ /*embd =*/ nullptr,
331
+ /*pos =*/ nullptr,
332
+ /*n_seq_id =*/ nullptr,
333
+ /*seq_id =*/ nullptr,
334
+ /*logits =*/ nullptr,
335
+ };
336
+
337
+ if (embd) {
338
+ batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
+ } else {
340
+ batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
+ }
342
+
343
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
+ for (int i = 0; i < n_tokens_alloc; ++i) {
347
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
+ }
349
+ batch.seq_id[n_tokens_alloc] = nullptr;
350
+
351
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
+
353
+ return batch;
354
+ }
355
+
356
+ void llama_batch_free(struct llama_batch batch) {
357
+ if (batch.token) free(batch.token);
358
+ if (batch.embd) free(batch.embd);
359
+ if (batch.pos) free(batch.pos);
360
+ if (batch.n_seq_id) free(batch.n_seq_id);
361
+ if (batch.seq_id) {
362
+ for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
+ free(batch.seq_id[i]);
364
+ }
365
+ free(batch.seq_id);
366
+ }
367
+ if (batch.logits) free(batch.logits);
368
+ }
examples/talk-llama/llama-batch.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+
5
+ #include <array>
6
+ #include <vector>
7
+
8
+ // very similar to llama_batch,
9
+ // but has more metadata about sequences
10
+ struct llama_ubatch {
11
+ bool equal_seqs;
12
+ // TODO: whole_seqs for embeddings?
13
+
14
+ uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
15
+ uint32_t n_seq_tokens; // tokens per sequence
16
+ uint32_t n_seqs;
17
+
18
+ llama_token * token; // [n_tokens]
19
+ float * embd; // [n_embd, n_tokens]
20
+ llama_pos * pos; // [n_tokens]
21
+ int32_t * n_seq_id; // [n_seqs]
22
+ llama_seq_id ** seq_id; // [n_seqs]
23
+ int8_t * output; // [n_tokens]
24
+ };
25
+
26
+ struct llama_sbatch_seq {
27
+ int32_t n_seq_id;
28
+
29
+ llama_seq_id * seq_id;
30
+
31
+ size_t offset;
32
+ size_t length;
33
+ };
34
+
35
+ // sequence-length-aware batch splitting
36
+ struct llama_sbatch {
37
+ // tokens left in this batch
38
+ size_t n_tokens;
39
+
40
+ size_t n_embd;
41
+
42
+ bool logits_all; // TODO: remove once lctx.logits_all is removed too
43
+
44
+ // sorted indices into the batch
45
+ std::vector<size_t> ids;
46
+ // batch indices of the output
47
+ std::vector<size_t> out_ids;
48
+ std::vector<llama_sbatch_seq> seq;
49
+
50
+ const llama_batch * batch = nullptr;
51
+
52
+ // buffers for the ubatch
53
+ std::vector<llama_token> ubatch_token;
54
+ std::vector<float> ubatch_embd;
55
+ std::vector<llama_pos> ubatch_pos;
56
+ std::vector<int32_t> ubatch_n_seq_id;
57
+ std::vector<llama_seq_id *> ubatch_seq_id;
58
+ std::vector<int8_t> ubatch_output;
59
+
60
+ llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
61
+
62
+ void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
63
+
64
+ // simple split, unknown number of sequences of unequal lengths
65
+ llama_ubatch split_simple(size_t n_ubatch);
66
+
67
+ // make batches of equal-length sequences
68
+ llama_ubatch split_equal(size_t n_ubatch);
69
+
70
+ // sequence-wise split
71
+ llama_ubatch split_seq(size_t n_ubatch);
72
+
73
+ void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
74
+ };
75
+
76
+ // temporary allocate memory for the input batch if needed
77
+ struct llama_batch_allocr {
78
+ struct llama_batch batch;
79
+
80
+ std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
81
+ std::vector<llama_pos> pos;
82
+ std::vector<int32_t> n_seq_id;
83
+ std::vector<llama_seq_id *> seq_id;
84
+ std::vector<int8_t> logits;
85
+
86
+ // optionally fulfill the batch returned by llama_batch_get_one
87
+ llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
88
+ };
examples/talk-llama/llama-chat.cpp ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-chat.h"
2
+
3
+ #include "llama.h"
4
+
5
+ #include <map>
6
+ #include <sstream>
7
+
8
+ #if __cplusplus >= 202000L
9
+ #define LU8(x) (const char*)(u8##x)
10
+ #else
11
+ #define LU8(x) u8##x
12
+ #endif
13
+
14
+ // trim whitespace from the beginning and end of a string
15
+ static std::string trim(const std::string & str) {
16
+ size_t start = 0;
17
+ size_t end = str.size();
18
+ while (start < end && isspace(str[start])) {
19
+ start += 1;
20
+ }
21
+ while (end > start && isspace(str[end - 1])) {
22
+ end -= 1;
23
+ }
24
+ return str.substr(start, end - start);
25
+ }
26
+
27
+ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
28
+ { "chatml", LLM_CHAT_TEMPLATE_CHATML },
29
+ { "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 },
30
+ { "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS },
31
+ { "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS },
32
+ { "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP },
33
+ { "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 },
34
+ { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
35
+ { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
36
+ { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
37
+ { "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
38
+ { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
39
+ { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR },
40
+ { "monarch", LLM_CHAT_TEMPLATE_MONARCH },
41
+ { "gemma", LLM_CHAT_TEMPLATE_GEMMA },
42
+ { "orion", LLM_CHAT_TEMPLATE_ORION },
43
+ { "openchat", LLM_CHAT_TEMPLATE_OPENCHAT },
44
+ { "vicuna", LLM_CHAT_TEMPLATE_VICUNA },
45
+ { "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA },
46
+ { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK },
47
+ { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 },
48
+ { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
49
+ { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
50
+ { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
51
+ { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
52
+ { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
53
+ { "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
54
+ { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
55
+ { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
56
+ { "granite", LLM_CHAT_TEMPLATE_GRANITE },
57
+ { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
58
+ { "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
59
+ };
60
+
61
+ llm_chat_template llm_chat_template_from_str(const std::string & name) {
62
+ return LLM_CHAT_TEMPLATES.at(name);
63
+ }
64
+
65
+ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
66
+ try {
67
+ return llm_chat_template_from_str(tmpl);
68
+ } catch (const std::out_of_range &) {
69
+ // ignore
70
+ }
71
+
72
+ auto tmpl_contains = [&tmpl](const char * haystack) -> bool {
73
+ return tmpl.find(haystack) != std::string::npos;
74
+ };
75
+ if (tmpl_contains("<|im_start|>")) {
76
+ return LLM_CHAT_TEMPLATE_CHATML;
77
+ } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
78
+ if (tmpl_contains("[SYSTEM_PROMPT]")) {
79
+ return LLM_CHAT_TEMPLATE_MISTRAL_V7;
80
+ } else if (
81
+ // catches official 'v1' template
82
+ tmpl_contains("' [INST] ' + system_message")
83
+ // catches official 'v3' and 'v3-tekken' templates
84
+ || tmpl_contains("[AVAILABLE_TOOLS]")
85
+ ) {
86
+ // Official mistral 'v1', 'v3' and 'v3-tekken' templates
87
+ // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
88
+ // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
89
+ if (tmpl_contains(" [INST]")) {
90
+ return LLM_CHAT_TEMPLATE_MISTRAL_V1;
91
+ } else if (tmpl_contains("\"[INST]\"")) {
92
+ return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN;
93
+ }
94
+ return LLM_CHAT_TEMPLATE_MISTRAL_V3;
95
+ } else {
96
+ // llama2 template and its variants
97
+ // [variant] support system message
98
+ // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
99
+ bool support_system_message = tmpl_contains("<<SYS>>");
100
+ bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
101
+ bool strip_message = tmpl_contains("content.strip()");
102
+ if (strip_message) {
103
+ return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP;
104
+ } else if (add_bos_inside_history) {
105
+ return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS;
106
+ } else if (support_system_message) {
107
+ return LLM_CHAT_TEMPLATE_LLAMA_2_SYS;
108
+ } else {
109
+ return LLM_CHAT_TEMPLATE_LLAMA_2;
110
+ }
111
+ }
112
+ } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
113
+ return LLM_CHAT_TEMPLATE_PHI_3;
114
+ } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
115
+ return LLM_CHAT_TEMPLATE_FALCON_3;
116
+ } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
117
+ return LLM_CHAT_TEMPLATE_ZEPHYR;
118
+ } else if (tmpl_contains("bos_token + message['role']")) {
119
+ return LLM_CHAT_TEMPLATE_MONARCH;
120
+ } else if (tmpl_contains("<start_of_turn>")) {
121
+ return LLM_CHAT_TEMPLATE_GEMMA;
122
+ } else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
123
+ // OrionStarAI/Orion-14B-Chat
124
+ return LLM_CHAT_TEMPLATE_ORION;
125
+ } else if (tmpl_contains("GPT4 Correct ")) {
126
+ // openchat/openchat-3.5-0106
127
+ return LLM_CHAT_TEMPLATE_OPENCHAT;
128
+ } else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) {
129
+ // eachadea/vicuna-13b-1.1 (and Orca variant)
130
+ if (tmpl_contains("SYSTEM: ")) {
131
+ return LLM_CHAT_TEMPLATE_VICUNA_ORCA;
132
+ }
133
+ return LLM_CHAT_TEMPLATE_VICUNA;
134
+ } else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) {
135
+ // deepseek-ai/deepseek-coder-33b-instruct
136
+ return LLM_CHAT_TEMPLATE_DEEPSEEK;
137
+ } else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) {
138
+ // CohereForAI/c4ai-command-r-plus
139
+ return LLM_CHAT_TEMPLATE_COMMAND_R;
140
+ } else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) {
141
+ return LLM_CHAT_TEMPLATE_LLAMA_3;
142
+ } else if (tmpl_contains("[gMASK]sop")) {
143
+ // chatglm3-6b
144
+ return LLM_CHAT_TEMPLATE_CHATGML_3;
145
+ } else if (tmpl_contains("[gMASK]<sop>")) {
146
+ return LLM_CHAT_TEMPLATE_CHATGML_4;
147
+ } else if (tmpl_contains(LU8("<用户>"))) {
148
+ // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
149
+ return LLM_CHAT_TEMPLATE_MINICPM;
150
+ } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
151
+ return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
152
+ } else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) {
153
+ return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
154
+ } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
155
+ // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
156
+ // EXAONE-3.0-7.8B-Instruct
157
+ return LLM_CHAT_TEMPLATE_EXAONE_3;
158
+ } else if (tmpl_contains("rwkv-world")) {
159
+ return LLM_CHAT_TEMPLATE_RWKV_WORLD;
160
+ } else if (tmpl_contains("<|start_of_role|>")) {
161
+ return LLM_CHAT_TEMPLATE_GRANITE;
162
+ } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) {
163
+ return LLM_CHAT_TEMPLATE_GIGACHAT;
164
+ } else if (tmpl_contains("<|role_start|>")) {
165
+ return LLM_CHAT_TEMPLATE_MEGREZ;
166
+ }
167
+ return LLM_CHAT_TEMPLATE_UNKNOWN;
168
+ }
169
+
170
+ // Simple version of "llama_apply_chat_template" that only works with strings
171
+ // This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
172
+ int32_t llm_chat_apply_template(
173
+ llm_chat_template tmpl,
174
+ const std::vector<const llama_chat_message *> & chat,
175
+ std::string & dest, bool add_ass) {
176
+ // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
177
+ std::stringstream ss;
178
+ if (tmpl == LLM_CHAT_TEMPLATE_CHATML) {
179
+ // chatml template
180
+ for (auto message : chat) {
181
+ ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
182
+ }
183
+ if (add_ass) {
184
+ ss << "<|im_start|>assistant\n";
185
+ }
186
+ } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
187
+ // Official mistral 'v7' template
188
+ // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
189
+ for (auto message : chat) {
190
+ std::string role(message->role);
191
+ std::string content(message->content);
192
+ if (role == "system") {
193
+ ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
194
+ } else if (role == "user") {
195
+ ss << "[INST] " << content << "[/INST]";
196
+ }
197
+ else {
198
+ ss << " " << content << "</s>";
199
+ }
200
+ }
201
+ } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
202
+ || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3
203
+ || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) {
204
+ // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
205
+ // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
206
+ std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : "";
207
+ std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " ";
208
+ bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3;
209
+ bool is_inside_turn = false;
210
+ for (auto message : chat) {
211
+ if (!is_inside_turn) {
212
+ ss << leading_space << "[INST]" << trailing_space;
213
+ is_inside_turn = true;
214
+ }
215
+ std::string role(message->role);
216
+ std::string content(message->content);
217
+ if (role == "system") {
218
+ ss << content << "\n\n";
219
+ } else if (role == "user") {
220
+ ss << content << leading_space << "[/INST]";
221
+ } else {
222
+ ss << trailing_space << (trim_assistant_message ? trim(content) : content) << "</s>";
223
+ is_inside_turn = false;
224
+ }
225
+ }
226
+ } else if (
227
+ tmpl == LLM_CHAT_TEMPLATE_LLAMA_2
228
+ || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS
229
+ || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS
230
+ || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) {
231
+ // llama2 template and its variants
232
+ // [variant] support system message
233
+ // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
234
+ bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2;
235
+ // [variant] add BOS inside history
236
+ bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS;
237
+ // [variant] trim spaces from the input message
238
+ bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP;
239
+ // construct the prompt
240
+ bool is_inside_turn = true; // skip BOS at the beginning
241
+ ss << "[INST] ";
242
+ for (auto message : chat) {
243
+ std::string content = strip_message ? trim(message->content) : message->content;
244
+ std::string role(message->role);
245
+ if (!is_inside_turn) {
246
+ is_inside_turn = true;
247
+ ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
248
+ }
249
+ if (role == "system") {
250
+ if (support_system_message) {
251
+ ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
252
+ } else {
253
+ // if the model does not support system message, we still include it in the first message, but without <<SYS>>
254
+ ss << content << "\n";
255
+ }
256
+ } else if (role == "user") {
257
+ ss << content << " [/INST]";
258
+ } else {
259
+ ss << content << "</s>";
260
+ is_inside_turn = false;
261
+ }
262
+ }
263
+ } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) {
264
+ // Phi 3
265
+ for (auto message : chat) {
266
+ std::string role(message->role);
267
+ ss << "<|" << role << "|>\n" << message->content << "<|end|>\n";
268
+ }
269
+ if (add_ass) {
270
+ ss << "<|assistant|>\n";
271
+ }
272
+ } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) {
273
+ // Falcon 3
274
+ for (auto message : chat) {
275
+ std::string role(message->role);
276
+ ss << "<|" << role << "|>\n" << message->content << "\n";
277
+ }
278
+ if (add_ass) {
279
+ ss << "<|assistant|>\n";
280
+ }
281
+ } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) {
282
+ // zephyr template
283
+ for (auto message : chat) {
284
+ ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
285
+ }
286
+ if (add_ass) {
287
+ ss << "<|assistant|>\n";
288
+ }
289
+ } else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) {
290
+ // mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
291
+ for (auto message : chat) {
292
+ std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
293
+ ss << bos << message->role << "\n" << message->content << "</s>\n";
294
+ }
295
+ if (add_ass) {
296
+ ss << "<s>assistant\n";
297
+ }
298
+ } else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) {
299
+ // google/gemma-7b-it
300
+ std::string system_prompt = "";
301
+ for (auto message : chat) {
302
+ std::string role(message->role);
303
+ if (role == "system") {
304
+ // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
305
+ system_prompt = trim(message->content);
306
+ continue;
307
+ }
308
+ // in gemma, "assistant" is "model"
309
+ role = role == "assistant" ? "model" : message->role;
310
+ ss << "<start_of_turn>" << role << "\n";
311
+ if (!system_prompt.empty() && role != "model") {
312
+ ss << system_prompt << "\n\n";
313
+ system_prompt = "";
314
+ }
315
+ ss << trim(message->content) << "<end_of_turn>\n";
316
+ }
317
+ if (add_ass) {
318
+ ss << "<start_of_turn>model\n";
319
+ }
320
+ } else if (tmpl == LLM_CHAT_TEMPLATE_ORION) {
321
+ // OrionStarAI/Orion-14B-Chat
322
+ std::string system_prompt = "";
323
+ for (auto message : chat) {
324
+ std::string role(message->role);
325
+ if (role == "system") {
326
+ // there is no system message support, we will merge it with user prompt
327
+ system_prompt = message->content;
328
+ continue;
329
+ } else if (role == "user") {
330
+ ss << "Human: ";
331
+ if (!system_prompt.empty()) {
332
+ ss << system_prompt << "\n\n";
333
+ system_prompt = "";
334
+ }
335
+ ss << message->content << "\n\nAssistant: </s>";
336
+ } else {
337
+ ss << message->content << "</s>";
338
+ }
339
+ }
340
+ } else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) {
341
+ // openchat/openchat-3.5-0106,
342
+ for (auto message : chat) {
343
+ std::string role(message->role);
344
+ if (role == "system") {
345
+ ss << message->content << "<|end_of_turn|>";
346
+ } else {
347
+ role[0] = toupper(role[0]);
348
+ ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
349
+ }
350
+ }
351
+ if (add_ass) {
352
+ ss << "GPT4 Correct Assistant:";
353
+ }
354
+ } else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) {
355
+ // eachadea/vicuna-13b-1.1 (and Orca variant)
356
+ for (auto message : chat) {
357
+ std::string role(message->role);
358
+ if (role == "system") {
359
+ // Orca-Vicuna variant uses a system prefix
360
+ if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) {
361
+ ss << "SYSTEM: " << message->content << "\n";
362
+ } else {
363
+ ss << message->content << "\n\n";
364
+ }
365
+ } else if (role == "user") {
366
+ ss << "USER: " << message->content << "\n";
367
+ } else if (role == "assistant") {
368
+ ss << "ASSISTANT: " << message->content << "</s>\n";
369
+ }
370
+ }
371
+ if (add_ass) {
372
+ ss << "ASSISTANT:";
373
+ }
374
+ } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) {
375
+ // deepseek-ai/deepseek-coder-33b-instruct
376
+ for (auto message : chat) {
377
+ std::string role(message->role);
378
+ if (role == "system") {
379
+ ss << message->content;
380
+ } else if (role == "user") {
381
+ ss << "### Instruction:\n" << message->content << "\n";
382
+ } else if (role == "assistant") {
383
+ ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
384
+ }
385
+ }
386
+ if (add_ass) {
387
+ ss << "### Response:\n";
388
+ }
389
+ } else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) {
390
+ // CohereForAI/c4ai-command-r-plus
391
+ for (auto message : chat) {
392
+ std::string role(message->role);
393
+ if (role == "system") {
394
+ ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
395
+ } else if (role == "user") {
396
+ ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
397
+ } else if (role == "assistant") {
398
+ ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
399
+ }
400
+ }
401
+ if (add_ass) {
402
+ ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
403
+ }
404
+ } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) {
405
+ // Llama 3
406
+ for (auto message : chat) {
407
+ std::string role(message->role);
408
+ ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
409
+ }
410
+ if (add_ass) {
411
+ ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
412
+ }
413
+ } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
414
+ // chatglm3-6b
415
+ ss << "[gMASK]" << "sop";
416
+ for (auto message : chat) {
417
+ std::string role(message->role);
418
+ ss << "<|" << role << "|>" << "\n " << message->content;
419
+ }
420
+ if (add_ass) {
421
+ ss << "<|assistant|>";
422
+ }
423
+ } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
424
+ ss << "[gMASK]" << "<sop>";
425
+ for (auto message : chat) {
426
+ std::string role(message->role);
427
+ ss << "<|" << role << "|>" << "\n" << message->content;
428
+ }
429
+ if (add_ass) {
430
+ ss << "<|assistant|>";
431
+ }
432
+ } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
433
+ // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
434
+ for (auto message : chat) {
435
+ std::string role(message->role);
436
+ if (role == "user") {
437
+ ss << LU8("<用户>");
438
+ ss << trim(message->content);
439
+ ss << "<AI>";
440
+ } else {
441
+ ss << trim(message->content);
442
+ }
443
+ }
444
+ } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) {
445
+ // DeepSeek-V2
446
+ for (auto message : chat) {
447
+ std::string role(message->role);
448
+ if (role == "system") {
449
+ ss << message->content << "\n\n";
450
+ } else if (role == "user") {
451
+ ss << "User: " << message->content << "\n\n";
452
+ } else if (role == "assistant") {
453
+ ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>");
454
+ }
455
+ }
456
+ if (add_ass) {
457
+ ss << "Assistant:";
458
+ }
459
+ } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) {
460
+ // DeepSeek-V3
461
+ for (auto message : chat) {
462
+ std::string role(message->role);
463
+ if (role == "system") {
464
+ ss << message->content << "\n\n";
465
+ } else if (role == "user") {
466
+ ss << LU8("<|User|>") << message->content;
467
+ } else if (role == "assistant") {
468
+ ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>");
469
+ }
470
+ }
471
+ if (add_ass) {
472
+ ss << LU8("<|Assistant|>");
473
+ }
474
+ } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
475
+ // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
476
+ // EXAONE-3.0-7.8B-Instruct
477
+ for (auto message : chat) {
478
+ std::string role(message->role);
479
+ if (role == "system") {
480
+ ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
481
+ } else if (role == "user") {
482
+ ss << "[|user|]" << trim(message->content) << "\n";
483
+ } else if (role == "assistant") {
484
+ ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
485
+ }
486
+ }
487
+ if (add_ass) {
488
+ ss << "[|assistant|]";
489
+ }
490
+ } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
491
+ // this template requires the model to have "\n\n" as EOT token
492
+ for (auto message : chat) {
493
+ std::string role(message->role);
494
+ if (role == "user") {
495
+ ss << "User: " << message->content << "\n\nAssistant:";
496
+ } else {
497
+ ss << message->content << "\n\n";
498
+ }
499
+ }
500
+ } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
501
+ // IBM Granite template
502
+ for (const auto & message : chat) {
503
+ std::string role(message->role);
504
+ ss << "<|start_of_role|>" << role << "<|end_of_role|>";
505
+ if (role == "assistant_tool_call") {
506
+ ss << "<|tool_call|>";
507
+ }
508
+ ss << message->content << "<|end_of_text|>\n";
509
+ }
510
+ if (add_ass) {
511
+ ss << "<|start_of_role|>assistant<|end_of_role|>\n";
512
+ }
513
+ } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
514
+ // GigaChat template
515
+ bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
516
+
517
+ // Handle system message if present
518
+ if (has_system) {
519
+ ss << "<s>" << chat[0]->content << "<|message_sep|>";
520
+ } else {
521
+ ss << "<s>";
522
+ }
523
+
524
+ // Process remaining messages
525
+ for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) {
526
+ std::string role(chat[i]->role);
527
+ if (role == "user") {
528
+ ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>"
529
+ << "available functions<|role_sep|>[]<|message_sep|>";
530
+ } else if (role == "assistant") {
531
+ ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>";
532
+ }
533
+ }
534
+
535
+ // Add generation prompt if needed
536
+ if (add_ass) {
537
+ ss << "assistant<|role_sep|>";
538
+ }
539
+ } else if (tmpl == LLM_CHAT_TEMPLATE_MEGREZ) {
540
+ // Megrez template
541
+ for (auto message : chat) {
542
+ std::string role(message->role);
543
+ ss << "<|role_start|>" << role << "<|role_end|>" << message->content << "<|turn_end|>";
544
+ }
545
+
546
+ if (add_ass) {
547
+ ss << "<|role_start|>assistant<|role_end|>";
548
+ }
549
+ } else {
550
+ // template not supported
551
+ return -1;
552
+ }
553
+ dest = ss.str();
554
+ return dest.size();
555
+ }
556
+
557
+ // public interface
558
+
559
+ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
560
+ auto it = LLM_CHAT_TEMPLATES.begin();
561
+ for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) {
562
+ output[i] = it->first.c_str();
563
+ std::advance(it, 1);
564
+ }
565
+ return (int32_t) LLM_CHAT_TEMPLATES.size();
566
+ }
567
+
examples/talk-llama/llama-chat.h ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <string>
4
+ #include <vector>
5
+ #include <cstdint>
6
+
7
+ enum llm_chat_template {
8
+ LLM_CHAT_TEMPLATE_CHATML,
9
+ LLM_CHAT_TEMPLATE_LLAMA_2,
10
+ LLM_CHAT_TEMPLATE_LLAMA_2_SYS,
11
+ LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS,
12
+ LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP,
13
+ LLM_CHAT_TEMPLATE_MISTRAL_V1,
14
+ LLM_CHAT_TEMPLATE_MISTRAL_V3,
15
+ LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
16
+ LLM_CHAT_TEMPLATE_MISTRAL_V7,
17
+ LLM_CHAT_TEMPLATE_PHI_3,
18
+ LLM_CHAT_TEMPLATE_FALCON_3,
19
+ LLM_CHAT_TEMPLATE_ZEPHYR,
20
+ LLM_CHAT_TEMPLATE_MONARCH,
21
+ LLM_CHAT_TEMPLATE_GEMMA,
22
+ LLM_CHAT_TEMPLATE_ORION,
23
+ LLM_CHAT_TEMPLATE_OPENCHAT,
24
+ LLM_CHAT_TEMPLATE_VICUNA,
25
+ LLM_CHAT_TEMPLATE_VICUNA_ORCA,
26
+ LLM_CHAT_TEMPLATE_DEEPSEEK,
27
+ LLM_CHAT_TEMPLATE_DEEPSEEK_2,
28
+ LLM_CHAT_TEMPLATE_DEEPSEEK_3,
29
+ LLM_CHAT_TEMPLATE_COMMAND_R,
30
+ LLM_CHAT_TEMPLATE_LLAMA_3,
31
+ LLM_CHAT_TEMPLATE_CHATGML_3,
32
+ LLM_CHAT_TEMPLATE_CHATGML_4,
33
+ LLM_CHAT_TEMPLATE_MINICPM,
34
+ LLM_CHAT_TEMPLATE_EXAONE_3,
35
+ LLM_CHAT_TEMPLATE_RWKV_WORLD,
36
+ LLM_CHAT_TEMPLATE_GRANITE,
37
+ LLM_CHAT_TEMPLATE_GIGACHAT,
38
+ LLM_CHAT_TEMPLATE_MEGREZ,
39
+ LLM_CHAT_TEMPLATE_UNKNOWN,
40
+ };
41
+
42
+ struct llama_chat_message;
43
+
44
+ llm_chat_template llm_chat_template_from_str(const std::string & name);
45
+
46
+ llm_chat_template llm_chat_detect_template(const std::string & tmpl);
47
+
48
+ int32_t llm_chat_apply_template(
49
+ llm_chat_template tmpl,
50
+ const std::vector<const llama_chat_message *> & chat,
51
+ std::string & dest, bool add_ass);
examples/talk-llama/llama-context.cpp ADDED
@@ -0,0 +1,1771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-context.h"
2
+
3
+ #include <cassert>
4
+ #include <cmath>
5
+ #include <cstring>
6
+ #include <stdexcept>
7
+
8
+ void llama_set_k_shift(struct llama_context & lctx) {
9
+ const int64_t kv_size = lctx.kv_self.size;
10
+
11
+ assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
12
+
13
+ int32_t * data = (int32_t *) lctx.inp_K_shift->data;
14
+
15
+ for (int i = 0; i < kv_size; ++i) {
16
+ data[i] = lctx.kv_self.cells[i].delta;
17
+ }
18
+ }
19
+
20
+ void llama_set_s_copy(struct llama_context & lctx) {
21
+ const int64_t kv_size = lctx.kv_self.size;
22
+
23
+ assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
24
+
25
+ int32_t * data = (int32_t *) lctx.inp_s_copy->data;
26
+
27
+ for (int i = 0; i < kv_size; ++i) {
28
+ data[i] = lctx.kv_self.cells[i].src;
29
+ }
30
+ }
31
+
32
+ // llama input
33
+
34
+ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
35
+ // TODO move to hparams if a T5 variant appears that uses a different value
36
+ const int64_t max_distance = 128;
37
+
38
+ if (bidirectional) {
39
+ n_buckets >>= 1;
40
+ }
41
+
42
+ const int64_t max_exact = n_buckets >> 1;
43
+
44
+ int32_t relative_position = x - y;
45
+ int32_t relative_bucket = 0;
46
+ if (bidirectional) {
47
+ relative_bucket += (relative_position > 0) * n_buckets;
48
+ relative_position = abs(relative_position);
49
+ } else {
50
+ relative_position = -std::min<int32_t>(relative_position, 0);
51
+ }
52
+ int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
53
+ relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
54
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
55
+ return relative_bucket;
56
+ }
57
+
58
+ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
59
+ //
60
+ // set input data
61
+ //
62
+
63
+ const auto & hparams = lctx.model.hparams;
64
+ const auto & cparams = lctx.cparams;
65
+ const auto & kv_self = lctx.kv_self;
66
+
67
+ if (ubatch.token) {
68
+ const int64_t n_tokens = ubatch.n_tokens;
69
+
70
+ ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
71
+ }
72
+
73
+ if (ubatch.embd) {
74
+ const int64_t n_embd = hparams.n_embd;
75
+ const int64_t n_tokens = ubatch.n_tokens;
76
+
77
+ ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
78
+ }
79
+
80
+ if (ubatch.pos && lctx.inp_pos) {
81
+ const int64_t n_tokens = ubatch.n_tokens;
82
+ auto n_pos = lctx.n_pos_per_token;
83
+ ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos));
84
+ }
85
+
86
+ if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
87
+ //GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
88
+
89
+ if (!lctx.inp_out_ids) {
90
+ LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__);
91
+ } else {
92
+ const int64_t n_tokens = ubatch.n_tokens;
93
+
94
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
95
+ int32_t * data = (int32_t *) lctx.inp_out_ids->data;
96
+
97
+ if (lctx.n_outputs == n_tokens) {
98
+ for (int i = 0; i < n_tokens; ++i) {
99
+ data[i] = i;
100
+ }
101
+ } else if (ubatch.output) {
102
+ int32_t n_outputs = 0;
103
+ for (int i = 0; i < n_tokens; ++i) {
104
+ if (ubatch.output[i]) {
105
+ data[n_outputs++] = i;
106
+ }
107
+ }
108
+ // the graph needs to have been passed the correct number of outputs
109
+ GGML_ASSERT(lctx.n_outputs == n_outputs);
110
+ } else if (lctx.n_outputs == 1) {
111
+ // only keep last output
112
+ data[0] = n_tokens - 1;
113
+ } else {
114
+ GGML_ASSERT(lctx.n_outputs == 0);
115
+ }
116
+ }
117
+ }
118
+
119
+ GGML_ASSERT(
120
+ // (!a || b) is a logical implication (a -> b)
121
+ // !hparams.causal_attn -> !cparams.causal_attn
122
+ (hparams.causal_attn || !cparams.causal_attn) &&
123
+ "causal attention is not supported by this model"
124
+ );
125
+
126
+ if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
127
+ // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
128
+ if (cparams.causal_attn && !lctx.is_encoding) {
129
+ const int64_t n_kv = kv_self.n;
130
+ const int64_t n_tokens = ubatch.n_tokens;
131
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
132
+ const int64_t n_seqs = ubatch.n_seqs;
133
+
134
+
135
+ float * data = nullptr;
136
+ float * data_swa = nullptr;
137
+
138
+ if (lctx.inp_KQ_mask) {
139
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
140
+ data = (float *) lctx.inp_KQ_mask->data;
141
+ }
142
+
143
+ if (lctx.inp_KQ_mask_swa) {
144
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
145
+ data_swa = (float *) lctx.inp_KQ_mask_swa->data;
146
+ }
147
+
148
+ // For causal attention, use only the previous KV cells
149
+ // of the correct sequence for each token of the ubatch.
150
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
151
+ for (int h = 0; h < 1; ++h) {
152
+ for (int s = 0; s < n_seqs; ++s) {
153
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
154
+
155
+ for (int j = 0; j < n_seq_tokens; ++j) {
156
+ const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
157
+
158
+ for (int i = 0; i < n_kv; ++i) {
159
+ float f;
160
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
161
+ f = -INFINITY;
162
+ } else {
163
+ if (hparams.use_alibi) {
164
+ f = -std::abs(kv_self.cells[i].pos - pos);
165
+ } else {
166
+ f = 0.0f;
167
+ }
168
+ }
169
+
170
+ if (data) {
171
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
172
+ }
173
+
174
+ // may need to cut off old tokens for sliding window
175
+ if (data_swa) {
176
+ if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
177
+ f = -INFINITY;
178
+ }
179
+ data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
180
+ }
181
+ }
182
+ }
183
+ }
184
+
185
+ if (data) {
186
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
187
+ for (int j = 0; j < n_kv; ++j) {
188
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
189
+ }
190
+ }
191
+ }
192
+
193
+ if (data_swa) {
194
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
195
+ for (int j = 0; j < n_kv; ++j) {
196
+ data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
197
+ }
198
+ }
199
+ }
200
+ }
201
+ } else {
202
+ const int64_t n_tokens = ubatch.n_tokens;
203
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
204
+ const int64_t n_seqs = ubatch.n_seqs;
205
+ // when using kv cache, the mask needs to match the kv cache size
206
+ const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
207
+
208
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
209
+
210
+ float * data = (float *) lctx.inp_KQ_mask->data;
211
+
212
+ for (int h = 0; h < 1; ++h) {
213
+ for (int s1 = 0; s1 < n_seqs; ++s1) {
214
+ const llama_seq_id seq_id = ubatch.seq_id[s1][0];
215
+
216
+ for (int j = 0; j < n_seq_tokens; ++j) {
217
+ const int32_t tj = s1*n_seq_tokens + j;
218
+
219
+ for (int s0 = 0; s0 < n_seqs; ++s0) {
220
+ for (int i = 0; i < n_seq_tokens; ++i) {
221
+ const int32_t ti = s0*n_seq_tokens + i;
222
+ float f = -INFINITY;
223
+
224
+ for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
225
+ if (ubatch.seq_id[s0][s] == seq_id) {
226
+ if (hparams.use_alibi) {
227
+ f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
228
+ } else {
229
+ f = 0.0f;
230
+ }
231
+ break;
232
+ }
233
+ }
234
+
235
+ data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
236
+ }
237
+ }
238
+
239
+ for (int i = n_tokens; i < n_stride; ++i) {
240
+ data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
241
+ }
242
+ }
243
+ }
244
+ }
245
+ }
246
+ }
247
+
248
+ if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
249
+ const int64_t n_tokens = ubatch.n_tokens;
250
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
251
+ const int64_t n_seqs = ubatch.n_seqs;
252
+
253
+ GGML_ASSERT(lctx.inp_mean);
254
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
255
+
256
+ float * data = (float *) lctx.inp_mean->data;
257
+ memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
258
+
259
+ std::vector<uint64_t> sum(n_tokens, 0);
260
+
261
+ for (int s = 0; s < n_seqs; ++s) {
262
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
263
+
264
+ // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
265
+ GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
266
+
267
+ sum[seq_id] += ubatch.n_seq_tokens;
268
+ }
269
+
270
+ std::vector<float> div(n_tokens, 0.0f);
271
+ for (int i = 0; i < n_tokens; ++i) {
272
+ const uint64_t s = sum[i];
273
+ if (s > 0) {
274
+ div[i] = 1.0f/float(s);
275
+ }
276
+ }
277
+
278
+ for (int s = 0; s < n_seqs; ++s) {
279
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
280
+
281
+ for (int i = 0; i < n_seq_tokens; ++i) {
282
+ data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
283
+ }
284
+ }
285
+ }
286
+
287
+ if (cparams.embeddings && (
288
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
289
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
290
+ const int64_t n_tokens = ubatch.n_tokens;
291
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
292
+ const int64_t n_seqs = ubatch.n_seqs;
293
+
294
+ GGML_ASSERT(lctx.inp_cls);
295
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
296
+
297
+ uint32_t * data = (uint32_t *) lctx.inp_cls->data;
298
+ memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
299
+
300
+ for (int s = 0; s < n_seqs; ++s) {
301
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
302
+
303
+ // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
304
+ GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
305
+
306
+ for (int i = 0; i < n_seq_tokens; ++i) {
307
+ const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
308
+
309
+ if (pos == 0) {
310
+ data[seq_id] = s*n_seq_tokens + i;
311
+ }
312
+ }
313
+ }
314
+ }
315
+
316
+ if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
317
+ const int64_t n_tokens = ubatch.n_tokens;
318
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
319
+ const int64_t n_seqs = ubatch.n_seqs;
320
+
321
+ GGML_ASSERT(lctx.inp_cls);
322
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
323
+
324
+ uint32_t * data = (uint32_t *) lctx.inp_cls->data;
325
+ memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
326
+
327
+ std::vector<int> last_pos(n_tokens, -1);
328
+ std::vector<int> last_row(n_tokens, -1);
329
+
330
+ for (int s = 0; s < n_seqs; ++s) {
331
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
332
+
333
+ // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
334
+ GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
335
+
336
+ for (int i = 0; i < n_seq_tokens; ++i) {
337
+ const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
338
+
339
+ if (pos >= last_pos[seq_id]) {
340
+ last_pos[seq_id] = pos;
341
+ last_row[seq_id] = s*n_seq_tokens + i;
342
+ }
343
+ }
344
+ }
345
+
346
+ for (int i = 0; i < n_tokens; ++i) {
347
+ if (last_row[i] >= 0) {
348
+ data[i] = last_row[i];
349
+ }
350
+ }
351
+ }
352
+
353
+ if (kv_self.recurrent) {
354
+ const int64_t n_kv = kv_self.n;
355
+
356
+ if (lctx.inp_s_mask) {
357
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
358
+ float * data = (float *) lctx.inp_s_mask->data;
359
+
360
+ // clear unused states
361
+ for (int i = 0; i < n_kv; ++i) {
362
+ const uint32_t cell_id = i + kv_self.head;
363
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
364
+
365
+ data[i] = (float) (kv_cell.src >= 0);
366
+
367
+ // only clear once
368
+ if (kv_cell.src < 0) {
369
+ kv_cell.src = cell_id;
370
+ }
371
+ }
372
+ }
373
+
374
+ if (lctx.inp_s_copy) {
375
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
376
+ int32_t * data = (int32_t *) lctx.inp_s_copy->data;
377
+
378
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
379
+ for (uint32_t i = 0; i < n_kv; ++i) {
380
+ const uint32_t cell_id = i + kv_self.head;
381
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
382
+
383
+ // prevent out-of-bound sources
384
+ if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
385
+ kv_cell.src = cell_id;
386
+ }
387
+
388
+ data[i] = kv_cell.src;
389
+
390
+ // ensure copy only happens once
391
+ if (kv_cell.src != (int32_t) cell_id) {
392
+ kv_cell.src = cell_id;
393
+ }
394
+ }
395
+ }
396
+ }
397
+
398
+ if (lctx.inp_pos_bucket) {
399
+ const int64_t n_tokens = ubatch.n_tokens;
400
+
401
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
402
+ GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
403
+
404
+ int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
405
+
406
+ if (!lctx.is_encoding) {
407
+ const int64_t n_kv = kv_self.n;
408
+ for (int h = 0; h < 1; ++h) {
409
+ for (int j = 0; j < n_tokens; ++j) {
410
+ for (int i = 0; i < n_kv; ++i) {
411
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
412
+ }
413
+ }
414
+ }
415
+ } else {
416
+ for (int h = 0; h < 1; ++h) {
417
+ for (int j = 0; j < n_tokens; ++j) {
418
+ for (int i = 0; i < n_tokens; ++i) {
419
+ data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
420
+ }
421
+ }
422
+ }
423
+ }
424
+ }
425
+
426
+ if (!lctx.is_encoding && lctx.inp_embd_enc) {
427
+ assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
428
+ assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
429
+
430
+ ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
431
+ }
432
+
433
+ if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
434
+ const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
435
+ const int64_t n_tokens = ubatch.n_tokens;
436
+
437
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
438
+ GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
439
+
440
+ float * data = (float *) lctx.inp_KQ_mask_cross->data;
441
+
442
+ for (int h = 0; h < 1; ++h) {
443
+ for (int j = 0; j < n_tokens; ++j) {
444
+ for (int i = 0; i < n_output_enc; ++i) {
445
+ float f = -INFINITY;
446
+ for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
447
+ const llama_seq_id seq_id = ubatch.seq_id[j][s];
448
+ if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
449
+ f = 0.0f;
450
+ }
451
+ }
452
+ data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
453
+ }
454
+ }
455
+
456
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
457
+ for (int j = 0; j < n_output_enc; ++j) {
458
+ data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
459
+ }
460
+ }
461
+ }
462
+ }
463
+ }
464
+
465
+ // llama output
466
+
467
+ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
468
+ const auto & cparams = lctx.cparams;
469
+ const auto & hparams = lctx.model.hparams;
470
+
471
+ const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
472
+
473
+ const auto n_batch = cparams.n_batch;
474
+ const auto n_vocab = hparams.n_vocab;
475
+ const auto n_embd = hparams.n_embd;
476
+
477
+ // TODO: use a per-batch flag for logits presence instead
478
+ const bool has_logits = !cparams.embeddings;
479
+ const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
480
+
481
+ const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
482
+ const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
483
+
484
+ if (lctx.output_ids.empty()) {
485
+ // init, never resized afterwards
486
+ lctx.output_ids.resize(n_batch);
487
+ }
488
+
489
+ const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0;
490
+ const size_t new_size = (logits_size + embd_size) * sizeof(float);
491
+
492
+ // alloc only when more than the current capacity is required
493
+ // TODO: also consider shrinking the buffer
494
+ if (!lctx.buf_output || prev_size < new_size) {
495
+ if (lctx.buf_output) {
496
+ #ifndef NDEBUG
497
+ // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
498
+ LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
499
+ #endif
500
+ lctx.buf_output = nullptr;
501
+ lctx.logits = nullptr;
502
+ lctx.embd = nullptr;
503
+ }
504
+
505
+ auto * buft = ggml_backend_cpu_buffer_type();
506
+ // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
507
+ auto * output_dev = lctx.model.dev_output.dev;
508
+ auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
509
+ if (output_dev_host_buft) {
510
+ buft = output_dev_host_buft;
511
+ }
512
+ lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
513
+ if (lctx.buf_output == nullptr) {
514
+ LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
515
+ return 0;
516
+ }
517
+ }
518
+
519
+ float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get());
520
+
521
+ lctx.logits = has_logits ? output_base : nullptr;
522
+ lctx.embd = has_embd ? output_base + logits_size : nullptr;
523
+
524
+ lctx.output_size = n_outputs_max;
525
+ lctx.logits_size = logits_size;
526
+ lctx.embd_size = embd_size;
527
+
528
+ // set all ids as invalid (negative)
529
+ std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
530
+
531
+ ggml_backend_buffer_clear(lctx.buf_output.get(), 0);
532
+
533
+ lctx.n_outputs = 0;
534
+
535
+ return n_outputs_max;
536
+ }
537
+
538
+ void llama_output_reorder(struct llama_context & ctx) {
539
+ std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
540
+ if (!out_ids.empty()) {
541
+ const uint32_t n_vocab = ctx.model.hparams.n_vocab;
542
+ const uint32_t n_embd = ctx.model.hparams.n_embd;
543
+
544
+ const int32_t n_outputs = ctx.n_outputs;
545
+ GGML_ASSERT((size_t) n_outputs == out_ids.size());
546
+
547
+ // TODO: is there something more efficient which also minimizes swaps?
548
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
549
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
550
+ int32_t j_min = i;
551
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
552
+ if (out_ids[j] < out_ids[j_min]) {
553
+ j_min = j;
554
+ }
555
+ }
556
+ if (j_min == i) { continue; }
557
+ std::swap(out_ids[i], out_ids[j_min]);
558
+ if (ctx.logits_size > 0) {
559
+ for (uint32_t k = 0; k < n_vocab; k++) {
560
+ std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]);
561
+ }
562
+ }
563
+ if (ctx.embd_size > 0) {
564
+ for (uint32_t k = 0; k < n_embd; k++) {
565
+ std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]);
566
+ }
567
+ }
568
+ }
569
+ std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1);
570
+ for (int32_t i = 0; i < n_outputs; ++i) {
571
+ ctx.output_ids[out_ids[i]] = i;
572
+ }
573
+ out_ids.clear();
574
+ }
575
+ }
576
+
577
+ //
578
+ // interface implementation
579
+ //
580
+
581
+ void llama_free(struct llama_context * ctx) {
582
+ delete ctx;
583
+ }
584
+
585
+ uint32_t llama_n_ctx(const struct llama_context * ctx) {
586
+ return ctx->cparams.n_ctx;
587
+ }
588
+
589
+ uint32_t llama_n_batch(const struct llama_context * ctx) {
590
+ return ctx->cparams.n_batch;
591
+ }
592
+
593
+ uint32_t llama_n_ubatch(const struct llama_context * ctx) {
594
+ return ctx->cparams.n_ubatch;
595
+ }
596
+
597
+ uint32_t llama_n_seq_max(const struct llama_context * ctx) {
598
+ return ctx->kv_self.size;
599
+ }
600
+
601
+ const struct llama_model * llama_get_model(const struct llama_context * ctx) {
602
+ return &ctx->model;
603
+ }
604
+
605
+ enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
606
+ return ctx->cparams.pooling_type;
607
+ }
608
+
609
+ void llama_attach_threadpool(
610
+ struct llama_context * ctx,
611
+ ggml_threadpool_t threadpool,
612
+ ggml_threadpool_t threadpool_batch) {
613
+ ctx->threadpool = threadpool;
614
+ ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
615
+ }
616
+
617
+ void llama_detach_threadpool(struct llama_context * ctx) {
618
+ ctx->threadpool = nullptr;
619
+ ctx->threadpool_batch = nullptr;
620
+ }
621
+
622
+ void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
623
+ ctx->cparams.n_threads = n_threads;
624
+ ctx->cparams.n_threads_batch = n_threads_batch;
625
+ }
626
+
627
+ int32_t llama_n_threads(struct llama_context * ctx) {
628
+ return ctx->cparams.n_threads;
629
+ }
630
+
631
+ int32_t llama_n_threads_batch(struct llama_context * ctx) {
632
+ return ctx->cparams.n_threads_batch;
633
+ }
634
+
635
+ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
636
+ ctx->abort_callback = abort_callback;
637
+ ctx->abort_callback_data = abort_callback_data;
638
+
639
+ for (auto & backend : ctx->backends) {
640
+ auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
641
+ auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
642
+ if (set_abort_callback_fn) {
643
+ set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data);
644
+ }
645
+ }
646
+ }
647
+
648
+ void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
649
+ ctx->cparams.embeddings = embeddings;
650
+ }
651
+
652
+ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
653
+ ctx->cparams.causal_attn = causal_attn;
654
+ }
655
+
656
+ void llama_synchronize(struct llama_context * ctx) {
657
+ ggml_backend_sched_synchronize(ctx->sched.get());
658
+
659
+ // FIXME: if multiple single tokens are evaluated without a synchronization,
660
+ // the stats will be added to the prompt evaluation stats
661
+ // this should only happen when using batch size 1 to evaluate a batch
662
+
663
+ // add the evaluation to the stats
664
+ if (ctx->n_queued_tokens == 1) {
665
+ if (!ctx->cparams.no_perf) {
666
+ ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
667
+ }
668
+ ctx->n_eval++;
669
+ } else if (ctx->n_queued_tokens > 1) {
670
+ if (!ctx->cparams.no_perf) {
671
+ ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
672
+ }
673
+ ctx->n_p_eval += ctx->n_queued_tokens;
674
+ }
675
+
676
+ // get a more accurate load time, upon first eval
677
+ if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
678
+ ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
679
+ ctx->has_evaluated_once = true;
680
+ }
681
+
682
+ ctx->n_queued_tokens = 0;
683
+ ctx->t_compute_start_us = 0;
684
+ }
685
+
686
+ float * llama_get_logits(struct llama_context * ctx) {
687
+ llama_synchronize(ctx);
688
+
689
+ // reorder logits for backward compatibility
690
+ // TODO: maybe deprecate this
691
+ llama_output_reorder(*ctx);
692
+
693
+ return ctx->logits;
694
+ }
695
+
696
+ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
697
+ int32_t j = -1;
698
+
699
+ llama_synchronize(ctx);
700
+
701
+ try {
702
+ if (ctx->logits == nullptr) {
703
+ throw std::runtime_error("no logits");
704
+ }
705
+
706
+ if (i < 0) {
707
+ j = ctx->n_outputs + i;
708
+ if (j < 0) {
709
+ throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
710
+ }
711
+ } else if ((size_t) i >= ctx->output_ids.size()) {
712
+ throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
713
+ } else {
714
+ j = ctx->output_ids[i];
715
+ }
716
+
717
+ if (j < 0) {
718
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
719
+ }
720
+ if (j >= ctx->n_outputs) {
721
+ // This should not happen
722
+ throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
723
+ }
724
+
725
+ return ctx->logits + j*ctx->model.hparams.n_vocab;
726
+ } catch (const std::exception & err) {
727
+ LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
728
+ #ifndef NDEBUG
729
+ GGML_ABORT("fatal error");
730
+ #else
731
+ return nullptr;
732
+ #endif
733
+ }
734
+ }
735
+
736
+ float * llama_get_embeddings(struct llama_context * ctx) {
737
+ llama_synchronize(ctx);
738
+
739
+ // reorder embeddings for backward compatibility
740
+ // TODO: maybe deprecate this
741
+ llama_output_reorder(*ctx);
742
+
743
+ return ctx->embd;
744
+ }
745
+
746
+ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
747
+ int32_t j = -1;
748
+
749
+ llama_synchronize(ctx);
750
+
751
+ try {
752
+ if (ctx->embd == nullptr) {
753
+ throw std::runtime_error("no embeddings");
754
+ }
755
+
756
+ if (i < 0) {
757
+ j = ctx->n_outputs + i;
758
+ if (j < 0) {
759
+ throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
760
+ }
761
+ } else if ((size_t) i >= ctx->output_ids.size()) {
762
+ throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
763
+ } else {
764
+ j = ctx->output_ids[i];
765
+ }
766
+
767
+ if (j < 0) {
768
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
769
+ }
770
+ if (j >= ctx->n_outputs) {
771
+ // This should not happen
772
+ throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
773
+ }
774
+
775
+ return ctx->embd + j*ctx->model.hparams.n_embd;
776
+ } catch (const std::exception & err) {
777
+ LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
778
+ #ifndef NDEBUG
779
+ GGML_ABORT("fatal error");
780
+ #else
781
+ return nullptr;
782
+ #endif
783
+ }
784
+ }
785
+
786
+ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
787
+ llama_synchronize(ctx);
788
+
789
+ auto it = ctx->embd_seq.find(seq_id);
790
+ if (it == ctx->embd_seq.end()) {
791
+ return nullptr;
792
+ }
793
+
794
+ return it->second.data();
795
+ }
796
+
797
+ // llama state API
798
+
799
+ // deprecated
800
+ size_t llama_get_state_size(struct llama_context * ctx) {
801
+ return llama_state_get_size(ctx);
802
+ }
803
+
804
+ // deprecated
805
+ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
806
+ return llama_state_get_data(ctx, dst, -1);
807
+ }
808
+
809
+ // deprecated
810
+ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
811
+ return llama_state_set_data(ctx, src, -1);
812
+ }
813
+
814
+ // deprecated
815
+ bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
816
+ return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
817
+ }
818
+
819
+ // deprecated
820
+ bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
821
+ return llama_state_save_file(ctx, path_session, tokens, n_token_count);
822
+ }
823
+
824
+ // TODO: replace all non-fatal assertions with returned errors or exceptions
825
+ struct llama_data_write {
826
+ virtual void write(const void * src, size_t size) = 0;
827
+ virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0;
828
+ virtual size_t get_size_written() = 0;
829
+ virtual ~llama_data_write() = default;
830
+
831
+ void write_string(const std::string & str) {
832
+ uint32_t str_size = str.size();
833
+
834
+ write(&str_size, sizeof(str_size));
835
+ write(str.data(), str_size);
836
+ }
837
+
838
+ void write_model_info(const struct llama_context * ctx) {
839
+ const std::string arch_str = llm_arch_name(ctx->model.arch);
840
+ write_string(arch_str);
841
+ // TODO: add more model-specific info which should prevent loading the session file if not identical
842
+ }
843
+
844
+ //void write_rng(const std::mt19937 & rng) {
845
+ // std::ostringstream rng_ss;
846
+ // rng_ss << rng;
847
+
848
+ // const std::string & rng_str = rng_ss.str();
849
+
850
+ // write_string(rng_str);
851
+ //}
852
+
853
+ void write_output_ids(struct llama_context * ctx) {
854
+ llama_output_reorder(*ctx);
855
+
856
+ const uint32_t n_outputs = ctx->n_outputs;
857
+
858
+ std::vector<int32_t> output_pos;
859
+
860
+ const size_t n_batch = ctx->cparams.n_batch;
861
+ const auto & output_ids = ctx->output_ids;
862
+
863
+ GGML_ASSERT(n_outputs <= ctx->output_size);
864
+
865
+ output_pos.resize(n_outputs);
866
+
867
+ // build a more compact representation of the output ids
868
+ for (size_t i = 0; i < n_batch; ++i) {
869
+ // map an output id to a position in the batch
870
+ int32_t pos = output_ids[i];
871
+ if (pos >= 0) {
872
+ GGML_ASSERT((uint32_t) pos < n_outputs);
873
+ output_pos[pos] = i;
874
+ }
875
+ }
876
+
877
+ write(&n_outputs, sizeof(n_outputs));
878
+
879
+ if (n_outputs) {
880
+ write(output_pos.data(), n_outputs * sizeof(int32_t));
881
+ }
882
+ }
883
+
884
+ void write_logits(const struct llama_context * ctx) {
885
+ const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
886
+
887
+ write(&logits_size, sizeof(logits_size));
888
+
889
+ if (logits_size) {
890
+ write(ctx->logits, logits_size * sizeof(float));
891
+ }
892
+ }
893
+
894
+ void write_embeddings(const struct llama_context * ctx) {
895
+ const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
896
+
897
+ write(&embeddings_size, sizeof(embeddings_size));
898
+
899
+ if (embeddings_size) {
900
+ write(ctx->embd, embeddings_size * sizeof(float));
901
+ }
902
+ }
903
+
904
+ void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
905
+ for (const auto & range : cell_ranges) {
906
+ for (uint32_t i = range.first; i < range.second; ++i) {
907
+ const auto & cell = kv_self.cells[i];
908
+ const llama_pos pos = cell.pos;
909
+ const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
910
+
911
+ write(&pos, sizeof(pos));
912
+ write(&n_seq_id, sizeof(n_seq_id));
913
+
914
+ if (n_seq_id) {
915
+ for (auto seq_id : cell.seq_id) {
916
+ write(&seq_id, sizeof(seq_id));
917
+ }
918
+ }
919
+ }
920
+ }
921
+ }
922
+
923
+ void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
924
+ const struct llama_kv_cache & kv_self = ctx->kv_self;
925
+ const struct llama_hparams & hparams = ctx->model.hparams;
926
+
927
+ const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
928
+ const uint32_t n_layer = hparams.n_layer;
929
+
930
+ write(&v_trans, sizeof(v_trans));
931
+ write(&n_layer, sizeof(n_layer));
932
+
933
+ std::vector<uint8_t> tmp_buf;
934
+
935
+ // Iterate and write all the keys first, each row is a cell
936
+ // Get whole range at a time
937
+ for (uint32_t il = 0; il < n_layer; ++il) {
938
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
939
+
940
+ // Write key type
941
+ const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
942
+ write(&k_type_i, sizeof(k_type_i));
943
+
944
+ // Write row size of key
945
+ const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
946
+ write(&k_size_row, sizeof(k_size_row));
947
+
948
+ // Read each range of cells of k_size length each into tmp_buf and write out
949
+ for (const auto & range : cell_ranges) {
950
+ const size_t range_size = range.second - range.first;
951
+ const size_t buf_size = range_size * k_size_row;
952
+ write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
953
+ }
954
+ }
955
+
956
+ if (!kv_self.v_trans) {
957
+ for (uint32_t il = 0; il < n_layer; ++il) {
958
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
959
+
960
+ // Write value type
961
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
962
+ write(&v_type_i, sizeof(v_type_i));
963
+
964
+ // Write row size of value
965
+ const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
966
+ write(&v_size_row, sizeof(v_size_row));
967
+
968
+ // Read each range of cells of v_size length each into tmp_buf and write out
969
+ for (const auto & range : cell_ranges) {
970
+ const size_t range_size = range.second - range.first;
971
+ const size_t buf_size = range_size * v_size_row;
972
+ write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
973
+ }
974
+ }
975
+ } else {
976
+ // When v is transposed, we also need the element size and get the element ranges from each row
977
+ const uint32_t kv_size = kv_self.size;
978
+ for (uint32_t il = 0; il < n_layer; ++il) {
979
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
980
+
981
+ // Write value type
982
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
983
+ write(&v_type_i, sizeof(v_type_i));
984
+
985
+ // Write element size
986
+ const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
987
+ write(&v_size_el, sizeof(v_size_el));
988
+
989
+ // Write GQA embedding size
990
+ write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
991
+
992
+ // For each row, we get the element values of each cell
993
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
994
+ // Read each range of cells of v_size_el length each into tmp_buf and write out
995
+ for (const auto & range : cell_ranges) {
996
+ const size_t range_size = range.second - range.first;
997
+ const size_t src_offset = (range.first + j * kv_size) * v_size_el;
998
+ const size_t buf_size = range_size * v_size_el;
999
+ write_tensor_data(kv_self.v_l[il], src_offset, buf_size);
1000
+ }
1001
+ }
1002
+ }
1003
+ }
1004
+ }
1005
+
1006
+ void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
1007
+ const struct llama_kv_cache & kv_self = ctx->kv_self;
1008
+ std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1009
+ uint32_t cell_count = 0;
1010
+
1011
+ // Count the number of cells with the specified seq_id
1012
+ // Find all the ranges of cells with this seq id (or all, when -1)
1013
+ uint32_t cell_range_begin = kv_self.size;
1014
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
1015
+ const auto & cell = kv_self.cells[i];
1016
+ if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
1017
+ ++cell_count;
1018
+ if (cell_range_begin == kv_self.size) {
1019
+ cell_range_begin = i;
1020
+ }
1021
+ } else {
1022
+ if (cell_range_begin != kv_self.size) {
1023
+ cell_ranges.emplace_back(cell_range_begin, i);
1024
+ cell_range_begin = kv_self.size;
1025
+ }
1026
+ }
1027
+ }
1028
+ if (cell_range_begin != kv_self.size) {
1029
+ cell_ranges.emplace_back(cell_range_begin, kv_self.size);
1030
+ }
1031
+
1032
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1033
+ uint32_t cell_count_check = 0;
1034
+ for (const auto & range : cell_ranges) {
1035
+ cell_count_check += range.second - range.first;
1036
+ }
1037
+ GGML_ASSERT(cell_count == cell_count_check);
1038
+
1039
+ write(&cell_count, sizeof(cell_count));
1040
+
1041
+ write_kv_cache_meta(kv_self, cell_ranges, seq_id);
1042
+ write_kv_cache_data(ctx, cell_ranges);
1043
+ }
1044
+ };
1045
+
1046
+ struct llama_data_read {
1047
+ virtual const uint8_t * read(size_t size) = 0;
1048
+ virtual void read_to(void * dst, size_t size) = 0;
1049
+ virtual size_t get_size_read() = 0;
1050
+ virtual ~llama_data_read() = default;
1051
+
1052
+ void read_string(std::string & str) {
1053
+ uint32_t str_size;
1054
+ read_to(&str_size, sizeof(str_size));
1055
+
1056
+ str.assign((const char *) read(str_size), str_size);
1057
+ }
1058
+
1059
+ // validate model information
1060
+ void read_model_info(const struct llama_context * ctx) {
1061
+ const std::string cur_arch_str = llm_arch_name(ctx->model.arch);
1062
+
1063
+ std::string arch_str;
1064
+ read_string(arch_str);
1065
+ if (cur_arch_str != arch_str) {
1066
+ throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
1067
+ }
1068
+ // TODO: add more info which needs to be identical but which is not verified otherwise
1069
+ }
1070
+
1071
+ //void read_rng(std::mt19937 & rng) {
1072
+ // std::string rng_str;
1073
+ // read_string(rng_str);
1074
+
1075
+ // std::istringstream rng_ss(rng_str);
1076
+ // rng_ss >> rng;
1077
+
1078
+ // if (rng_ss.fail()) {
1079
+ // throw std::runtime_error("failed to load RNG state");
1080
+ // }
1081
+ //}
1082
+
1083
+ void read_output_ids(struct llama_context * ctx) {
1084
+ std::vector<int32_t> output_pos;
1085
+
1086
+ uint32_t n_outputs;
1087
+ read_to(&n_outputs, sizeof(n_outputs));
1088
+
1089
+ if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
1090
+ throw std::runtime_error("could not reserve outputs");
1091
+ }
1092
+
1093
+ if (n_outputs) {
1094
+ output_pos.resize(n_outputs);
1095
+ read_to(output_pos.data(), n_outputs * sizeof(int32_t));
1096
+
1097
+ for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
1098
+ int32_t id = output_pos[i];
1099
+ if ((uint32_t) id >= ctx->cparams.n_batch) {
1100
+ throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
1101
+ }
1102
+ ctx->output_ids[id] = i;
1103
+ }
1104
+
1105
+ ctx->n_outputs = n_outputs;
1106
+ }
1107
+ }
1108
+
1109
+ void read_logits(struct llama_context * ctx) {
1110
+ uint64_t logits_size;
1111
+ read_to(&logits_size, sizeof(logits_size));
1112
+
1113
+ if (ctx->logits_size < logits_size) {
1114
+ throw std::runtime_error("logits buffer too small");
1115
+ }
1116
+
1117
+ if (logits_size) {
1118
+ read_to(ctx->logits, logits_size * sizeof(float));
1119
+ }
1120
+ }
1121
+
1122
+ void read_embeddings(struct llama_context * ctx) {
1123
+ uint64_t embeddings_size;
1124
+ read_to(&embeddings_size, sizeof(embeddings_size));
1125
+
1126
+ if (ctx->embd_size < embeddings_size) {
1127
+ throw std::runtime_error("embeddings buffer too small");
1128
+ }
1129
+
1130
+ if (embeddings_size) {
1131
+ read_to(ctx->embd, embeddings_size * sizeof(float));
1132
+ }
1133
+ }
1134
+
1135
+ bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
1136
+ struct llama_kv_cache & kv_self = ctx->kv_self;
1137
+
1138
+ if (dest_seq_id != -1) {
1139
+ // single sequence
1140
+
1141
+ llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
1142
+
1143
+ llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1144
+ batch.n_tokens = cell_count;
1145
+ batch.n_seq_tokens = cell_count;
1146
+ batch.n_seqs = 1;
1147
+
1148
+ for (uint32_t i = 0; i < cell_count; ++i) {
1149
+ llama_pos pos;
1150
+ uint32_t n_seq_id;
1151
+
1152
+ read_to(&pos, sizeof(pos));
1153
+ read_to(&n_seq_id, sizeof(n_seq_id));
1154
+
1155
+ if (n_seq_id != 0) {
1156
+ LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1157
+ return false;
1158
+ }
1159
+
1160
+ batch.pos[i] = pos;
1161
+ }
1162
+ batch.n_seq_id[0] = 1;
1163
+ batch.seq_id[0] = &dest_seq_id;
1164
+ if (!llama_kv_cache_find_slot(kv_self, batch)) {
1165
+ LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1166
+ return false;
1167
+ }
1168
+
1169
+ // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1170
+ // Assume that this is one contiguous block of cells
1171
+ GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
1172
+ GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
1173
+ GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
1174
+ GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
1175
+ GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
1176
+ } else {
1177
+ // whole KV cache restore
1178
+
1179
+ if (cell_count > kv_self.size) {
1180
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1181
+ return false;
1182
+ }
1183
+
1184
+ llama_kv_cache_clear(kv_self);
1185
+
1186
+ for (uint32_t i = 0; i < cell_count; ++i) {
1187
+ llama_kv_cell & cell = kv_self.cells[i];
1188
+
1189
+ llama_pos pos;
1190
+ uint32_t n_seq_id;
1191
+
1192
+ read_to(&pos, sizeof(pos));
1193
+ read_to(&n_seq_id, sizeof(n_seq_id));
1194
+
1195
+ cell.pos = pos;
1196
+
1197
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
1198
+ llama_seq_id seq_id;
1199
+ read_to(&seq_id, sizeof(seq_id));
1200
+
1201
+ if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1202
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
1203
+ return false;
1204
+ }
1205
+
1206
+ cell.seq_id.insert(seq_id);
1207
+
1208
+ if (kv_self.recurrent) {
1209
+ int32_t & tail = kv_self.cells[seq_id].tail;
1210
+ if (tail != -1) {
1211
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
1212
+ return false;
1213
+ }
1214
+ tail = i;
1215
+ }
1216
+ }
1217
+ }
1218
+
1219
+ kv_self.head = 0;
1220
+ kv_self.used = cell_count;
1221
+ }
1222
+
1223
+ if (kv_self.recurrent) {
1224
+ for (uint32_t i = 0; i < cell_count; ++i) {
1225
+ uint32_t cell_id = kv_self.head + i;
1226
+ // make sure the recurrent states will keep their restored state
1227
+ kv_self.cells[cell_id].src = cell_id;
1228
+ }
1229
+ }
1230
+
1231
+ return true;
1232
+ }
1233
+
1234
+ bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
1235
+ const struct llama_hparams & hparams = ctx->model.hparams;
1236
+ struct llama_kv_cache & kv_self = ctx->kv_self;
1237
+ uint32_t v_trans;
1238
+ uint32_t n_layer;
1239
+ read_to(&v_trans, sizeof(v_trans));
1240
+ read_to(&n_layer, sizeof(n_layer));
1241
+
1242
+ if (n_layer != hparams.n_layer) {
1243
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
1244
+ return false;
1245
+ }
1246
+ if (cell_count > kv_self.size) {
1247
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
1248
+ return false;
1249
+ }
1250
+ if (kv_self.v_trans != (bool) v_trans) {
1251
+ LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1252
+ return false;
1253
+ }
1254
+
1255
+ // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1256
+ for (uint32_t il = 0; il < n_layer; ++il) {
1257
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1258
+
1259
+ // Read type of key
1260
+ int32_t k_type_i_ref;
1261
+ read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1262
+ const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
1263
+ if (k_type_i != k_type_i_ref) {
1264
+ LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1265
+ return false;
1266
+ }
1267
+
1268
+ // Read row size of key
1269
+ uint64_t k_size_row_ref;
1270
+ read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1271
+ const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
1272
+ if (k_size_row != k_size_row_ref) {
1273
+ LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1274
+ return false;
1275
+ }
1276
+
1277
+ if (cell_count) {
1278
+ // Read and set the keys for the whole cell range
1279
+ ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
1280
+ }
1281
+ }
1282
+
1283
+ if (!kv_self.v_trans) {
1284
+ for (uint32_t il = 0; il < n_layer; ++il) {
1285
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1286
+
1287
+ // Read type of value
1288
+ int32_t v_type_i_ref;
1289
+ read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1290
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
1291
+ if (v_type_i != v_type_i_ref) {
1292
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1293
+ return false;
1294
+ }
1295
+
1296
+ // Read row size of value
1297
+ uint64_t v_size_row_ref;
1298
+ read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1299
+ const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
1300
+ if (v_size_row != v_size_row_ref) {
1301
+ LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1302
+ return false;
1303
+ }
1304
+
1305
+ if (cell_count) {
1306
+ // Read and set the values for the whole cell range
1307
+ ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
1308
+ }
1309
+ }
1310
+ } else {
1311
+ // For each layer, read the values for each cell (transposed)
1312
+ for (uint32_t il = 0; il < n_layer; ++il) {
1313
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1314
+
1315
+ // Read type of value
1316
+ int32_t v_type_i_ref;
1317
+ read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1318
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
1319
+ if (v_type_i != v_type_i_ref) {
1320
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1321
+ return false;
1322
+ }
1323
+
1324
+ // Read element size of value
1325
+ uint32_t v_size_el_ref;
1326
+ read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1327
+ const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
1328
+ if (v_size_el != v_size_el_ref) {
1329
+ LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1330
+ return false;
1331
+ }
1332
+
1333
+ // Read GQA embedding size
1334
+ uint32_t n_embd_v_gqa_ref;
1335
+ read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1336
+ if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1337
+ LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1338
+ return false;
1339
+ }
1340
+
1341
+ if (cell_count) {
1342
+ // For each row in the transposed matrix, read the values for the whole cell range
1343
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1344
+ const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
1345
+ ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1346
+ }
1347
+ }
1348
+ }
1349
+ }
1350
+ return true;
1351
+ }
1352
+
1353
+ void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
1354
+ uint32_t cell_count;
1355
+ read_to(&cell_count, sizeof(cell_count));
1356
+
1357
+ bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
1358
+
1359
+ if (!res) {
1360
+ if (seq_id == -1) {
1361
+ llama_kv_cache_clear(ctx);
1362
+ } else {
1363
+ llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
1364
+ }
1365
+ throw std::runtime_error("failed to restore kv cache");
1366
+ }
1367
+ }
1368
+ };
1369
+
1370
+ struct llama_data_write_dummy : llama_data_write {
1371
+ size_t size_written = 0;
1372
+
1373
+ llama_data_write_dummy() {}
1374
+
1375
+ void write(const void * /* src */, size_t size) override {
1376
+ size_written += size;
1377
+ }
1378
+
1379
+ void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
1380
+ size_written += size;
1381
+ }
1382
+
1383
+ size_t get_size_written() override {
1384
+ return size_written;
1385
+ }
1386
+ };
1387
+
1388
+ struct llama_data_write_buffer : llama_data_write {
1389
+ uint8_t * ptr;
1390
+ size_t buf_size = 0;
1391
+ size_t size_written = 0;
1392
+
1393
+ llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1394
+
1395
+ void write(const void * src, size_t size) override {
1396
+ if (size > buf_size) {
1397
+ throw std::runtime_error("unexpectedly reached end of buffer");
1398
+ }
1399
+ memcpy(ptr, src, size);
1400
+ ptr += size;
1401
+ size_written += size;
1402
+ buf_size -= size;
1403
+ }
1404
+
1405
+ void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
1406
+ if (size > buf_size) {
1407
+ throw std::runtime_error("unexpectedly reached end of buffer");
1408
+ }
1409
+ ggml_backend_tensor_get(tensor, ptr, offset, size);
1410
+ ptr += size;
1411
+ size_written += size;
1412
+ buf_size -= size;
1413
+ }
1414
+
1415
+ size_t get_size_written() override {
1416
+ return size_written;
1417
+ }
1418
+ };
1419
+
1420
+ struct llama_data_read_buffer : llama_data_read {
1421
+ const uint8_t * ptr;
1422
+ size_t buf_size = 0;
1423
+ size_t size_read = 0;
1424
+
1425
+ llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1426
+
1427
+ const uint8_t * read(size_t size) override {
1428
+ const uint8_t * base_ptr = ptr;
1429
+ if (size > buf_size) {
1430
+ throw std::runtime_error("unexpectedly reached end of buffer");
1431
+ }
1432
+ ptr += size;
1433
+ size_read += size;
1434
+ buf_size -= size;
1435
+ return base_ptr;
1436
+ }
1437
+
1438
+ void read_to(void * dst, size_t size) override {
1439
+ memcpy(dst, read(size), size);
1440
+ }
1441
+
1442
+ size_t get_size_read() override {
1443
+ return size_read;
1444
+ }
1445
+ };
1446
+
1447
+ struct llama_data_write_file : llama_data_write {
1448
+ llama_file * file;
1449
+ size_t size_written = 0;
1450
+ std::vector<uint8_t> temp_buffer;
1451
+
1452
+ llama_data_write_file(llama_file * f) : file(f) {}
1453
+
1454
+ void write(const void * src, size_t size) override {
1455
+ file->write_raw(src, size);
1456
+ size_written += size;
1457
+ }
1458
+
1459
+ void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
1460
+ temp_buffer.resize(size);
1461
+ ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
1462
+ write(temp_buffer.data(), temp_buffer.size());
1463
+ }
1464
+
1465
+ size_t get_size_written() override {
1466
+ return size_written;
1467
+ }
1468
+ };
1469
+
1470
+ struct llama_data_read_file : llama_data_read {
1471
+ llama_file * file;
1472
+ size_t size_read = 0;
1473
+ std::vector<uint8_t> temp_buffer;
1474
+
1475
+ llama_data_read_file(llama_file * f) : file(f) {}
1476
+
1477
+ void read_to(void * dst, size_t size) override {
1478
+ file->read_raw(dst, size);
1479
+ size_read += size;
1480
+ }
1481
+
1482
+ const uint8_t * read(size_t size) override {
1483
+ temp_buffer.resize(size);
1484
+ read_to(temp_buffer.data(), size);
1485
+ return temp_buffer.data();
1486
+ }
1487
+
1488
+ size_t get_size_read() override {
1489
+ return size_read;
1490
+ }
1491
+ };
1492
+
1493
+ /** copy state data into either a buffer or file depending on the passed in context
1494
+ *
1495
+ * file context:
1496
+ * llama_file file("/path", "wb");
1497
+ * llama_data_write_file data_ctx(&file);
1498
+ * llama_state_get_data_internal(ctx, data_ctx);
1499
+ *
1500
+ * buffer context:
1501
+ * std::vector<uint8_t> buf(max_size, 0);
1502
+ * llama_data_write_buffer data_ctx(buf.data(), max_size);
1503
+ * llama_state_get_data_internal(ctx, data_ctx);
1504
+ *
1505
+ */
1506
+ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
1507
+ llama_synchronize(ctx);
1508
+
1509
+ data_ctx.write_model_info(ctx);
1510
+
1511
+ // copy outputs
1512
+ data_ctx.write_output_ids(ctx);
1513
+ data_ctx.write_logits(ctx);
1514
+ data_ctx.write_embeddings(ctx);
1515
+
1516
+ data_ctx.write_kv_cache(ctx);
1517
+
1518
+ return data_ctx.get_size_written();
1519
+ }
1520
+
1521
+ size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
1522
+ llama_data_write_buffer data_ctx(dst, size);
1523
+ try {
1524
+ return llama_state_get_data_internal(ctx, data_ctx);
1525
+ } catch (const std::exception & err) {
1526
+ LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1527
+ return 0;
1528
+ }
1529
+ }
1530
+
1531
+ // Returns the *actual* size of the state.
1532
+ // Intended to be used when saving to state to a buffer.
1533
+ size_t llama_state_get_size(struct llama_context * ctx) {
1534
+ llama_data_write_dummy data_ctx;
1535
+ try {
1536
+ return llama_state_get_data_internal(ctx, data_ctx);
1537
+ } catch (const std::exception & err) {
1538
+ LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1539
+ return 0;
1540
+ }
1541
+ }
1542
+
1543
+ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
1544
+ llama_synchronize(ctx);
1545
+
1546
+ data_ctx.read_model_info(ctx);
1547
+
1548
+ // set outputs
1549
+ data_ctx.read_output_ids(ctx);
1550
+ data_ctx.read_logits(ctx);
1551
+ data_ctx.read_embeddings(ctx);
1552
+
1553
+ data_ctx.read_kv_cache(ctx);
1554
+
1555
+ return data_ctx.get_size_read();
1556
+ }
1557
+
1558
+ // Sets the state reading from the specified source address
1559
+ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
1560
+ llama_data_read_buffer data_ctx(src, size);
1561
+ try {
1562
+ return llama_state_set_data_internal(ctx, data_ctx);
1563
+ } catch (const std::exception & err) {
1564
+ LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1565
+ return 0;
1566
+ }
1567
+ }
1568
+
1569
+ static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1570
+ llama_file file(path_session, "rb");
1571
+
1572
+ // sanity checks
1573
+ {
1574
+ const uint32_t magic = file.read_u32();
1575
+ const uint32_t version = file.read_u32();
1576
+
1577
+ if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
1578
+ LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
1579
+ return false;
1580
+ }
1581
+ }
1582
+
1583
+ // load the prompt
1584
+ {
1585
+ const uint32_t n_token_count = file.read_u32();
1586
+
1587
+ if (n_token_count > n_token_capacity) {
1588
+ LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
1589
+ return false;
1590
+ }
1591
+
1592
+ file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
1593
+ *n_token_count_out = n_token_count;
1594
+ }
1595
+
1596
+ // restore the context state
1597
+ {
1598
+ const size_t n_state_size_cur = file.size() - file.tell();
1599
+
1600
+ llama_data_read_file data_ctx(&file);
1601
+ const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
1602
+
1603
+ if (n_read != n_state_size_cur) {
1604
+ LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
1605
+ return false;
1606
+ }
1607
+ }
1608
+ return true;
1609
+ }
1610
+
1611
+ bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1612
+ try {
1613
+ return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
1614
+ } catch (const std::exception & err) {
1615
+ LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
1616
+ return false;
1617
+ }
1618
+ }
1619
+
1620
+ static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
1621
+ llama_file file(path_session, "wb");
1622
+
1623
+ file.write_u32(LLAMA_SESSION_MAGIC);
1624
+ file.write_u32(LLAMA_SESSION_VERSION);
1625
+
1626
+ // save the prompt
1627
+ file.write_u32((uint32_t) n_token_count);
1628
+ file.write_raw(tokens, sizeof(llama_token) * n_token_count);
1629
+
1630
+ // save the context state using stream saving
1631
+ llama_data_write_file data_ctx(&file);
1632
+ llama_state_get_data_internal(ctx, data_ctx);
1633
+
1634
+ return true;
1635
+ }
1636
+
1637
+ bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
1638
+ try {
1639
+ return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
1640
+ } catch (const std::exception & err) {
1641
+ LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
1642
+ return false;
1643
+ }
1644
+ }
1645
+
1646
+ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
1647
+ llama_synchronize(ctx);
1648
+
1649
+ data_ctx.write_kv_cache(ctx, seq_id);
1650
+
1651
+ return data_ctx.get_size_written();
1652
+ }
1653
+
1654
+ size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
1655
+ llama_data_write_dummy data_ctx;
1656
+ return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
1657
+ }
1658
+
1659
+ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
1660
+ llama_data_write_buffer data_ctx(dst, size);
1661
+ try {
1662
+ return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
1663
+ } catch (const std::exception & err) {
1664
+ LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
1665
+ return 0;
1666
+ }
1667
+ }
1668
+
1669
+ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
1670
+ llama_synchronize(ctx);
1671
+
1672
+ data_ctx.read_kv_cache(ctx, dest_seq_id);
1673
+
1674
+ return data_ctx.get_size_read();
1675
+ }
1676
+
1677
+ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
1678
+ llama_data_read_buffer data_ctx(src, size);
1679
+ try {
1680
+ return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
1681
+ } catch (const std::exception & err) {
1682
+ LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
1683
+ return 0;
1684
+ }
1685
+ }
1686
+
1687
+ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
1688
+ llama_file file(filepath, "wb");
1689
+
1690
+ file.write_u32(LLAMA_STATE_SEQ_MAGIC);
1691
+ file.write_u32(LLAMA_STATE_SEQ_VERSION);
1692
+
1693
+ // save the prompt
1694
+ file.write_u32((uint32_t) n_token_count);
1695
+ file.write_raw(tokens, sizeof(llama_token) * n_token_count);
1696
+
1697
+ // save the context state using stream saving
1698
+ llama_data_write_file data_ctx(&file);
1699
+ llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
1700
+
1701
+ const size_t res = file.tell();
1702
+ GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
1703
+ return res;
1704
+ }
1705
+
1706
+ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1707
+ llama_file file(filepath, "rb");
1708
+
1709
+ // version checks
1710
+ {
1711
+ const uint32_t magic = file.read_u32();
1712
+ const uint32_t version = file.read_u32();
1713
+
1714
+ if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
1715
+ LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
1716
+ return 0;
1717
+ }
1718
+ }
1719
+
1720
+ // load the prompt
1721
+ {
1722
+ const uint32_t n_token_count = file.read_u32();
1723
+
1724
+ if (n_token_count > n_token_capacity) {
1725
+ LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
1726
+ return 0;
1727
+ }
1728
+
1729
+ file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
1730
+ *n_token_count_out = n_token_count;
1731
+ }
1732
+
1733
+ // restore the context state
1734
+ {
1735
+ const size_t state_size = file.size() - file.tell();
1736
+ llama_data_read_file data_ctx(&file);
1737
+ const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
1738
+ if (!nread) {
1739
+ LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
1740
+ return 0;
1741
+ }
1742
+ GGML_ASSERT(nread <= state_size);
1743
+ GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
1744
+ }
1745
+
1746
+ return file.tell();
1747
+ }
1748
+
1749
+ size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
1750
+ try {
1751
+ return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
1752
+ } catch (const std::exception & err) {
1753
+ LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
1754
+ return 0;
1755
+ }
1756
+ }
1757
+
1758
+ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1759
+ try {
1760
+ return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
1761
+ } catch (const std::exception & err) {
1762
+ LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
1763
+ return 0;
1764
+ }
1765
+ }
1766
+
1767
+ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1768
+ struct llama_context * ctx
1769
+ ) {
1770
+ return ctx->model.tensors_by_name;
1771
+ }
examples/talk-llama/llama-context.h ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+ #include "llama-batch.h"
5
+ #include "llama-cparams.h"
6
+ #include "llama-model.h"
7
+ #include "llama-kv-cache.h"
8
+ #include "llama-adapter.h"
9
+
10
+ #include "ggml-cpp.h"
11
+
12
+ #include <map>
13
+ #include <unordered_map>
14
+ #include <vector>
15
+ #include <set>
16
+
17
+ struct llama_context {
18
+ llama_context(const llama_model & model)
19
+ : model(model)
20
+ , t_start_us(model.t_start_us)
21
+ , t_load_us(model.t_load_us) {}
22
+
23
+ const struct llama_model & model;
24
+
25
+ struct llama_cparams cparams;
26
+ struct llama_sbatch sbatch; // TODO: revisit if needed
27
+ struct llama_kv_cache kv_self;
28
+ struct llama_control_vector cvec;
29
+
30
+ std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
31
+
32
+ std::vector<ggml_backend_ptr> backends;
33
+ std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
34
+
35
+ ggml_backend_t backend_cpu = nullptr;
36
+
37
+ ggml_threadpool_t threadpool = nullptr;
38
+ ggml_threadpool_t threadpool_batch = nullptr;
39
+
40
+ bool has_evaluated_once = false;
41
+
42
+ mutable int64_t t_start_us;
43
+ mutable int64_t t_load_us;
44
+ mutable int64_t t_p_eval_us = 0;
45
+ mutable int64_t t_eval_us = 0;
46
+
47
+ mutable int64_t t_compute_start_us = 0;
48
+ mutable int64_t n_queued_tokens = 0;
49
+
50
+ mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
51
+ mutable int32_t n_eval = 0; // number of eval calls
52
+
53
+ // host buffer for the model output (logits and embeddings)
54
+ ggml_backend_buffer_ptr buf_output;
55
+
56
+ // decode output (2-dimensional array: [n_outputs][n_vocab])
57
+ size_t logits_size = 0; // capacity (of floats) for logits
58
+ float * logits = nullptr;
59
+
60
+ std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
61
+ size_t output_size = 0; // capacity (of tokens positions) for the output buffers
62
+ int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
63
+
64
+ bool logits_all = false;
65
+
66
+ // embeddings output (2-dimensional array: [n_outputs][n_embd])
67
+ // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
68
+ size_t embd_size = 0; // capacity (of floats) for embeddings
69
+ float * embd = nullptr;
70
+
71
+ // sequence embeddings output (map of [n_embd] vectors)
72
+ // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
73
+ std::map<llama_seq_id, std::vector<float>> embd_seq;
74
+
75
+ // whether we are computing encoder output or decoder output
76
+ bool is_encoding = false;
77
+
78
+ // TODO: find a better way to accommodate mutli-dimension position encoding methods
79
+ // number of position id each token get, 1 for each token in most cases.
80
+ // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
81
+ int n_pos_per_token = 1;
82
+
83
+ // output of the encoder part of the encoder-decoder models
84
+ std::vector<float> embd_enc;
85
+ std::vector<std::set<llama_seq_id>> seq_ids_enc;
86
+
87
+ // memory buffers used to evaluate the model
88
+ std::vector<uint8_t> buf_compute_meta;
89
+ ggml_backend_sched_ptr sched;
90
+
91
+ ggml_abort_callback abort_callback = nullptr;
92
+ void * abort_callback_data = nullptr;
93
+
94
+ // input tensors
95
+ struct ggml_tensor * inp_tokens; // I32 [n_batch]
96
+ struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
97
+ struct ggml_tensor * inp_pos; // I32 [n_batch]
98
+ struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
99
+ struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
100
+ struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
101
+ struct ggml_tensor * inp_K_shift; // I32 [kv_size]
102
+ struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
103
+ struct ggml_tensor * inp_cls; // I32 [n_batch]
104
+ struct ggml_tensor * inp_s_copy; // I32 [kv_size]
105
+ struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
106
+ struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
107
+ struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
108
+ struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
109
+ struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
110
+ };
111
+
112
+ // TODO: make these methods of llama_context
113
+ void llama_set_k_shift(struct llama_context & lctx);
114
+
115
+ void llama_set_s_copy(struct llama_context & lctx);
116
+
117
+ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
118
+
119
+ // Make sure enough space is available for outputs.
120
+ // Returns max number of outputs for which space was reserved.
121
+ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
122
+
123
+ // make the outputs have the same order they had in the user-provided batch
124
+ void llama_output_reorder(struct llama_context & ctx);
125
+
126
+ // For internal test use
127
+ // TODO: remove
128
+ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
examples/talk-llama/llama-cparams.cpp ADDED
@@ -0,0 +1 @@
 
 
1
+ #include "llama-cparams.h"
examples/talk-llama/llama-cparams.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+
5
+ #include <cstdint>
6
+
7
+ struct llama_cparams {
8
+ uint32_t n_ctx; // context size used during inference
9
+ uint32_t n_batch;
10
+ uint32_t n_ubatch;
11
+ uint32_t n_seq_max;
12
+ int n_threads; // number of threads to use for generation
13
+ int n_threads_batch; // number of threads to use for batch processing
14
+
15
+ float rope_freq_base;
16
+ float rope_freq_scale;
17
+
18
+ uint32_t n_ctx_orig_yarn;
19
+ // These hyperparameters are not exposed in GGUF, because all
20
+ // existing YaRN models use the same values for them.
21
+ float yarn_ext_factor;
22
+ float yarn_attn_factor;
23
+ float yarn_beta_fast;
24
+ float yarn_beta_slow;
25
+ float defrag_thold;
26
+
27
+ bool embeddings;
28
+ bool causal_attn;
29
+ bool offload_kqv;
30
+ bool flash_attn;
31
+ bool no_perf;
32
+
33
+ enum llama_pooling_type pooling_type;
34
+
35
+ ggml_backend_sched_eval_callback cb_eval;
36
+ void * cb_eval_user_data;
37
+ };
examples/talk-llama/llama-grammar.cpp CHANGED
@@ -1,5 +1,6 @@
1
  #include "llama-grammar.h"
2
 
 
3
  #include "llama-vocab.h"
4
  #include "llama-sampling.h"
5
 
@@ -822,15 +823,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
822
  return grammar->stacks;
823
  }
824
 
825
- void llama_grammar_accept(
826
- const llama_grammar_rules & rules,
827
- const llama_grammar_stacks & stacks,
828
- const uint32_t chr,
829
- llama_grammar_stacks & stacks_new) {
830
- stacks_new.clear();
831
- stacks_new.reserve(stacks.size());
832
 
833
- for (const auto & stack : stacks) {
834
  if (stack.empty()) {
835
  continue;
836
  }
@@ -844,9 +841,11 @@ void llama_grammar_accept(
844
  if (!llama_grammar_is_end_of_sequence(pos)) {
845
  new_stack.push_back(pos);
846
  }
847
- llama_grammar_advance_stack(rules, new_stack, stacks_new);
848
  }
849
  }
 
 
850
  }
851
 
852
  llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
@@ -1051,7 +1050,12 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
1051
  }
1052
 
1053
  struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1054
- llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
 
 
 
 
 
1055
 
1056
  // redirect elements in stacks to point to new rules
1057
  for (size_t is = 0; is < result->stacks.size(); is++) {
@@ -1059,7 +1063,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
1059
  for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
1060
  for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
1061
  if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1062
- result->stacks[is][ie] = &result->rules[ir0][ir1];
1063
  }
1064
  }
1065
  }
@@ -1126,11 +1130,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1126
  const auto decoded = decode_utf8(piece, grammar.partial_utf8);
1127
  const auto & code_points = decoded.first;
1128
 
1129
- llama_grammar_stacks stacks_new;
1130
-
1131
  for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1132
- llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
1133
- grammar.stacks = std::move(stacks_new);
1134
  }
1135
 
1136
  grammar.partial_utf8 = decoded.second;
 
1
  #include "llama-grammar.h"
2
 
3
+ #include "llama-impl.h"
4
  #include "llama-vocab.h"
5
  #include "llama-sampling.h"
6
 
 
823
  return grammar->stacks;
824
  }
825
 
826
+ void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
827
+ llama_grammar_stacks stacks_new;
828
+ stacks_new.reserve(grammar->stacks.size());
 
 
 
 
829
 
830
+ for (const auto & stack : grammar->stacks) {
831
  if (stack.empty()) {
832
  continue;
833
  }
 
841
  if (!llama_grammar_is_end_of_sequence(pos)) {
842
  new_stack.push_back(pos);
843
  }
844
+ llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
845
  }
846
  }
847
+
848
+ grammar->stacks = std::move(stacks_new);
849
  }
850
 
851
  llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
 
1050
  }
1051
 
1052
  struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1053
+ llama_grammar * result = new llama_grammar {
1054
+ grammar.vocab,
1055
+ grammar.rules,
1056
+ grammar.stacks,
1057
+ grammar.partial_utf8,
1058
+ };
1059
 
1060
  // redirect elements in stacks to point to new rules
1061
  for (size_t is = 0; is < result->stacks.size(); is++) {
 
1063
  for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
1064
  for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
1065
  if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1066
+ result->stacks[is][ie] = &result->rules[ir0][ir1];
1067
  }
1068
  }
1069
  }
 
1130
  const auto decoded = decode_utf8(piece, grammar.partial_utf8);
1131
  const auto & code_points = decoded.first;
1132
 
 
 
1133
  for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1134
+ llama_grammar_accept(&grammar, *it);
 
1135
  }
1136
 
1137
  grammar.partial_utf8 = decoded.second;
examples/talk-llama/llama-grammar.h CHANGED
@@ -1,8 +1,10 @@
1
  #pragma once
2
 
3
- #include "llama-impl.h"
4
 
5
  #include <map>
 
 
6
 
7
  struct llama_vocab;
8
 
@@ -58,6 +60,7 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>;
58
  using llama_grammar_stacks = std::vector<llama_grammar_stack>;
59
  using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
60
 
 
61
  const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
62
  llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
63
 
@@ -65,11 +68,7 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar
65
  // be positioned at a character range (see `llama_grammar_advance_stack`), and
66
  // produces the N possible stacks if the given char is accepted at those
67
  // positions
68
- void llama_grammar_accept(
69
- const llama_grammar_rules & rules,
70
- const llama_grammar_stacks & stacks,
71
- uint32_t chr,
72
- llama_grammar_stacks & stacks_new);
73
 
74
  std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
75
  const llama_grammar_rules & rules,
 
1
  #pragma once
2
 
3
+ #include "llama.h"
4
 
5
  #include <map>
6
+ #include <string>
7
+ #include <vector>
8
 
9
  struct llama_vocab;
10
 
 
60
  using llama_grammar_stacks = std::vector<llama_grammar_stack>;
61
  using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
62
 
63
+ // TODO: remove, needed for tests atm
64
  const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
65
  llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
66
 
 
68
  // be positioned at a character range (see `llama_grammar_advance_stack`), and
69
  // produces the N possible stacks if the given char is accepted at those
70
  // positions
71
+ void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
 
 
 
 
72
 
73
  std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
74
  const llama_grammar_rules & rules,
examples/talk-llama/llama-hparams.cpp ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-hparams.h"
2
+
3
+ #include "ggml.h"
4
+
5
+ uint32_t llama_hparams::n_head(uint32_t il) const {
6
+ if (il < n_layer) {
7
+ return n_head_arr[il];
8
+ }
9
+
10
+ GGML_ABORT("fatal error");
11
+ }
12
+
13
+ uint32_t llama_hparams::n_head_kv(uint32_t il) const {
14
+ if (il < n_layer) {
15
+ return n_head_kv_arr[il];
16
+ }
17
+
18
+ GGML_ABORT("fatal error");
19
+ }
20
+
21
+ uint32_t llama_hparams::n_ff(uint32_t il) const {
22
+ if (il < n_layer) {
23
+ return n_ff_arr[il];
24
+ }
25
+
26
+ GGML_ABORT("fatal error");
27
+ }
28
+
29
+ uint32_t llama_hparams::n_gqa(uint32_t il) const {
30
+ const uint32_t n_head = this->n_head(il);
31
+ const uint32_t n_head_kv = this->n_head_kv(il);
32
+
33
+ if (n_head_kv == 0) {
34
+ return 0;
35
+ }
36
+
37
+ return n_head/n_head_kv;
38
+ }
39
+
40
+ uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
41
+ const uint32_t n_head_kv = this->n_head_kv(il);
42
+
43
+ return n_embd_head_k * n_head_kv;
44
+ }
45
+
46
+ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
47
+ const uint32_t n_head_kv = this->n_head_kv(il);
48
+
49
+ return n_embd_head_v * n_head_kv;
50
+ }
51
+
52
+ uint32_t llama_hparams::n_embd_k_s() const {
53
+ if (wkv_head_size != 0) {
54
+ // for RWKV models
55
+ return 2 * n_embd;
56
+ }
57
+
58
+ // TODO: maybe support other convolution strides than 1
59
+ // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
60
+ return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
61
+ }
62
+
63
+ uint32_t llama_hparams::n_embd_v_s() const {
64
+ if (wkv_head_size != 0) {
65
+ // corresponds to RWKV's wkv_states size
66
+ return n_embd * wkv_head_size;
67
+ }
68
+
69
+ // corresponds to Mamba's ssm_states size
70
+ return ssm_d_state * ssm_d_inner;
71
+ }
examples/talk-llama/llama-hparams.h ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+
5
+ #include <array>
6
+
7
+ // bump if necessary
8
+ #define LLAMA_MAX_LAYERS 512
9
+ #define LLAMA_MAX_EXPERTS 256 // DeepSeekV3
10
+
11
+ enum llama_expert_gating_func_type {
12
+ LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
13
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
14
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
15
+ };
16
+
17
+ struct llama_hparams_posnet {
18
+ uint32_t n_embd;
19
+ uint32_t n_layer;
20
+ };
21
+
22
+ struct llama_hparams_convnext {
23
+ uint32_t n_embd;
24
+ uint32_t n_layer;
25
+ };
26
+
27
+ struct llama_hparams {
28
+ bool vocab_only;
29
+ bool rope_finetuned;
30
+ bool use_par_res;
31
+ bool swin_norm;
32
+
33
+ uint32_t n_vocab = 0;
34
+ uint32_t n_ctx_train; // context size the model was trained on
35
+ uint32_t n_embd;
36
+ uint32_t n_embd_features = 0;
37
+ uint32_t n_layer;
38
+ uint32_t n_rot;
39
+ uint32_t n_swa = 0; // sliding window attention (SWA)
40
+ uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
41
+ uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
42
+ uint32_t n_expert = 0;
43
+ uint32_t n_expert_used = 0;
44
+ uint32_t n_vocab_type = 0; // for BERT-style token types
45
+ uint32_t n_rel_attn_bkts = 0;
46
+
47
+ // for WavTokenizer
48
+ struct llama_hparams_posnet posnet;
49
+ struct llama_hparams_convnext convnext;
50
+
51
+ std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
52
+ std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
53
+ std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
54
+
55
+ uint32_t n_layer_dense_lead = 0;
56
+ uint32_t n_lora_q = 0;
57
+ uint32_t n_lora_kv = 0;
58
+ uint32_t n_ff_exp = 0;
59
+ uint32_t n_ff_shexp = 0;
60
+ uint32_t n_expert_shared = 0;
61
+ uint32_t n_norm_groups = 0;
62
+
63
+ float expert_weights_scale = 0.0;
64
+ bool expert_weights_norm = false;
65
+ uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
66
+
67
+ float f_norm_eps;
68
+ float f_norm_rms_eps;
69
+ float f_norm_group_eps;
70
+
71
+ float f_attn_logit_softcapping = 50.0f;
72
+ float f_final_logit_softcapping = 30.0f;
73
+
74
+ // for RWKV
75
+ uint32_t rescale_every_n_layers = 0;
76
+ uint32_t time_mix_extra_dim = 0;
77
+ uint32_t time_decay_extra_dim = 0;
78
+ uint32_t wkv_head_size = 0;
79
+
80
+ float rope_attn_factor = 1.0f;
81
+ float rope_freq_base_train;
82
+ float rope_freq_scale_train;
83
+ uint32_t n_ctx_orig_yarn;
84
+ float rope_yarn_log_mul;
85
+
86
+ std::array<int, 4> rope_sections;
87
+
88
+ // for State Space Models
89
+ uint32_t ssm_d_conv = 0;
90
+ uint32_t ssm_d_inner = 0;
91
+ uint32_t ssm_d_state = 0;
92
+ uint32_t ssm_dt_rank = 0;
93
+
94
+ bool ssm_dt_b_c_rms = false;
95
+
96
+ float f_clamp_kqv = 0.0f;
97
+ float f_max_alibi_bias = 0.0f;
98
+ float f_logit_scale = 0.0f;
99
+
100
+ // Additional scale factors (Granite/Granite MoE)
101
+ float f_residual_scale = 0.0f;
102
+ float f_embedding_scale = 0.0f;
103
+ float f_attention_scale = 0.0f;
104
+
105
+ bool causal_attn = true;
106
+ bool use_alibi = false;
107
+ bool attn_soft_cap = false;
108
+
109
+ // needed by encoder-decoder models (e.g. T5, FLAN-T5)
110
+ // ref: https://github.com/ggerganov/llama.cpp/pull/8141
111
+ llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
112
+
113
+ enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
114
+ enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
115
+ enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
116
+
117
+ uint32_t n_head(uint32_t il = 0) const;
118
+
119
+ uint32_t n_head_kv(uint32_t il = 0) const;
120
+
121
+ uint32_t n_ff(uint32_t il = 0) const;
122
+
123
+ uint32_t n_gqa(uint32_t il = 0) const;
124
+
125
+ // dimension of key embeddings across all k-v heads
126
+ uint32_t n_embd_k_gqa(uint32_t il = 0) const;
127
+
128
+ // dimension of value embeddings across all k-v heads
129
+ uint32_t n_embd_v_gqa(uint32_t il = 0) const;
130
+
131
+ // dimension of the rolling state embeddings
132
+ // corresponds to Mamba's conv_states size or RWKV's token_shift states size
133
+ uint32_t n_embd_k_s() const;
134
+
135
+ // dimension of the recurrent state embeddings
136
+ uint32_t n_embd_v_s() const;
137
+ };
138
+
139
+ static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
140
+
examples/talk-llama/llama-impl.cpp ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-impl.h"
2
+
3
+ #include "llama.h"
4
+
5
+ #include <cinttypes>
6
+ #include <climits>
7
+ #include <cstdarg>
8
+ #include <cstring>
9
+ #include <vector>
10
+ #include <sstream>
11
+
12
+ struct llama_logger_state {
13
+ ggml_log_callback log_callback = llama_log_callback_default;
14
+ void * log_callback_user_data = nullptr;
15
+ };
16
+
17
+ static llama_logger_state g_logger_state;
18
+
19
+ time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
20
+
21
+ time_meas::~time_meas() {
22
+ if (t_start_us >= 0) {
23
+ t_acc += ggml_time_us() - t_start_us;
24
+ }
25
+ }
26
+
27
+ void llama_log_set(ggml_log_callback log_callback, void * user_data) {
28
+ ggml_log_set(log_callback, user_data);
29
+ g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
30
+ g_logger_state.log_callback_user_data = user_data;
31
+ }
32
+
33
+ static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {
34
+ va_list args_copy;
35
+ va_copy(args_copy, args);
36
+ char buffer[128];
37
+ int len = vsnprintf(buffer, 128, format, args);
38
+ if (len < 128) {
39
+ g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
40
+ } else {
41
+ char * buffer2 = new char[len + 1];
42
+ vsnprintf(buffer2, len + 1, format, args_copy);
43
+ buffer2[len] = 0;
44
+ g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
45
+ delete[] buffer2;
46
+ }
47
+ va_end(args_copy);
48
+ }
49
+
50
+ void llama_log_internal(ggml_log_level level, const char * format, ...) {
51
+ va_list args;
52
+ va_start(args, format);
53
+ llama_log_internal_v(level, format, args);
54
+ va_end(args);
55
+ }
56
+
57
+ void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
58
+ (void) level;
59
+ (void) user_data;
60
+ fputs(text, stderr);
61
+ fflush(stderr);
62
+ }
63
+
64
+ void replace_all(std::string & s, const std::string & search, const std::string & replace) {
65
+ if (search.empty()) {
66
+ return;
67
+ }
68
+ std::string builder;
69
+ builder.reserve(s.length());
70
+ size_t pos = 0;
71
+ size_t last_pos = 0;
72
+ while ((pos = s.find(search, last_pos)) != std::string::npos) {
73
+ builder.append(s, last_pos, pos - last_pos);
74
+ builder.append(replace);
75
+ last_pos = pos + search.length();
76
+ }
77
+ builder.append(s, last_pos, std::string::npos);
78
+ s = std::move(builder);
79
+ }
80
+
81
+ std::string format(const char * fmt, ...) {
82
+ va_list ap;
83
+ va_list ap2;
84
+ va_start(ap, fmt);
85
+ va_copy(ap2, ap);
86
+ int size = vsnprintf(NULL, 0, fmt, ap);
87
+ GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
88
+ std::vector<char> buf(size + 1);
89
+ int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
90
+ GGML_ASSERT(size2 == size);
91
+ va_end(ap2);
92
+ va_end(ap);
93
+ return std::string(buf.data(), size);
94
+ }
95
+
96
+ std::string llama_format_tensor_shape(const std::vector<int64_t> & ne) {
97
+ char buf[256];
98
+ snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
99
+ for (size_t i = 1; i < ne.size(); i++) {
100
+ snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
101
+ }
102
+ return buf;
103
+ }
104
+
105
+ std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
106
+ char buf[256];
107
+ snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
108
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
109
+ snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]);
110
+ }
111
+ return buf;
112
+ }
113
+
114
+ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
115
+ switch (type) {
116
+ case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]);
117
+ case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]);
118
+ case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]);
119
+ case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]);
120
+ case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]);
121
+ case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]);
122
+ case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]);
123
+ case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]);
124
+ case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]);
125
+ case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]);
126
+ case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false";
127
+ default: return format("unknown type %d", type);
128
+ }
129
+ }
130
+
131
+ std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
132
+ const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
133
+
134
+ switch (type) {
135
+ case GGUF_TYPE_STRING:
136
+ return gguf_get_val_str(ctx_gguf, i);
137
+ case GGUF_TYPE_ARRAY:
138
+ {
139
+ const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
140
+ int arr_n = gguf_get_arr_n(ctx_gguf, i);
141
+ const void * data = gguf_get_arr_data(ctx_gguf, i);
142
+ std::stringstream ss;
143
+ ss << "[";
144
+ for (int j = 0; j < arr_n; j++) {
145
+ if (arr_type == GGUF_TYPE_STRING) {
146
+ std::string val = gguf_get_arr_str(ctx_gguf, i, j);
147
+ // escape quotes
148
+ replace_all(val, "\\", "\\\\");
149
+ replace_all(val, "\"", "\\\"");
150
+ ss << '"' << val << '"';
151
+ } else if (arr_type == GGUF_TYPE_ARRAY) {
152
+ ss << "???";
153
+ } else {
154
+ ss << gguf_data_to_str(arr_type, data, j);
155
+ }
156
+ if (j < arr_n - 1) {
157
+ ss << ", ";
158
+ }
159
+ }
160
+ ss << "]";
161
+ return ss.str();
162
+ }
163
+ default:
164
+ return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
165
+ }
166
+ }
examples/talk-llama/llama-impl.h CHANGED
@@ -1,10 +1,9 @@
1
  #pragma once
2
 
3
- #include "llama.h"
4
 
5
  #include <string>
6
  #include <vector>
7
- #include <stdexcept>
8
 
9
  #ifdef __GNUC__
10
  #ifdef __MINGW32__
@@ -35,147 +34,28 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
35
  // helpers
36
  //
37
 
38
- struct time_meas {
39
- time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
 
 
 
40
 
41
- ~time_meas() {
42
- if (t_start_us >= 0) {
43
- t_acc += ggml_time_us() - t_start_us;
44
- }
45
- }
46
 
47
  const int64_t t_start_us;
48
 
49
  int64_t & t_acc;
50
  };
51
 
52
- static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
53
- if (search.empty()) {
54
- return;
55
- }
56
- std::string builder;
57
- builder.reserve(s.length());
58
- size_t pos = 0;
59
- size_t last_pos = 0;
60
- while ((pos = s.find(search, last_pos)) != std::string::npos) {
61
- builder.append(s, last_pos, pos - last_pos);
62
- builder.append(replace);
63
- last_pos = pos + search.length();
64
- }
65
- builder.append(s, last_pos, std::string::npos);
66
- s = std::move(builder);
67
- }
68
-
69
- const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
70
- struct llama_context * ctx
71
- );
72
-
73
- // the ring buffer works similarly to std::deque, but with a fixed capacity
74
- template<typename T>
75
- struct ring_buffer {
76
- ring_buffer(size_t cap) : capacity(cap), data(cap) {}
77
-
78
- T & front() {
79
- if (sz == 0) {
80
- throw std::runtime_error("ring buffer is empty");
81
- }
82
- return data[first];
83
- }
84
-
85
- const T & front() const {
86
- if (sz == 0) {
87
- throw std::runtime_error("ring buffer is empty");
88
- }
89
- return data[first];
90
- }
91
-
92
- T & back() {
93
- if (sz == 0) {
94
- throw std::runtime_error("ring buffer is empty");
95
- }
96
- return data[pos];
97
- }
98
-
99
- const T & back() const {
100
- if (sz == 0) {
101
- throw std::runtime_error("ring buffer is empty");
102
- }
103
- return data[pos];
104
- }
105
 
106
- void push_back(const T & value) {
107
- if (capacity == 0) {
108
- throw std::runtime_error("ring buffer: capacity is zero");
109
- }
110
 
111
- if (sz == capacity) {
112
- // advance the start when buffer is full
113
- first = (first + 1) % capacity;
114
- } else {
115
- sz++;
116
- }
117
- data[pos] = value;
118
- pos = (pos + 1) % capacity;
119
- }
120
 
121
- T pop_front() {
122
- if (sz == 0) {
123
- throw std::runtime_error("ring buffer is empty");
124
- }
125
- T value = data[first];
126
- first = (first + 1) % capacity;
127
- sz--;
128
- return value;
129
- }
130
-
131
- //T & operator[](size_t i) {
132
- // if (i >= sz) {
133
- // throw std::runtime_error("ring buffer: index out of bounds");
134
- // }
135
- // return data[(first + i) % capacity];
136
- //}
137
-
138
- //const T & at(size_t i) const {
139
- // if (i >= sz) {
140
- // throw std::runtime_error("ring buffer: index out of bounds");
141
- // }
142
- // return data[(first + i) % capacity];
143
- //}
144
-
145
- const T & rat(size_t i) const {
146
- if (i >= sz) {
147
- throw std::runtime_error("ring buffer: index out of bounds");
148
- }
149
- return data[(first + sz - i - 1) % capacity];
150
- }
151
-
152
- std::vector<T> to_vector() const {
153
- std::vector<T> result;
154
- result.reserve(sz);
155
- for (size_t i = 0; i < sz; i++) {
156
- result.push_back(data[(first + i) % capacity]);
157
- }
158
- return result;
159
- }
160
-
161
- void clear() {
162
- // here only reset the status of the buffer
163
- sz = 0;
164
- first = 0;
165
- pos = 0;
166
- }
167
-
168
- bool empty() const {
169
- return sz == 0;
170
- }
171
-
172
- size_t size() const {
173
- return sz;
174
- }
175
-
176
- size_t capacity = 0;
177
- size_t sz = 0;
178
- size_t first = 0;
179
- size_t pos = 0;
180
- std::vector<T> data;
181
- };
 
1
  #pragma once
2
 
3
+ #include "ggml.h" // for ggml_log_level
4
 
5
  #include <string>
6
  #include <vector>
 
7
 
8
  #ifdef __GNUC__
9
  #ifdef __MINGW32__
 
34
  // helpers
35
  //
36
 
37
+ template <typename T>
38
+ struct no_init {
39
+ T value;
40
+ no_init() { /* do nothing */ }
41
+ };
42
 
43
+ struct time_meas {
44
+ time_meas(int64_t & t_acc, bool disable = false);
45
+ ~time_meas();
 
 
46
 
47
  const int64_t t_start_us;
48
 
49
  int64_t & t_acc;
50
  };
51
 
52
+ void replace_all(std::string & s, const std::string & search, const std::string & replace);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ // TODO: rename to llama_format ?
55
+ LLAMA_ATTRIBUTE_FORMAT(1, 2)
56
+ std::string format(const char * fmt, ...);
 
57
 
58
+ std::string llama_format_tensor_shape(const std::vector<int64_t> & ne);
59
+ std::string llama_format_tensor_shape(const struct ggml_tensor * t);
 
 
 
 
 
 
 
60
 
61
+ std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/talk-llama/llama-kv-cache.cpp ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-kv-cache.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-batch.h"
5
+ #include "llama-cparams.h"
6
+ #include "llama-model.h"
7
+
8
+ #include <algorithm>
9
+ #include <limits>
10
+ #include <map>
11
+
12
+ static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
13
+
14
+ uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
15
+ // the FA kernels require padding to avoid extra runtime boundary checks
16
+ return cparams.flash_attn ? 256u : 32u;
17
+ }
18
+
19
+ bool llama_kv_cache_init(
20
+ struct llama_kv_cache & cache,
21
+ const llama_model & model,
22
+ const llama_cparams & cparams,
23
+ ggml_type type_k,
24
+ ggml_type type_v,
25
+ uint32_t kv_size,
26
+ bool offload) {
27
+ const struct llama_hparams & hparams = model.hparams;
28
+
29
+ const int32_t n_layer = hparams.n_layer;
30
+
31
+ cache.has_shift = false;
32
+
33
+ cache.recurrent = llama_model_is_recurrent(&model);
34
+ cache.v_trans = !cache.recurrent && !cparams.flash_attn;
35
+ cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
36
+
37
+ LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
38
+ __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift);
39
+
40
+ cache.head = 0;
41
+ cache.size = kv_size;
42
+ cache.used = 0;
43
+
44
+ cache.type_k = type_k;
45
+ cache.type_v = type_v;
46
+
47
+ cache.cells.clear();
48
+ cache.cells.resize(kv_size);
49
+
50
+ // create a context for each buffer type
51
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
52
+ auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
53
+ auto it = ctx_map.find(buft);
54
+ if (it == ctx_map.end()) {
55
+ struct ggml_init_params params = {
56
+ /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
57
+ /*.mem_buffer =*/ NULL,
58
+ /*.no_alloc =*/ true,
59
+ };
60
+ ggml_context * ctx = ggml_init(params);
61
+ if (!ctx) {
62
+ return nullptr;
63
+ }
64
+ ctx_map[buft] = ctx;
65
+ cache.ctxs.emplace_back(ctx);
66
+ return ctx;
67
+ }
68
+ return it->second;
69
+ };
70
+
71
+ cache.k_l.reserve(n_layer);
72
+ cache.v_l.reserve(n_layer);
73
+
74
+ for (int i = 0; i < n_layer; i++) {
75
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
76
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
77
+
78
+ LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
79
+
80
+ ggml_backend_buffer_type_t buft;
81
+ if (offload) {
82
+ auto * dev = model.dev_layer.at(i).dev;
83
+ buft = ggml_backend_dev_buffer_type(dev);
84
+ } else {
85
+ buft = ggml_backend_cpu_buffer_type();
86
+ }
87
+ ggml_context * ctx = ctx_for_buft(buft);
88
+
89
+ if (!ctx) {
90
+ LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
91
+ return false;
92
+ }
93
+
94
+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
95
+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
96
+ ggml_format_name(k, "cache_k_l%d", i);
97
+ ggml_format_name(v, "cache_v_l%d", i);
98
+ cache.k_l.push_back(k);
99
+ cache.v_l.push_back(v);
100
+ }
101
+
102
+ // allocate tensors and initialize the buffers to avoid NaNs in the padding
103
+ for (auto it : ctx_map) {
104
+ auto * buft = it.first;
105
+ auto * ctx = it.second;
106
+
107
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
108
+ if (!buf) {
109
+ LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
110
+ return false;
111
+ }
112
+ ggml_backend_buffer_clear(buf, 0);
113
+ LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
114
+ cache.bufs.emplace_back(buf);
115
+ }
116
+
117
+ return true;
118
+ }
119
+
120
+ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
121
+ struct llama_kv_cache & cache,
122
+ const struct llama_ubatch & ubatch) {
123
+ const uint32_t n_tokens = ubatch.n_tokens;
124
+ const uint32_t n_seqs = ubatch.n_seqs;
125
+ const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
126
+
127
+ if (cache.recurrent) {
128
+ // For recurrent state architectures (like Mamba or RWKV),
129
+ // each cache cell can store the state for a whole sequence.
130
+ // A slot should be always be contiguous.
131
+
132
+ // can only process batches with an equal number of new tokens in each sequence
133
+ GGML_ASSERT(ubatch.equal_seqs);
134
+
135
+ int32_t min = cache.size - 1;
136
+ int32_t max = 0;
137
+
138
+ // everything should fit if all seq_ids are smaller than the max
139
+ for (uint32_t s = 0; s < n_seqs; ++s) {
140
+ const uint32_t n_seq_id = ubatch.n_seq_id[s];
141
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
142
+ const llama_seq_id seq_id = ubatch.seq_id[s][j];
143
+
144
+ if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
145
+ // too big seq_id
146
+ // TODO: would it be possible to resize the cache instead?
147
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
148
+ return llama_kv_cache_slot_info_failed;
149
+ }
150
+ if (j > 0) {
151
+ llama_kv_cell & seq = cache.cells[seq_id];
152
+ if (seq.tail >= 0) {
153
+ llama_kv_cell & cell = cache.cells[seq.tail];
154
+ // clear cells from seq_ids that become shared
155
+ // (should not normally happen, but let's handle it anyway)
156
+ cell.seq_id.erase(seq_id);
157
+ seq.tail = -1;
158
+ if (cell.seq_id.empty()) {
159
+ cell.pos = -1;
160
+ cell.src = -1;
161
+ cache.used -= 1;
162
+ }
163
+ }
164
+ }
165
+ }
166
+ }
167
+
168
+ #ifndef NDEBUG
169
+ {
170
+ std::vector<int32_t> tails_verif;
171
+ tails_verif.assign(cache.size, -1);
172
+ for (uint32_t i = 0; i < cache.size; ++i) {
173
+ llama_kv_cell & cell = cache.cells[i];
174
+ for (llama_seq_id seq_id : cell.seq_id) {
175
+ if (tails_verif[seq_id] != -1) {
176
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
177
+ }
178
+ tails_verif[seq_id] = i;
179
+ }
180
+ }
181
+ for (uint32_t i = 0; i < cache.size; ++i) {
182
+ if (tails_verif[i] != cache.cells[i].tail) {
183
+ LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
184
+ }
185
+ }
186
+ }
187
+ #endif
188
+
189
+ // find next empty cell
190
+ uint32_t next_empty_cell = cache.head;
191
+
192
+ for (uint32_t i = 0; i < cache.size; ++i) {
193
+ if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
194
+ llama_kv_cell & cell = cache.cells[next_empty_cell];
195
+ if (cell.is_empty()) { break; }
196
+ next_empty_cell += 1;
197
+ }
198
+
199
+ // find usable cell range
200
+ for (uint32_t s = 0; s < n_seqs; ++s) {
201
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
202
+ llama_kv_cell & seq_meta = cache.cells[seq_id];
203
+ bool has_cell = false;
204
+ if (seq_meta.tail >= 0) {
205
+ llama_kv_cell & cell = cache.cells[seq_meta.tail];
206
+ GGML_ASSERT(cell.has_seq_id(seq_id));
207
+ // does this seq_id "own" the cell?
208
+ if (cell.seq_id.size() == 1) { has_cell = true; }
209
+ }
210
+ if (!has_cell) {
211
+ llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
212
+ GGML_ASSERT(empty_cell.is_empty());
213
+ // copy old tail into the empty cell
214
+ if (seq_meta.tail >= 0) {
215
+ llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
216
+ empty_cell.pos = orig_cell.pos;
217
+ empty_cell.src = orig_cell.src;
218
+ orig_cell.seq_id.erase(seq_id);
219
+ empty_cell.seq_id.insert(seq_id); // will be overwritten
220
+ }
221
+ seq_meta.tail = next_empty_cell;
222
+ // find next empty cell
223
+ if (s + 1 < n_seqs) {
224
+ next_empty_cell += 1;
225
+ for (uint32_t i = 0; i < cache.size; ++i) {
226
+ if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
227
+ llama_kv_cell & cell = cache.cells[next_empty_cell];
228
+ if (cell.is_empty()) { break; }
229
+ next_empty_cell += 1;
230
+ }
231
+ }
232
+ }
233
+ if (min > seq_meta.tail) { min = seq_meta.tail; }
234
+ if (max < seq_meta.tail) { max = seq_meta.tail; }
235
+ }
236
+
237
+ // gather and re-order
238
+ for (uint32_t s = 0; s < n_seqs; ++s) {
239
+ int32_t dst_id = s + min;
240
+ int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
241
+ if (dst_id != src_id) {
242
+ llama_kv_cell & dst_cell = cache.cells[dst_id];
243
+ llama_kv_cell & src_cell = cache.cells[src_id];
244
+
245
+ std::swap(dst_cell.pos, src_cell.pos);
246
+ std::swap(dst_cell.src, src_cell.src);
247
+ std::swap(dst_cell.seq_id, src_cell.seq_id);
248
+
249
+ // swap tails (assuming they NEVER overlap)
250
+ for (const llama_seq_id seq_id : src_cell.seq_id) {
251
+ cache.cells[seq_id].tail = src_id;
252
+ }
253
+ for (const llama_seq_id seq_id : dst_cell.seq_id) {
254
+ cache.cells[seq_id].tail = dst_id;
255
+ }
256
+ }
257
+ }
258
+
259
+ // update the pos of the used seqs
260
+ for (uint32_t s = 0; s < n_seqs; ++s) {
261
+ const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
262
+ int32_t cell_id = s + min;
263
+ llama_kv_cell & cell = cache.cells[cell_id];
264
+
265
+ if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
266
+ // What should happen when the pos backtracks or skips a value?
267
+ // Clearing the state mid-batch would require special-casing which isn't done.
268
+ LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
269
+ __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
270
+ }
271
+ cell.pos = last_pos;
272
+ cell.seq_id.clear();
273
+ for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
274
+ const llama_seq_id seq_id = ubatch.seq_id[s][j];
275
+ cell.seq_id.insert(seq_id);
276
+ cache.cells[seq_id].tail = cell_id;
277
+ }
278
+ }
279
+
280
+ // allow getting the range of used cells, from head to head + n
281
+ cache.head = min;
282
+ cache.n = max - min + 1;
283
+ cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
284
+ [](const llama_kv_cell& cell){ return !cell.is_empty(); });
285
+
286
+ // sanity check
287
+ return llama_kv_cache_slot_info(cache.n >= n_seqs);
288
+ }
289
+ // otherwise, one cell per token.
290
+
291
+ if (n_tokens > cache.size) {
292
+ LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
293
+ return llama_kv_cache_slot_info_failed;
294
+ }
295
+
296
+ uint32_t n_tested = 0;
297
+
298
+ while (true) {
299
+ if (cache.head + n_tokens > cache.size) {
300
+ n_tested += cache.size - cache.head;
301
+ cache.head = 0;
302
+ continue;
303
+ }
304
+
305
+ bool found = true;
306
+ for (uint32_t i = 0; i < n_tokens; i++) {
307
+ if (cache.cells[cache.head + i].pos >= 0) {
308
+ found = false;
309
+ cache.head += i + 1;
310
+ n_tested += i + 1;
311
+ break;
312
+ }
313
+ }
314
+
315
+ if (found) {
316
+ break;
317
+ }
318
+
319
+ if (n_tested >= cache.size) {
320
+ //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
321
+ return llama_kv_cache_slot_info_failed;
322
+ }
323
+ }
324
+
325
+ for (uint32_t s = 0; s < n_seqs; s++) {
326
+ for (uint32_t i = 0; i < n_seq_tokens; ++i) {
327
+ uint32_t k = s*n_seq_tokens + i;
328
+ cache.cells[cache.head + k].pos = ubatch.pos[k];
329
+
330
+ for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
331
+ cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
332
+ }
333
+ }
334
+ }
335
+
336
+ cache.used += n_tokens;
337
+
338
+ return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
339
+ }
340
+
341
+ uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
342
+ for (uint32_t i = cache.size; i > 0; --i) {
343
+ const llama_kv_cell & cell = cache.cells[i - 1];
344
+
345
+ if (cell.pos >= 0 && !cell.is_empty()) {
346
+ return i;
347
+ }
348
+ }
349
+
350
+ return 0;
351
+ }
352
+
353
+ void llama_kv_cache_clear(struct llama_kv_cache & cache) {
354
+ for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
355
+ cache.cells[i].pos = -1;
356
+ cache.cells[i].seq_id.clear();
357
+ cache.cells[i].src = -1;
358
+ cache.cells[i].tail = -1;
359
+ }
360
+ cache.head = 0;
361
+ cache.used = 0;
362
+
363
+ for (auto & buf : cache.bufs) {
364
+ ggml_backend_buffer_clear(buf.get(), 0);
365
+ }
366
+ }
367
+
368
+ bool llama_kv_cache_seq_rm(
369
+ struct llama_kv_cache & cache,
370
+ llama_seq_id seq_id,
371
+ llama_pos p0,
372
+ llama_pos p1) {
373
+ uint32_t new_head = cache.size;
374
+
375
+ if (p0 < 0) p0 = 0;
376
+ if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
377
+
378
+ // models like Mamba or RWKV can't have a state partially erased
379
+ if (cache.recurrent) {
380
+ if (seq_id >= (int64_t) cache.size) {
381
+ // could be fatal
382
+ return false;
383
+ }
384
+ if (0 <= seq_id) {
385
+ int32_t & tail_id = cache.cells[seq_id].tail;
386
+ if (tail_id >= 0) {
387
+ const llama_kv_cell & cell = cache.cells[tail_id];
388
+ // partial intersection is invalid
389
+ if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
390
+ return false;
391
+ }
392
+ // invalidate tails which will be cleared
393
+ if (p0 <= cell.pos && cell.pos < p1) {
394
+ tail_id = -1;
395
+ }
396
+ }
397
+ } else {
398
+ // seq_id is negative, then the range should include everything or nothing
399
+ if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
400
+ return false;
401
+ }
402
+ }
403
+ }
404
+
405
+ for (uint32_t i = 0; i < cache.size; ++i) {
406
+ if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
407
+ if (seq_id < 0) {
408
+ cache.cells[i].seq_id.clear();
409
+ } else if (cache.cells[i].has_seq_id(seq_id)) {
410
+ cache.cells[i].seq_id.erase(seq_id);
411
+ } else {
412
+ continue;
413
+ }
414
+ if (cache.cells[i].is_empty()) {
415
+ // keep count of the number of used cells
416
+ if (cache.cells[i].pos >= 0) cache.used--;
417
+
418
+ cache.cells[i].pos = -1;
419
+ cache.cells[i].src = -1;
420
+ if (new_head == cache.size) new_head = i;
421
+ }
422
+ }
423
+ }
424
+
425
+ // If we freed up a slot, set head to it so searching can start there.
426
+ if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
427
+
428
+ return true;
429
+ }
430
+
431
+ void llama_kv_cache_seq_cp(
432
+ struct llama_kv_cache & cache,
433
+ llama_seq_id seq_id_src,
434
+ llama_seq_id seq_id_dst,
435
+ llama_pos p0,
436
+ llama_pos p1) {
437
+ if (p0 < 0) p0 = 0;
438
+ if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
439
+
440
+ if (cache.recurrent) {
441
+ if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
442
+ llama_kv_cell & tail_src = cache.cells[seq_id_src];
443
+ llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
444
+ if (tail_dst.tail >= 0) {
445
+ // clear destination seq_id if it wasn't empty
446
+ llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
447
+
448
+ cell_dst.seq_id.erase(seq_id_dst);
449
+ tail_dst.tail = -1;
450
+ if (cell_dst.seq_id.empty()) {
451
+ cell_dst.pos = -1;
452
+ cell_dst.delta = -1;
453
+ cell_dst.src = -1;
454
+ cache.used -= 1;
455
+ }
456
+ }
457
+ if (tail_src.tail >= 0) {
458
+ llama_kv_cell & cell_src = cache.cells[tail_src.tail];
459
+
460
+ cell_src.seq_id.insert(seq_id_dst);
461
+ tail_dst.tail = tail_src.tail;
462
+ }
463
+ }
464
+
465
+ return;
466
+ }
467
+ // otherwise, this is the KV cache of a Transformer-like model
468
+
469
+ cache.head = 0;
470
+
471
+ for (uint32_t i = 0; i < cache.size; ++i) {
472
+ if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
473
+ cache.cells[i].seq_id.insert(seq_id_dst);
474
+ }
475
+ }
476
+ }
477
+
478
+ void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
479
+ uint32_t new_head = cache.size;
480
+
481
+ for (uint32_t i = 0; i < cache.size; ++i) {
482
+ if (cache.recurrent && (llama_seq_id) i != seq_id) {
483
+ cache.cells[i].tail = -1;
484
+ }
485
+ if (!cache.cells[i].has_seq_id(seq_id)) {
486
+ if (cache.cells[i].pos >= 0) cache.used--;
487
+ cache.cells[i].pos = -1;
488
+ cache.cells[i].src = -1;
489
+ cache.cells[i].seq_id.clear();
490
+ if (new_head == cache.size) new_head = i;
491
+ } else {
492
+ cache.cells[i].seq_id.clear();
493
+ cache.cells[i].seq_id.insert(seq_id);
494
+ }
495
+ }
496
+
497
+ // If we freed up a slot, set head to it so searching can start there.
498
+ if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
499
+ }
500
+
501
+ void llama_kv_cache_seq_add(
502
+ struct llama_kv_cache & cache,
503
+ llama_seq_id seq_id,
504
+ llama_pos p0,
505
+ llama_pos p1,
506
+ llama_pos delta) {
507
+ uint32_t new_head = cache.size;
508
+
509
+ if (p0 < 0) p0 = 0;
510
+ if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
511
+ // If there is no range then return early to avoid looping over the cache.
512
+ if (p0 == p1) return;
513
+
514
+ if (cache.recurrent) {
515
+ // for Mamba-like or RWKV models, only the pos needs to be shifted
516
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
517
+ const int32_t tail_id = cache.cells[seq_id].tail;
518
+ if (tail_id >= 0) {
519
+ llama_kv_cell & cell = cache.cells[tail_id];
520
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
521
+ cell.pos += delta;
522
+ }
523
+ }
524
+ }
525
+ return;
526
+ }
527
+
528
+ for (uint32_t i = 0; i < cache.size; ++i) {
529
+ if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
530
+ cache.has_shift = true;
531
+ cache.cells[i].pos += delta;
532
+ cache.cells[i].delta += delta;
533
+
534
+ if (cache.cells[i].pos < 0) {
535
+ if (!cache.cells[i].is_empty()) {
536
+ cache.used--;
537
+ }
538
+ cache.cells[i].pos = -1;
539
+ cache.cells[i].seq_id.clear();
540
+ if (new_head == cache.size) {
541
+ new_head = i;
542
+ }
543
+ }
544
+ }
545
+ }
546
+
547
+ // If we freed up a slot, set head to it so searching can start there.
548
+ // Otherwise we just start the next search from the beginning.
549
+ cache.head = new_head != cache.size ? new_head : 0;
550
+ }
551
+
552
+ void llama_kv_cache_seq_div(
553
+ struct llama_kv_cache & cache,
554
+ llama_seq_id seq_id,
555
+ llama_pos p0,
556
+ llama_pos p1,
557
+ int d) {
558
+ if (p0 < 0) p0 = 0;
559
+ if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
560
+ // If there is no range then return early to avoid looping over the cache.
561
+ if (p0 == p1) return;
562
+
563
+ if (cache.recurrent) {
564
+ // for Mamba-like or RWKV models, only the pos needs to be changed
565
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
566
+ const int32_t tail_id = cache.cells[seq_id].tail;
567
+ if (tail_id >= 0) {
568
+ llama_kv_cell & cell = cache.cells[tail_id];
569
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
570
+ cell.pos /= d;
571
+ }
572
+ }
573
+ }
574
+ return;
575
+ }
576
+
577
+ for (uint32_t i = 0; i < cache.size; ++i) {
578
+ if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
579
+ cache.has_shift = true;
580
+
581
+ {
582
+ llama_pos p_old = cache.cells[i].pos;
583
+ cache.cells[i].pos /= d;
584
+ cache.cells[i].delta += cache.cells[i].pos - p_old;
585
+ }
586
+ }
587
+ }
588
+ }
589
+
590
+ llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
591
+ llama_pos result = 0;
592
+
593
+ for (uint32_t i = 0; i < cache.size; ++i) {
594
+ if (cache.cells[i].has_seq_id(seq_id)) {
595
+ result = std::max(result, cache.cells[i].pos);
596
+ }
597
+ }
598
+
599
+ return result;
600
+ }
601
+
602
+ void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
603
+ if (!cache.recurrent) {
604
+ cache.do_defrag = true;
605
+ }
606
+ }
607
+
608
+ int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv) {
609
+ int result = 0;
610
+
611
+ for (uint32_t i = 0; i < kv.size; i++) {
612
+ result += kv.cells[i].seq_id.size();
613
+ }
614
+
615
+ return result;
616
+ }
617
+
618
+ int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv) {
619
+ return kv.used;
620
+ }
621
+
622
+ bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv) {
623
+ return kv.can_shift;
624
+ }
625
+
626
+ //
627
+ // kv cache view
628
+ //
629
+
630
+ struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max) {
631
+ struct llama_kv_cache_view result = {
632
+ /*.n_cells = */ 0,
633
+ /*.n_seq_max = */ n_seq_max,
634
+ /*.token_count = */ 0,
635
+ /*.used_cells = */ llama_get_kv_cache_used_cells(kv),
636
+ /*.max_contiguous = */ 0,
637
+ /*.max_contiguous_idx = */ -1,
638
+ /*.cells = */ nullptr,
639
+ /*.cells_sequences = */ nullptr,
640
+ };
641
+
642
+ return result;
643
+ }
644
+
645
+ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
646
+ if (view->cells != nullptr) {
647
+ free(view->cells);
648
+ view->cells = nullptr;
649
+ }
650
+ if (view->cells_sequences != nullptr) {
651
+ free(view->cells_sequences);
652
+ view->cells_sequences = nullptr;
653
+ }
654
+ }
655
+
656
+ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) {
657
+ if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) {
658
+ view->n_cells = int32_t(kv.size);
659
+ void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
660
+ GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
661
+ view->cells = (struct llama_kv_cache_view_cell *)p;
662
+ p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
663
+ GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
664
+ view->cells_sequences = (llama_seq_id *)p;
665
+ }
666
+
667
+ const std::vector<llama_kv_cell> & kv_cells = kv.cells;
668
+ llama_kv_cache_view_cell * c_curr = view->cells;
669
+ llama_seq_id * cs_curr = view->cells_sequences;
670
+ int32_t used_cells = 0;
671
+ int32_t token_count = 0;
672
+ int32_t curr_contig_idx = -1;
673
+ uint32_t max_contig = 0;
674
+ int32_t max_contig_idx = -1;
675
+
676
+ for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) {
677
+ const size_t curr_size = kv_cells[i].seq_id.size();
678
+ token_count += curr_size;
679
+ c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
680
+
681
+ if (curr_size > 0) {
682
+ if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
683
+ max_contig = i - curr_contig_idx;
684
+ max_contig_idx = curr_contig_idx;
685
+ }
686
+ curr_contig_idx = -1;
687
+ } else if (curr_contig_idx < 0) {
688
+ curr_contig_idx = i;
689
+ }
690
+
691
+ int seq_idx = 0;
692
+ for (const llama_seq_id it : kv_cells[i].seq_id) {
693
+ if (seq_idx >= view->n_seq_max) {
694
+ break;
695
+ }
696
+ cs_curr[seq_idx] = it;
697
+ seq_idx++;
698
+ }
699
+ if (seq_idx != 0) {
700
+ used_cells++;
701
+ }
702
+ for (; seq_idx < view->n_seq_max; seq_idx++) {
703
+ cs_curr[seq_idx] = -1;
704
+ }
705
+ }
706
+ if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
707
+ max_contig_idx = curr_contig_idx;
708
+ max_contig = kv_cells.size() - curr_contig_idx;
709
+ }
710
+ view->max_contiguous = max_contig;
711
+ view->max_contiguous_idx = max_contig_idx;
712
+ view->token_count = token_count;
713
+ view->used_cells = used_cells;
714
+ if (uint32_t(used_cells) != kv.used) {
715
+ LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
716
+ __func__, kv.used, used_cells);
717
+ }
718
+ }
examples/talk-llama/llama-kv-cache.h ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+
5
+ #include "ggml-cpp.h"
6
+
7
+ #include <set>
8
+ #include <vector>
9
+
10
+ struct llama_kv_cell {
11
+ llama_pos pos = -1;
12
+ llama_pos delta = 0;
13
+ int32_t src = -1; // used by recurrent state models to copy states
14
+ int32_t tail = -1;
15
+
16
+ std::set<llama_seq_id> seq_id;
17
+
18
+ bool has_seq_id(const llama_seq_id & id) const {
19
+ return seq_id.find(id) != seq_id.end();
20
+ }
21
+
22
+ bool is_empty() const {
23
+ return seq_id.empty();
24
+ }
25
+
26
+ bool is_same_seq(const llama_kv_cell & other) const {
27
+ return seq_id == other.seq_id;
28
+ }
29
+ };
30
+
31
+ // ring-buffer of cached KV data
32
+ struct llama_kv_cache {
33
+ bool has_shift = false;
34
+ bool do_defrag = false;
35
+ bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
36
+ bool v_trans = true; // the value tensor is transposed
37
+ bool can_shift = false;
38
+
39
+ // Note: The value of head isn't only used to optimize searching
40
+ // for a free KV slot. llama_decode_internal also uses it, so it
41
+ // cannot be freely changed after a slot has been allocated.
42
+ uint32_t head = 0;
43
+ uint32_t size = 0;
44
+ uint32_t used = 0; // used cells (i.e. at least one seq_id)
45
+
46
+ // computed before each graph build
47
+ uint32_t n = 0;
48
+
49
+ ggml_type type_k = GGML_TYPE_F16;
50
+ ggml_type type_v = GGML_TYPE_F16;
51
+
52
+ std::vector<llama_kv_cell> cells;
53
+
54
+ std::vector<struct ggml_tensor *> k_l; // per layer
55
+ std::vector<struct ggml_tensor *> v_l;
56
+
57
+ std::vector<ggml_context_ptr> ctxs;
58
+ std::vector<ggml_backend_buffer_ptr> bufs;
59
+
60
+ size_t total_size() const {
61
+ size_t size = 0;
62
+ for (const auto & buf : bufs) {
63
+ size += ggml_backend_buffer_get_size(buf.get());
64
+ }
65
+
66
+ return size;
67
+ }
68
+
69
+ // TODO: better data structures to reduce the cost of this operation
70
+ llama_pos max_pos() const {
71
+ llama_pos max_pos = -1;
72
+ for (const auto & cell : cells) {
73
+ max_pos = std::max(max_pos, cell.pos);
74
+ }
75
+
76
+ return max_pos;
77
+ }
78
+ };
79
+
80
+ // a structure holds information about the slot found in llama_kv_cache_find_slot
81
+ struct llama_kv_cache_slot_info {
82
+ std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
83
+ bool found = false; // the slot was found
84
+
85
+ explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
86
+ llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
87
+
88
+ operator bool() const { return found; }
89
+ };
90
+
91
+ // TODO: maybe not needed
92
+ uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
93
+
94
+ bool llama_kv_cache_init(
95
+ struct llama_kv_cache & cache,
96
+ const llama_model & model,
97
+ const llama_cparams & cparams,
98
+ ggml_type type_k,
99
+ ggml_type type_v,
100
+ uint32_t kv_size,
101
+ bool offload);
102
+
103
+ // find an empty slot of size "n_tokens" in the cache
104
+ // updates the cache head
105
+ // returns a structure holding information about the slot found
106
+ // Note: On success, it's important that cache.head points
107
+ // to the first cell of the slot.
108
+ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
109
+ struct llama_kv_cache & cache,
110
+ const struct llama_ubatch & batch);
111
+
112
+ // find how many cells are currently in use
113
+ uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
114
+
115
+ void llama_kv_cache_clear(struct llama_kv_cache & cache);
116
+
117
+ bool llama_kv_cache_seq_rm(
118
+ struct llama_kv_cache & cache,
119
+ llama_seq_id seq_id,
120
+ llama_pos p0,
121
+ llama_pos p1);
122
+
123
+ void llama_kv_cache_seq_cp(
124
+ struct llama_kv_cache & cache,
125
+ llama_seq_id seq_id_src,
126
+ llama_seq_id seq_id_dst,
127
+ llama_pos p0,
128
+ llama_pos p1);
129
+
130
+ void llama_kv_cache_seq_keep(
131
+ struct llama_kv_cache & cache,
132
+ llama_seq_id seq_id);
133
+
134
+ void llama_kv_cache_seq_add(
135
+ struct llama_kv_cache & cache,
136
+ llama_seq_id seq_id,
137
+ llama_pos p0,
138
+ llama_pos p1,
139
+ llama_pos delta);
140
+
141
+ void llama_kv_cache_seq_div(
142
+ struct llama_kv_cache & cache,
143
+ llama_seq_id seq_id,
144
+ llama_pos p0,
145
+ llama_pos p1,
146
+ int d);
147
+
148
+ llama_pos llama_kv_cache_seq_pos_max(
149
+ struct llama_kv_cache & cache,
150
+ llama_seq_id seq_id);
151
+
152
+ void llama_kv_cache_defrag(struct llama_kv_cache & cache);
153
+
154
+ int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
155
+
156
+ int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
157
+
158
+ bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
159
+
160
+ //
161
+ // kv cache view
162
+ //
163
+
164
+ struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
165
+
166
+ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
167
+
168
+ //
169
+ // kv cache restore
170
+ //
171
+
172
+ // saves the kv_cache state for future recovery.
173
+ // used to rollback llama_kv_cache_find_slot changes.
174
+ struct llama_kv_slot_restorer {
175
+ struct llama_kv_cache_state {
176
+ uint32_t head = 0;
177
+ uint32_t n = 0;
178
+ } old_state;
179
+
180
+ // for non-recurrent models only
181
+ // list of slots to restore
182
+ std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
183
+
184
+ bool do_restore = false;
185
+
186
+ explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
187
+ old_state.head = cache.head;
188
+ old_state.n = cache.n;
189
+ }
190
+
191
+ // saves a slot information for future restoration
192
+ void save(const struct llama_kv_cache_slot_info & slot) {
193
+ if (slot) {
194
+ do_restore = true;
195
+ if (slot.boundaries.first != slot.boundaries.second) {
196
+ slot_boundaries.push_back(slot.boundaries);
197
+ }
198
+ }
199
+ }
200
+
201
+ // must be explicitly called to restore the kv_cache state
202
+ // and rollback changes from all llama_kv_cache_find_slot calls
203
+ void restore(struct llama_kv_cache & cache) {
204
+ if (do_restore) {
205
+ cache.head = old_state.head;
206
+ cache.n = old_state.n;
207
+
208
+ if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
209
+ llama_kv_cache_seq_rm(cache, -1, -1, -1);
210
+ } else {
211
+ for (auto & slot : slot_boundaries) {
212
+ llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
213
+ }
214
+ }
215
+ }
216
+ }
217
+ };
218
+
examples/talk-llama/llama-mmap.cpp ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-mmap.h"
2
+
3
+ #include "llama-impl.h"
4
+
5
+ #include "ggml.h"
6
+
7
+ #include <cstring>
8
+ #include <climits>
9
+ #include <stdexcept>
10
+
11
+ #ifdef __has_include
12
+ #if __has_include(<unistd.h>)
13
+ #include <unistd.h>
14
+ #if defined(_POSIX_MAPPED_FILES)
15
+ #include <sys/mman.h>
16
+ #include <fcntl.h>
17
+ #endif
18
+ #if defined(_POSIX_MEMLOCK_RANGE)
19
+ #include <sys/resource.h>
20
+ #endif
21
+ #endif
22
+ #endif
23
+
24
+ #if defined(_WIN32)
25
+ #define WIN32_LEAN_AND_MEAN
26
+ #ifndef NOMINMAX
27
+ #define NOMINMAX
28
+ #endif
29
+ #include <windows.h>
30
+ #ifndef PATH_MAX
31
+ #define PATH_MAX MAX_PATH
32
+ #endif
33
+ #include <io.h>
34
+ #endif
35
+
36
+ // TODO: consider moving to llama-impl.h if needed in more places
37
+ #if defined(_WIN32)
38
+ std::string llama_format_win_err(DWORD err) {
39
+ LPSTR buf;
40
+ size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
41
+ NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
42
+ if (!size) {
43
+ return "FormatMessageA failed";
44
+ }
45
+ std::string ret(buf, size);
46
+ LocalFree(buf);
47
+ return ret;
48
+ }
49
+ #endif
50
+
51
+ // llama_file
52
+
53
+ struct llama_file::impl {
54
+ #if defined(_WIN32)
55
+ HANDLE fp_win32;
56
+ std::string GetErrorMessageWin32(DWORD error_code) const {
57
+ std::string ret;
58
+ LPSTR lpMsgBuf = NULL;
59
+ DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
60
+ NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
61
+ if (!bufLen) {
62
+ ret = format("Win32 error code: %lx", error_code);
63
+ } else {
64
+ ret = lpMsgBuf;
65
+ LocalFree(lpMsgBuf);
66
+ }
67
+
68
+ return ret;
69
+ }
70
+
71
+ impl(const char * fname, const char * mode) {
72
+ fp = ggml_fopen(fname, mode);
73
+ if (fp == NULL) {
74
+ throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
75
+ }
76
+ fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp));
77
+ seek(0, SEEK_END);
78
+ size = tell();
79
+ seek(0, SEEK_SET);
80
+ }
81
+
82
+ size_t tell() const {
83
+ LARGE_INTEGER li;
84
+ li.QuadPart = 0;
85
+ BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT);
86
+ if (!ret) {
87
+ throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
88
+ }
89
+
90
+ return li.QuadPart;
91
+ }
92
+
93
+ void seek(size_t offset, int whence) const {
94
+ static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN");
95
+ static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT");
96
+ static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END");
97
+
98
+ LARGE_INTEGER li;
99
+ li.QuadPart = offset;
100
+ BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence);
101
+ if (!ret) {
102
+ throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
103
+ }
104
+ }
105
+
106
+ void read_raw(void * ptr, size_t len) const {
107
+ size_t bytes_read = 0;
108
+ while (bytes_read < len) {
109
+ size_t chunk_size = std::min<size_t>(len - bytes_read, 64*1024*1024);
110
+ DWORD chunk_read = 0;
111
+ BOOL result = ReadFile(fp_win32, reinterpret_cast<char*>(ptr) + bytes_read, chunk_size, &chunk_read, NULL);
112
+ if (!result) {
113
+ throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
114
+ }
115
+ if (chunk_read < chunk_size || chunk_read == 0) {
116
+ throw std::runtime_error("unexpectedly reached end of file");
117
+ }
118
+
119
+ bytes_read += chunk_read;
120
+ }
121
+ }
122
+
123
+ uint32_t read_u32() const {
124
+ uint32_t val;
125
+ read_raw(&val, sizeof(val));
126
+ return val;
127
+ }
128
+
129
+ void write_raw(const void * ptr, size_t len) const {
130
+ size_t bytes_written = 0;
131
+ while (bytes_written < len) {
132
+ size_t chunk_size = std::min<size_t>(len - bytes_written, 64*1024*1024);
133
+ DWORD chunk_written = 0;
134
+ BOOL result = WriteFile(fp_win32, reinterpret_cast<char const*>(ptr) + bytes_written, chunk_size, &chunk_written, NULL);
135
+ if (!result) {
136
+ throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
137
+ }
138
+ if (chunk_written < chunk_size || chunk_written == 0) {
139
+ throw std::runtime_error("unexpectedly failed to write bytes");
140
+ }
141
+
142
+ bytes_written += chunk_written;
143
+ }
144
+ }
145
+
146
+ void write_u32(uint32_t val) const {
147
+ write_raw(&val, sizeof(val));
148
+ }
149
+
150
+ ~impl() {
151
+ if (fp) {
152
+ std::fclose(fp);
153
+ }
154
+ }
155
+ #else
156
+ impl(const char * fname, const char * mode) {
157
+ fp = ggml_fopen(fname, mode);
158
+ if (fp == NULL) {
159
+ throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
160
+ }
161
+ seek(0, SEEK_END);
162
+ size = tell();
163
+ seek(0, SEEK_SET);
164
+ }
165
+
166
+ size_t tell() const {
167
+ // TODO: this ifdef is never true?
168
+ #ifdef _WIN32
169
+ __int64 ret = _ftelli64(fp);
170
+ #else
171
+ long ret = std::ftell(fp);
172
+ #endif
173
+ if (ret == -1) {
174
+ throw std::runtime_error(format("ftell error: %s", strerror(errno)));
175
+ }
176
+
177
+ return (size_t) ret;
178
+ }
179
+
180
+ void seek(size_t offset, int whence) const {
181
+ // TODO: this ifdef is never true?
182
+ #ifdef _WIN32
183
+ int ret = _fseeki64(fp, (__int64) offset, whence);
184
+ #else
185
+ int ret = std::fseek(fp, (long) offset, whence);
186
+ #endif
187
+ if (ret != 0) {
188
+ throw std::runtime_error(format("seek error: %s", strerror(errno)));
189
+ }
190
+ }
191
+
192
+ void read_raw(void * ptr, size_t len) const {
193
+ if (len == 0) {
194
+ return;
195
+ }
196
+ errno = 0;
197
+ std::size_t ret = std::fread(ptr, len, 1, fp);
198
+ if (ferror(fp)) {
199
+ throw std::runtime_error(format("read error: %s", strerror(errno)));
200
+ }
201
+ if (ret != 1) {
202
+ throw std::runtime_error("unexpectedly reached end of file");
203
+ }
204
+ }
205
+
206
+ uint32_t read_u32() const {
207
+ uint32_t ret;
208
+ read_raw(&ret, sizeof(ret));
209
+ return ret;
210
+ }
211
+
212
+ void write_raw(const void * ptr, size_t len) const {
213
+ if (len == 0) {
214
+ return;
215
+ }
216
+ errno = 0;
217
+ size_t ret = std::fwrite(ptr, len, 1, fp);
218
+ if (ret != 1) {
219
+ throw std::runtime_error(format("write error: %s", strerror(errno)));
220
+ }
221
+ }
222
+
223
+ void write_u32(uint32_t val) const {
224
+ write_raw(&val, sizeof(val));
225
+ }
226
+
227
+ ~impl() {
228
+ if (fp) {
229
+ std::fclose(fp);
230
+ }
231
+ }
232
+ #endif
233
+
234
+ FILE * fp;
235
+ size_t size;
236
+ };
237
+
238
+ llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique<impl>(fname, mode)) {}
239
+ llama_file::~llama_file() = default;
240
+
241
+ size_t llama_file::tell() const { return pimpl->tell(); }
242
+ size_t llama_file::size() const { return pimpl->size; }
243
+
244
+ int llama_file::file_id() const {
245
+ #ifdef _WIN32
246
+ return _fileno(pimpl->fp);
247
+ #else
248
+ #if defined(fileno)
249
+ return fileno(pimpl->fp);
250
+ #else
251
+ return ::fileno(pimpl->fp);
252
+ #endif
253
+ #endif
254
+ }
255
+
256
+ void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
257
+ void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
258
+
259
+ uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
260
+
261
+ void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
262
+ void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); }
263
+
264
+ // llama_mmap
265
+
266
+ struct llama_mmap::impl {
267
+ #ifdef _POSIX_MAPPED_FILES
268
+ std::vector<std::pair<size_t, size_t>> mapped_fragments;
269
+
270
+ impl(struct llama_file * file, size_t prefetch, bool numa) {
271
+ size = file->size();
272
+ int fd = file->file_id();
273
+ int flags = MAP_SHARED;
274
+ if (numa) { prefetch = 0; }
275
+ #ifdef __linux__
276
+ if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
277
+ LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n",
278
+ strerror(errno));
279
+ }
280
+ if (prefetch) { flags |= MAP_POPULATE; }
281
+ #endif
282
+ addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0);
283
+ if (addr == MAP_FAILED) {
284
+ throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
285
+ }
286
+
287
+ if (prefetch > 0) {
288
+ if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) {
289
+ LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
290
+ strerror(errno));
291
+ }
292
+ }
293
+ if (numa) {
294
+ if (posix_madvise(addr, file->size(), POSIX_MADV_RANDOM)) {
295
+ LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
296
+ strerror(errno));
297
+ }
298
+ }
299
+
300
+ mapped_fragments.emplace_back(0, file->size());
301
+ }
302
+
303
+ static void align_range(size_t * first, size_t * last, size_t page_size) {
304
+ size_t offset_in_page = *first & (page_size - 1);
305
+ size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
306
+ *first += offset_to_page;
307
+
308
+ *last = *last & ~(page_size - 1);
309
+
310
+ if (*last <= *first) {
311
+ *last = *first;
312
+ }
313
+ }
314
+
315
+ void unmap_fragment(size_t first, size_t last) {
316
+ int page_size = sysconf(_SC_PAGESIZE);
317
+ align_range(&first, &last, page_size);
318
+ size_t len = last - first;
319
+
320
+ if (len == 0) {
321
+ return;
322
+ }
323
+
324
+ GGML_ASSERT(first % page_size == 0);
325
+ GGML_ASSERT(last % page_size == 0);
326
+ GGML_ASSERT(last > first);
327
+
328
+ void * next_page_start = (uint8_t *) addr + first;
329
+
330
+ if (munmap(next_page_start, len)) {
331
+ LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
332
+ }
333
+
334
+ std::vector<std::pair<size_t, size_t>> new_mapped_fragments;
335
+ for (const auto & frag : mapped_fragments) {
336
+ if (frag.first < first && frag.second > last) {
337
+ new_mapped_fragments.emplace_back(frag.first, first);
338
+ new_mapped_fragments.emplace_back(last, frag.second);
339
+ } else if (frag.first < first && frag.second > first) {
340
+ new_mapped_fragments.emplace_back(frag.first, first);
341
+ } else if (frag.first < last && frag.second > last) {
342
+ new_mapped_fragments.emplace_back(last, frag.second);
343
+ } else if (frag.first >= first && frag.second <= last) {
344
+ } else {
345
+ new_mapped_fragments.push_back(frag);
346
+ }
347
+ }
348
+ mapped_fragments = std::move(new_mapped_fragments);
349
+ }
350
+
351
+ ~impl() {
352
+ for (const auto & frag : mapped_fragments) {
353
+ if (munmap((char *) addr + frag.first, frag.second - frag.first)) {
354
+ LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
355
+ }
356
+ }
357
+ }
358
+ #elif defined(_WIN32)
359
+ impl(struct llama_file * file, size_t prefetch, bool numa) {
360
+ GGML_UNUSED(numa);
361
+
362
+ size = file->size();
363
+
364
+ HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id());
365
+
366
+ HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
367
+
368
+ if (hMapping == NULL) {
369
+ DWORD error = GetLastError();
370
+ throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
371
+ }
372
+
373
+ addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
374
+ DWORD error = GetLastError();
375
+ CloseHandle(hMapping);
376
+
377
+ if (addr == NULL) {
378
+ throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
379
+ }
380
+
381
+ if (prefetch > 0) {
382
+ #if _WIN32_WINNT >= 0x602
383
+ BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG);
384
+ HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll");
385
+
386
+ pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory");
387
+
388
+ if (pPrefetchVirtualMemory) {
389
+ WIN32_MEMORY_RANGE_ENTRY range;
390
+ range.VirtualAddress = addr;
391
+ range.NumberOfBytes = (SIZE_T) std::min(size, prefetch);
392
+ if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
393
+ LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n",
394
+ llama_format_win_err(GetLastError()).c_str());
395
+ }
396
+ }
397
+ #else
398
+ throw std::runtime_error("PrefetchVirtualMemory unavailable");
399
+ #endif
400
+ }
401
+ }
402
+
403
+ void unmap_fragment(size_t first, size_t last) {
404
+ GGML_UNUSED(first);
405
+ GGML_UNUSED(last);
406
+ }
407
+
408
+ ~impl() {
409
+ if (!UnmapViewOfFile(addr)) {
410
+ LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
411
+ llama_format_win_err(GetLastError()).c_str());
412
+ }
413
+ }
414
+ #else
415
+ impl(struct llama_file * file, size_t prefetch, bool numa) {
416
+ GGML_UNUSED(file);
417
+ GGML_UNUSED(prefetch);
418
+ GGML_UNUSED(numa);
419
+
420
+ throw std::runtime_error("mmap not supported");
421
+ }
422
+
423
+ void unmap_fragment(size_t first, size_t last) {
424
+ GGML_UNUSED(first);
425
+ GGML_UNUSED(last);
426
+
427
+ throw std::runtime_error("mmap not supported");
428
+ }
429
+ #endif
430
+
431
+ void * addr;
432
+ size_t size;
433
+ };
434
+
435
+ llama_mmap::llama_mmap(struct llama_file * file, size_t prefetch, bool numa) : pimpl(std::make_unique<impl>(file, prefetch, numa)) {}
436
+ llama_mmap::~llama_mmap() = default;
437
+
438
+ size_t llama_mmap::size() const { return pimpl->size; }
439
+ void * llama_mmap::addr() const { return pimpl->addr; }
440
+
441
+ void llama_mmap::unmap_fragment(size_t first, size_t last) { pimpl->unmap_fragment(first, last); }
442
+
443
+ #if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32)
444
+ const bool llama_mmap::SUPPORTED = true;
445
+ #else
446
+ const bool llama_mmap::SUPPORTED = false;
447
+ #endif
448
+
449
+ // llama_mlock
450
+
451
+ struct llama_mlock::impl {
452
+ #ifdef _POSIX_MEMLOCK_RANGE
453
+ static size_t lock_granularity() {
454
+ return (size_t) sysconf(_SC_PAGESIZE);
455
+ }
456
+
457
+ bool raw_lock(const void * addr, size_t size) const {
458
+ if (!mlock(addr, size)) {
459
+ return true;
460
+ }
461
+
462
+ #ifdef __APPLE__
463
+ #define MLOCK_SUGGESTION \
464
+ "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
465
+ "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n"
466
+ #else
467
+ #define MLOCK_SUGGESTION \
468
+ "Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n"
469
+ #endif
470
+
471
+ char* errmsg = std::strerror(errno);
472
+ bool suggest = (errno == ENOMEM);
473
+
474
+ struct rlimit lock_limit;
475
+ if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
476
+ suggest = false;
477
+ }
478
+ if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
479
+ suggest = false;
480
+ }
481
+
482
+ LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
483
+ size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
484
+ return false;
485
+ }
486
+
487
+ static void raw_unlock(void * addr, size_t size) {
488
+ if (munlock(addr, size)) {
489
+ LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno));
490
+ }
491
+ }
492
+ #elif defined(_WIN32)
493
+ static size_t lock_granularity() {
494
+ SYSTEM_INFO si;
495
+ GetSystemInfo(&si);
496
+ return (size_t) si.dwPageSize;
497
+ }
498
+
499
+ bool raw_lock(void * ptr, size_t len) const {
500
+ for (int tries = 1; ; tries++) {
501
+ if (VirtualLock(ptr, len)) {
502
+ return true;
503
+ }
504
+ if (tries == 2) {
505
+ LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
506
+ len, size, llama_format_win_err(GetLastError()).c_str());
507
+ return false;
508
+ }
509
+
510
+ SIZE_T min_ws_size, max_ws_size;
511
+ if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
512
+ LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n",
513
+ llama_format_win_err(GetLastError()).c_str());
514
+ return false;
515
+ }
516
+ size_t increment = len + 1048576;
517
+ min_ws_size += increment;
518
+ max_ws_size += increment;
519
+ if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
520
+ LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n",
521
+ llama_format_win_err(GetLastError()).c_str());
522
+ return false;
523
+ }
524
+ }
525
+ }
526
+
527
+ static void raw_unlock(void * ptr, size_t len) {
528
+ if (!VirtualUnlock(ptr, len)) {
529
+ LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n",
530
+ llama_format_win_err(GetLastError()).c_str());
531
+ }
532
+ }
533
+ #else
534
+ static size_t lock_granularity() {
535
+ return (size_t) 65536;
536
+ }
537
+
538
+ bool raw_lock(const void * addr, size_t len) const {
539
+ LLAMA_LOG_WARN("warning: mlock not supported on this system\n");
540
+ return false;
541
+ }
542
+
543
+ static void raw_unlock(const void * addr, size_t len) {}
544
+ #endif
545
+
546
+ impl() : addr(NULL), size(0), failed_already(false) {}
547
+
548
+ void init(void * ptr) {
549
+ GGML_ASSERT(addr == NULL && size == 0);
550
+ addr = ptr;
551
+ }
552
+
553
+ void grow_to(size_t target_size) {
554
+ GGML_ASSERT(addr);
555
+ if (failed_already) {
556
+ return;
557
+ }
558
+ size_t granularity = lock_granularity();
559
+ target_size = (target_size + granularity - 1) & ~(granularity - 1);
560
+ if (target_size > size) {
561
+ if (raw_lock((uint8_t *) addr + size, target_size - size)) {
562
+ size = target_size;
563
+ } else {
564
+ failed_already = true;
565
+ }
566
+ }
567
+ }
568
+
569
+ void * addr;
570
+ size_t size;
571
+
572
+ bool failed_already;
573
+ };
574
+
575
+ llama_mlock::llama_mlock() : pimpl(std::make_unique<impl>()) {}
576
+ llama_mlock::~llama_mlock() = default;
577
+
578
+ void llama_mlock::init(void * ptr) { pimpl->init(ptr); }
579
+ void llama_mlock::grow_to(size_t target_size) { pimpl->grow_to(target_size); }
580
+
581
+ #if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32)
582
+ const bool llama_mlock::SUPPORTED = true;
583
+ #else
584
+ const bool llama_mlock::SUPPORTED = false;
585
+ #endif
586
+
587
+ size_t llama_path_max() {
588
+ return PATH_MAX;
589
+ }
examples/talk-llama/llama-mmap.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <memory>
4
+ #include <vector>
5
+
6
+ struct llama_file;
7
+ struct llama_mmap;
8
+ struct llama_mlock;
9
+
10
+ using llama_files = std::vector<std::unique_ptr<llama_file>>;
11
+ using llama_mmaps = std::vector<std::unique_ptr<llama_mmap>>;
12
+ using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
13
+
14
+ struct llama_file {
15
+ llama_file(const char * fname, const char * mode);
16
+ ~llama_file();
17
+
18
+ size_t tell() const;
19
+ size_t size() const;
20
+
21
+ int file_id() const; // fileno overload
22
+
23
+ void seek(size_t offset, int whence) const;
24
+
25
+ void read_raw(void * ptr, size_t len) const;
26
+ uint32_t read_u32() const;
27
+
28
+ void write_raw(const void * ptr, size_t len) const;
29
+ void write_u32(uint32_t val) const;
30
+
31
+ private:
32
+ struct impl;
33
+ std::unique_ptr<impl> pimpl;
34
+ };
35
+
36
+ struct llama_mmap {
37
+ llama_mmap(const llama_mmap &) = delete;
38
+ llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false);
39
+ ~llama_mmap();
40
+
41
+ size_t size() const;
42
+ void * addr() const;
43
+
44
+ void unmap_fragment(size_t first, size_t last);
45
+
46
+ static const bool SUPPORTED;
47
+
48
+ private:
49
+ struct impl;
50
+ std::unique_ptr<impl> pimpl;
51
+ };
52
+
53
+ struct llama_mlock {
54
+ llama_mlock();
55
+ ~llama_mlock();
56
+
57
+ void init(void * ptr);
58
+ void grow_to(size_t target_size);
59
+
60
+ static const bool SUPPORTED;
61
+
62
+ private:
63
+ struct impl;
64
+ std::unique_ptr<impl> pimpl;
65
+ };
66
+
67
+ size_t llama_path_max();
examples/talk-llama/llama-model-loader.cpp ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-model-loader.h"
2
+
3
+ #include "ggml.h"
4
+
5
+ #include <array>
6
+ #include <cinttypes>
7
+ #include <cstring>
8
+ #include <future>
9
+
10
+ const char * llama_file_version_name(llama_fver version) {
11
+ switch (version) {
12
+ case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
13
+ case GGUF_FILE_VERSION_V2: return "GGUF V2";
14
+ case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)";
15
+ }
16
+
17
+ return "unknown";
18
+ }
19
+
20
+ namespace GGUFMeta {
21
+ template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int)>
22
+ struct GKV_Base_Type {
23
+ static constexpr gguf_type gt = gt_;
24
+
25
+ static T getter(const gguf_context * ctx, const int kid) {
26
+ return gfun(ctx, kid);
27
+ }
28
+ };
29
+
30
+ template<typename T> struct GKV_Base;
31
+
32
+ template<> struct GKV_Base<bool >: GKV_Base_Type<bool, GGUF_TYPE_BOOL, gguf_get_val_bool> {};
33
+ template<> struct GKV_Base<uint8_t >: GKV_Base_Type<uint8_t, GGUF_TYPE_UINT8, gguf_get_val_u8 > {};
34
+ template<> struct GKV_Base<uint16_t >: GKV_Base_Type<uint16_t, GGUF_TYPE_UINT16, gguf_get_val_u16 > {};
35
+ template<> struct GKV_Base<uint32_t >: GKV_Base_Type<uint32_t, GGUF_TYPE_UINT32, gguf_get_val_u32 > {};
36
+ template<> struct GKV_Base<uint64_t >: GKV_Base_Type<uint64_t, GGUF_TYPE_UINT64, gguf_get_val_u64 > {};
37
+ template<> struct GKV_Base<int8_t >: GKV_Base_Type<int8_t, GGUF_TYPE_INT8, gguf_get_val_i8 > {};
38
+ template<> struct GKV_Base<int16_t >: GKV_Base_Type<int16_t, GGUF_TYPE_INT16, gguf_get_val_i16 > {};
39
+ template<> struct GKV_Base<int32_t >: GKV_Base_Type<int32_t, GGUF_TYPE_INT32, gguf_get_val_i32 > {};
40
+ template<> struct GKV_Base<int64_t >: GKV_Base_Type<int64_t, GGUF_TYPE_INT64, gguf_get_val_i64 > {};
41
+ template<> struct GKV_Base<float >: GKV_Base_Type<float, GGUF_TYPE_FLOAT32, gguf_get_val_f32 > {};
42
+ template<> struct GKV_Base<double >: GKV_Base_Type<double, GGUF_TYPE_FLOAT64, gguf_get_val_f64 > {};
43
+ template<> struct GKV_Base<const char *>: GKV_Base_Type<const char *, GGUF_TYPE_STRING, gguf_get_val_str > {};
44
+
45
+ template<> struct GKV_Base<std::string> {
46
+ static constexpr gguf_type gt = GGUF_TYPE_STRING;
47
+
48
+ static std::string getter(const gguf_context * ctx, const int kid) {
49
+ return gguf_get_val_str(ctx, kid);
50
+ }
51
+ };
52
+
53
+ struct ArrayInfo {
54
+ const gguf_type gt;
55
+ const size_t length;
56
+ const void * data;
57
+ };
58
+
59
+ template<> struct GKV_Base<ArrayInfo> {
60
+ public:
61
+ static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
62
+ static ArrayInfo getter(const gguf_context *ctx, const int k) {
63
+ return ArrayInfo {
64
+ gguf_get_arr_type(ctx, k),
65
+ size_t(gguf_get_arr_n(ctx, k)),
66
+ gguf_get_arr_data(ctx, k),
67
+ };
68
+ }
69
+ };
70
+
71
+ template<typename T>
72
+ class GKV : public GKV_Base<T> {
73
+ GKV() = delete;
74
+
75
+ public:
76
+ static T get_kv(const gguf_context * ctx, const int k) {
77
+ const enum gguf_type kt = gguf_get_kv_type(ctx, k);
78
+
79
+ if (kt != GKV::gt) {
80
+ throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
81
+ gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt)));
82
+ }
83
+ return GKV::getter(ctx, k);
84
+ }
85
+
86
+ static const char * override_type_to_str(const llama_model_kv_override_type ty) {
87
+ switch (ty) {
88
+ case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool";
89
+ case LLAMA_KV_OVERRIDE_TYPE_INT: return "int";
90
+ case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float";
91
+ case LLAMA_KV_OVERRIDE_TYPE_STR: return "str";
92
+ }
93
+ return "unknown";
94
+ }
95
+
96
+ static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) {
97
+ if (!ovrd) { return false; }
98
+ if (ovrd->tag == expected_type) {
99
+ LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
100
+ __func__, override_type_to_str(ovrd->tag), ovrd->key);
101
+ switch (ovrd->tag) {
102
+ case LLAMA_KV_OVERRIDE_TYPE_BOOL: {
103
+ LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false");
104
+ } break;
105
+ case LLAMA_KV_OVERRIDE_TYPE_INT: {
106
+ LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64);
107
+ } break;
108
+ case LLAMA_KV_OVERRIDE_TYPE_FLOAT: {
109
+ LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64);
110
+ } break;
111
+ case LLAMA_KV_OVERRIDE_TYPE_STR: {
112
+ LLAMA_LOG_INFO("%s\n", ovrd->val_str);
113
+ } break;
114
+ default:
115
+ // Shouldn't be possible to end up here, but just in case...
116
+ throw std::runtime_error(
117
+ format("Unsupported attempt to override %s type for metadata key %s\n",
118
+ override_type_to_str(ovrd->tag), ovrd->key));
119
+ }
120
+ return true;
121
+ }
122
+ LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
123
+ __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag));
124
+ return false;
125
+ }
126
+
127
+ template<typename OT>
128
+ static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
129
+ try_override(OT & target, const struct llama_model_kv_override * ovrd) {
130
+ if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) {
131
+ target = ovrd->val_bool;
132
+ return true;
133
+ }
134
+ return false;
135
+ }
136
+
137
+ template<typename OT>
138
+ static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
139
+ try_override(OT & target, const struct llama_model_kv_override * ovrd) {
140
+ if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) {
141
+ target = ovrd->val_i64;
142
+ return true;
143
+ }
144
+ return false;
145
+ }
146
+
147
+ template<typename OT>
148
+ static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
149
+ try_override(T & target, const struct llama_model_kv_override * ovrd) {
150
+ if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) {
151
+ target = ovrd->val_f64;
152
+ return true;
153
+ }
154
+ return false;
155
+ }
156
+
157
+ template<typename OT>
158
+ static typename std::enable_if<std::is_same<OT, std::string>::value, bool>::type
159
+ try_override(T & target, const struct llama_model_kv_override * ovrd) {
160
+ if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) {
161
+ target = ovrd->val_str;
162
+ return true;
163
+ }
164
+ return false;
165
+ }
166
+
167
+ static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
168
+ if (try_override<T>(target, ovrd)) {
169
+ return true;
170
+ }
171
+ if (k < 0) { return false; }
172
+ target = get_kv(ctx, k);
173
+ return true;
174
+ }
175
+
176
+ static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
177
+ return set(ctx, gguf_find_key(ctx, key), target, ovrd);
178
+ }
179
+
180
+ static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
181
+ return set(ctx, key.c_str(), target, ovrd);
182
+ }
183
+ };
184
+ }
185
+
186
+ template<typename T>
187
+ typename std::enable_if<std::is_integral<T>::value, bool>::type
188
+ llama_model_loader::get_arr_n(const std::string & key, T & result, bool required) {
189
+ const int kid = gguf_find_key(meta.get(), key.c_str());
190
+
191
+ if (kid < 0) {
192
+ if (required) {
193
+ throw std::runtime_error(format("key not found in model: %s", key.c_str()));
194
+ }
195
+ return false;
196
+ }
197
+
198
+ struct GGUFMeta::ArrayInfo arr_info =
199
+ GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
200
+
201
+
202
+ result = arr_info.length;
203
+ return true;
204
+ }
205
+
206
+ template<typename T>
207
+ typename std::enable_if<std::is_integral<T>::value, bool>::type
208
+ llama_model_loader::get_arr_n(enum llm_kv kid, T & result, bool required) {
209
+ return get_arr_n(llm_kv(kid), result, required);
210
+ }
211
+
212
+ template bool llama_model_loader::get_arr_n(enum llm_kv kid, uint32_t & result, bool required);
213
+
214
+ template<typename T>
215
+ bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
216
+ const int kid = gguf_find_key(meta.get(), key.c_str());
217
+
218
+ if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
219
+ if (required) {
220
+ throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
221
+ }
222
+ return false;
223
+ }
224
+
225
+ struct GGUFMeta::ArrayInfo arr_info =
226
+ GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
227
+
228
+ switch (arr_info.gt) {
229
+ case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
230
+ case GGUF_TYPE_INT32: GGML_ASSERT(
231
+ (std::is_same<T, int32_t>::value) ||
232
+ (std::is_same<T, uint32_t>::value)); break;
233
+ default:
234
+ throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
235
+ }
236
+
237
+ result.resize(arr_info.length);
238
+ result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
239
+
240
+ return true;
241
+ }
242
+
243
+ template<typename T, size_t N_MAX>
244
+ bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
245
+ const int kid = gguf_find_key(meta.get(), key.c_str());
246
+
247
+ if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
248
+ if (required) {
249
+ throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
250
+ }
251
+ return false;
252
+ }
253
+
254
+ struct GGUFMeta::ArrayInfo arr_info =
255
+ GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
256
+
257
+ switch (arr_info.gt) {
258
+ case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
259
+ case GGUF_TYPE_INT32: GGML_ASSERT(
260
+ (std::is_same<T, int32_t>::value) ||
261
+ (std::is_same<T, uint32_t>::value)); break;
262
+ default:
263
+ throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
264
+ }
265
+
266
+ if (arr_info.length > N_MAX) {
267
+ throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
268
+ }
269
+
270
+ std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
271
+
272
+ return true;
273
+ }
274
+
275
+ template<typename T>
276
+ bool llama_model_loader::get_arr(enum llm_kv kid, T & result, bool required) {
277
+ return get_arr(llm_kv(kid), result, required);
278
+ }
279
+
280
+ template<typename T>
281
+ bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
282
+ auto it = kv_overrides.find(key);
283
+
284
+ const struct llama_model_kv_override * override =
285
+ it != kv_overrides.end() ? &it->second : nullptr;
286
+
287
+ const bool found = GGUFMeta::GKV<T>::set(meta.get(), key, result, override);
288
+
289
+ if (required && !found) {
290
+ throw std::runtime_error(format("key not found in model: %s", key.c_str()));
291
+ }
292
+
293
+ return found;
294
+ }
295
+
296
+ template<typename T>
297
+ bool llama_model_loader::get_key(enum llm_kv kid, T & result, bool required) {
298
+ return get_key(llm_kv(kid), result, required);
299
+ }
300
+
301
+ template bool llama_model_loader::get_key<bool> (enum llm_kv kid, bool & result, bool required);
302
+ template bool llama_model_loader::get_key<float> (enum llm_kv kid, float & result, bool required);
303
+ template bool llama_model_loader::get_key<uint32_t> (enum llm_kv kid, uint32_t & result, bool required);
304
+ template bool llama_model_loader::get_key<std::string>(enum llm_kv kid, std::string & result, bool required);
305
+
306
+ template<>
307
+ bool llama_model_loader::get_key(enum llm_kv kid, enum llama_pooling_type & result, bool required) {
308
+ uint32_t tmp;
309
+ const bool found = get_key(kid, tmp, required);
310
+ if (found) {
311
+ result = (enum llama_pooling_type) tmp;
312
+ } else {
313
+ result = LLAMA_POOLING_TYPE_UNSPECIFIED;
314
+ }
315
+ return found;
316
+ }
317
+
318
+ // get array of n <= N_MAX elements, or a single element repeated n times
319
+ template<typename T, size_t N_MAX>
320
+ bool llama_model_loader::get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, bool required) {
321
+ const int kid = gguf_find_key(meta.get(), key.c_str());
322
+
323
+ if (kid < 0) {
324
+ if (required) {
325
+ throw std::runtime_error(format("key not found in model: %s", key.c_str()));
326
+ }
327
+ return false;
328
+ }
329
+
330
+ if (n > N_MAX) {
331
+ throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str()));
332
+ }
333
+
334
+ if (gguf_get_kv_type(meta.get(), kid) == GGUF_TYPE_ARRAY) {
335
+ struct GGUFMeta::ArrayInfo arr_info =
336
+ GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
337
+
338
+ if (n != arr_info.length) {
339
+ throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length));
340
+ }
341
+
342
+ return get_arr(key, result, required);
343
+ }
344
+
345
+ T value;
346
+
347
+ bool ok = get_key(key, value, required);
348
+ if (!ok) {
349
+ return false;
350
+ }
351
+
352
+ for (uint32_t i = 0; i < n; i++) {
353
+ result[i] = value;
354
+ }
355
+
356
+ return true;
357
+ }
358
+
359
+ template<typename T>
360
+ bool llama_model_loader::get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required) {
361
+ return get_key_or_arr(llm_kv(kid), result, n, required);
362
+ }
363
+
364
+ // TODO: this is not very clever - figure out something better
365
+ template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
366
+ template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
367
+
368
+ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
369
+ int trace = 0;
370
+ if (getenv("LLAMA_TRACE")) {
371
+ trace = atoi(getenv("LLAMA_TRACE"));
372
+ }
373
+
374
+ if (param_overrides_p != nullptr) {
375
+ for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) {
376
+ kv_overrides.insert({std::string(p->key), *p});
377
+ }
378
+ }
379
+
380
+ struct ggml_context * ctx = NULL;
381
+ struct gguf_init_params params = {
382
+ /*.no_alloc = */ true,
383
+ /*.ctx = */ &ctx,
384
+ };
385
+
386
+ meta.reset(gguf_init_from_file(fname.c_str(), params));
387
+ if (!meta) {
388
+ throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
389
+ }
390
+
391
+ get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
392
+ llm_kv = LLM_KV(llm_arch_from_string(arch_name));
393
+
394
+ files.emplace_back(new llama_file(fname.c_str(), "rb"));
395
+ contexts.emplace_back(ctx);
396
+
397
+ // Save tensors data offset of the main file.
398
+ // For subsidiary files, `meta` tensor data offset must not be used,
399
+ // so we build a unified tensors index for weights.
400
+ for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
401
+ std::string tensor_name = std::string(cur->name);
402
+ // make sure there is no duplicated tensor names
403
+ if (weights_map.find(tensor_name) != weights_map.end()) {
404
+ throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur)));
405
+ }
406
+ n_elements += ggml_nelements(cur);
407
+ n_bytes += ggml_nbytes(cur);
408
+ weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur));
409
+ }
410
+ uint16_t n_split = 0;
411
+ get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false);
412
+
413
+ // Load additional GGML contexts
414
+ if (n_split > 1) {
415
+ uint16_t idx = 0;
416
+ get_key(llm_kv(LLM_KV_SPLIT_NO), idx);
417
+ if (idx != 0) {
418
+ throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx));
419
+ }
420
+
421
+ std::vector<char> split_prefix(llama_path_max(), 0);
422
+ if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) {
423
+ throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
424
+ }
425
+
426
+ if (trace > 0) {
427
+ LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
428
+ }
429
+
430
+ std::vector<char> split_path(llama_path_max(), 0);
431
+ for (idx = 1; idx < n_split; idx++) {
432
+ llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split);
433
+
434
+ struct gguf_init_params split_params = {
435
+ /*.no_alloc = */ true,
436
+ /*.ctx = */ &ctx,
437
+ };
438
+ gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) };
439
+ if (!ctx_gguf) {
440
+ throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data()));
441
+ }
442
+
443
+ files.emplace_back(new llama_file(split_path.data(), "rb"));
444
+ contexts.emplace_back(ctx);
445
+
446
+ // Save tensors data offset info of the shard.
447
+ for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
448
+ std::string tensor_name = std::string(cur->name);
449
+ // make sure there is no duplicated tensor names
450
+ if (weights_map.find(tensor_name) != weights_map.end()) {
451
+ throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur)));
452
+ }
453
+ n_elements += ggml_nelements(cur);
454
+ n_bytes += ggml_nbytes(cur);
455
+ weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur));
456
+ }
457
+ }
458
+
459
+ get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors);
460
+
461
+ // sanity check
462
+ {
463
+ const int n_tensors_loaded = (int) weights_map.size();
464
+ if (n_tensors != n_tensors_loaded) {
465
+ throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded));
466
+ }
467
+ }
468
+
469
+ LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1);
470
+ }
471
+
472
+ n_kv = gguf_get_n_kv(meta.get());
473
+ n_tensors = weights_map.size();
474
+
475
+ fver = (enum llama_fver) gguf_get_version(meta.get());
476
+
477
+ LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n",
478
+ __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver));
479
+
480
+ // determine file type based on the number of tensors for each quantization and print meta data
481
+ // TODO: make optional
482
+ {
483
+ std::map<enum ggml_type, uint32_t> n_type;
484
+
485
+ uint32_t n_type_max = 0;
486
+ enum ggml_type type_max = GGML_TYPE_F32;
487
+
488
+ for (const auto & it : weights_map) {
489
+ const llama_tensor_weight & w = it.second;
490
+ const ggml_tensor * tensor = w.tensor;
491
+
492
+ enum ggml_type type = tensor->type;
493
+
494
+ n_type[type]++;
495
+
496
+ if (n_type_max < n_type[type]) {
497
+ n_type_max = n_type[type];
498
+ type_max = type;
499
+ }
500
+
501
+ if (trace > 0) {
502
+ const uint16_t sid = w.idx;
503
+ LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str());
504
+ }
505
+ }
506
+
507
+ switch (type_max) {
508
+ case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break;
509
+ case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break;
510
+ case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break;
511
+ case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break;
512
+ case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break;
513
+ case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break;
514
+ case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break;
515
+ case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break;
516
+ case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break;
517
+ case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break;
518
+ case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break;
519
+ case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break;
520
+ case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
521
+ case GGML_TYPE_TQ1_0: ftype = LLAMA_FTYPE_MOSTLY_TQ1_0; break;
522
+ case GGML_TYPE_TQ2_0: ftype = LLAMA_FTYPE_MOSTLY_TQ2_0; break;
523
+ case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
524
+ case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break;
525
+ case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break;
526
+ case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
527
+ case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break;
528
+ case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break;
529
+ case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
530
+ case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
531
+ case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
532
+ default:
533
+ {
534
+ LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
535
+ ftype = LLAMA_FTYPE_ALL_F32;
536
+ } break;
537
+ }
538
+
539
+ // this is a way to mark that we have "guessed" the file type
540
+ ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
541
+
542
+ {
543
+ const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV
544
+ if (kid >= 0) {
545
+ ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid);
546
+ }
547
+ }
548
+
549
+ LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
550
+
551
+ for (int i = 0; i < n_kv; i++) {
552
+ const char * name = gguf_get_key(meta.get(), i);
553
+ const enum gguf_type type = gguf_get_kv_type(meta.get(), i);
554
+ const std::string type_name =
555
+ type == GGUF_TYPE_ARRAY
556
+ ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i))
557
+ : gguf_type_name(type);
558
+
559
+ std::string value = gguf_kv_to_str(meta.get(), i);
560
+ const size_t MAX_VALUE_LEN = 40;
561
+ if (value.size() > MAX_VALUE_LEN) {
562
+ value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
563
+ }
564
+ replace_all(value, "\n", "\\n");
565
+
566
+ LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
567
+ }
568
+
569
+ // print type counts
570
+ for (auto & kv : n_type) {
571
+ if (kv.second == 0) {
572
+ continue;
573
+ }
574
+
575
+ LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second);
576
+ }
577
+ }
578
+
579
+ if (!llama_mmap::SUPPORTED) {
580
+ LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__);
581
+ use_mmap = false;
582
+ }
583
+
584
+ this->use_mmap = use_mmap;
585
+ this->check_tensors = check_tensors;
586
+ }
587
+
588
+ std::string llama_model_loader::get_arch_name() const {
589
+ return arch_name;
590
+ }
591
+
592
+ enum llm_arch llama_model_loader::get_arch() const {
593
+ return llm_kv.arch;
594
+ }
595
+
596
+ const llama_model_loader::llama_tensor_weight * llama_model_loader::get_weight(const char * name) const {
597
+ auto pos = weights_map.find(name);
598
+ if (pos != weights_map.end()) {
599
+ return &pos->second;
600
+ }
601
+
602
+ return nullptr;
603
+ }
604
+
605
+ const llama_model_loader::llama_tensor_weight & llama_model_loader::require_weight(const char * name) const {
606
+ const llama_tensor_weight * weight = get_weight(name);
607
+ if (!weight) {
608
+ throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
609
+ }
610
+ return *weight;
611
+ }
612
+
613
+ struct ggml_tensor * llama_model_loader::get_tensor_meta(const char * name) const {
614
+ const auto * weight = get_weight(name);
615
+ if (!weight) {
616
+ return nullptr;
617
+ }
618
+ return weight->tensor;
619
+ }
620
+
621
+ struct ggml_tensor * llama_model_loader::require_tensor_meta(const std::string & name) const {
622
+ struct ggml_tensor * tensor = get_tensor_meta(name.c_str());
623
+ if (!tensor) {
624
+ throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
625
+ }
626
+ return tensor;
627
+ }
628
+
629
+ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const {
630
+ const struct ggml_tensor * cur = get_tensor_meta(name.c_str());
631
+
632
+ if (cur == NULL) {
633
+ if (!required) {
634
+ return NULL;
635
+ }
636
+ throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
637
+ }
638
+
639
+ {
640
+ bool is_ok = true;
641
+ for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
642
+ if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) {
643
+ is_ok = false;
644
+ break;
645
+ }
646
+ }
647
+ if (!is_ok) {
648
+ throw std::runtime_error(
649
+ format("%s: tensor '%s' has wrong shape; expected %s, got %s",
650
+ __func__, name.c_str(),
651
+ llama_format_tensor_shape(ne).c_str(),
652
+ llama_format_tensor_shape(cur).c_str()));
653
+ }
654
+ }
655
+
656
+ return cur;
657
+ }
658
+
659
+ struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags) {
660
+ const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
661
+
662
+ if (cur == NULL) {
663
+ return NULL;
664
+ }
665
+
666
+ bool duplicated = flags & TENSOR_DUPLICATED;
667
+
668
+ struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur);
669
+ ggml_set_name(tensor, ggml_get_name(cur));
670
+
671
+ if (duplicated) {
672
+ size_data += ggml_nbytes(cur);
673
+ } else {
674
+ n_created++;
675
+ }
676
+
677
+ return tensor;
678
+
679
+ }
680
+
681
+ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required) {
682
+ const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
683
+
684
+ if (cur == NULL) {
685
+ return NULL;
686
+ }
687
+
688
+ if (cur->type != base->type) {
689
+ throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type)));
690
+ }
691
+
692
+ std::array<int64_t, GGML_MAX_DIMS> dims;
693
+ for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
694
+ dims[i] = i < ne.size() ? ne.begin()[i] : 1;
695
+ }
696
+
697
+ struct ggml_tensor * tensor = ggml_view_4d(ctx, base,
698
+ dims[0], dims[1], dims[2], dims[3],
699
+ cur->nb[1], cur->nb[2], cur->nb[3],
700
+ offset);
701
+
702
+ ggml_set_name(tensor, name.c_str());
703
+
704
+ n_created++;
705
+
706
+ return tensor;
707
+ }
708
+
709
+ void llama_model_loader::done_getting_tensors() const {
710
+ if (n_created != n_tensors) {
711
+ throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
712
+ }
713
+ }
714
+
715
+ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps) {
716
+ if (use_mmap) {
717
+ mappings.reserve(files.size());
718
+ mmaps_used.reserve(files.size());
719
+ for (const auto & file : files) {
720
+ auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
721
+ auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
722
+ std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn()));
723
+ mmaps_used.emplace_back(mapping->size(), 0);
724
+ if (mlock_mmaps) {
725
+ std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
726
+ mlock_mmap->init(mapping->addr());
727
+ mlock_mmaps->emplace_back(std::move(mlock_mmap));
728
+ }
729
+ mappings.emplace_back(std::move(mapping));
730
+ }
731
+ }
732
+
733
+ // compute the total size of all tensors for progress reporting
734
+ for (const auto & it : weights_map) {
735
+ size_data += ggml_nbytes(it.second.tensor);
736
+ }
737
+ }
738
+
739
+ void llama_model_loader::get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const {
740
+ GGML_ASSERT(!mappings.empty());
741
+ const auto & mapping = mappings.at(idx);
742
+
743
+ *first = mapping->size();
744
+ *last = 0;
745
+ *addr = mapping->addr();
746
+ for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
747
+ const auto * weight = get_weight(ggml_get_name(tensor));
748
+ if (!weight || weight->idx != idx) {
749
+ continue;
750
+ }
751
+ *first = std::min(*first, weight->offs);
752
+ *last = std::max(*last, weight->offs + ggml_nbytes(tensor));
753
+ }
754
+ }
755
+
756
+ void llama_model_loader::load_data_for(struct ggml_tensor * cur) const {
757
+ const auto & w = require_weight(ggml_get_name(cur));
758
+
759
+ if (use_mmap) {
760
+ const auto & mapping = mappings.at(w.idx);
761
+ if (cur->data == nullptr) {
762
+ cur->data = (uint8_t *)mapping->addr() + w.offs;
763
+ } else {
764
+ memcpy(cur->data, (uint8_t *)mapping->addr() + w.offs, ggml_nbytes(cur));
765
+ }
766
+ } else {
767
+ GGML_ASSERT(cur->data != nullptr);
768
+ GGML_ASSERT(w.idx < files.size());
769
+ const auto & file = files.at(w.idx);
770
+ file->seek(w.offs, SEEK_SET);
771
+ file->read_raw(cur->data, ggml_nbytes(cur));
772
+ }
773
+
774
+ if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) {
775
+ throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
776
+ }
777
+ }
778
+
779
+ bool llama_model_loader::load_all_data(
780
+ struct ggml_context * ctx,
781
+ llama_buf_map & bufs,
782
+ llama_mlocks * lmlocks,
783
+ llama_progress_callback progress_callback,
784
+ void * progress_callback_user_data) {
785
+ GGML_ASSERT(size_data != 0 && "call init_mappings() first");
786
+
787
+ std::vector<no_init<uint8_t>> read_buf;
788
+ std::vector<std::future<std::pair<ggml_tensor *, bool>>> validation_result;
789
+
790
+ // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
791
+ // NVMe raid configurations might require more / larger buffers.
792
+ constexpr size_t n_buffers = 4;
793
+ constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
794
+
795
+ std::vector<ggml_backend_buffer_t> host_buffers;
796
+ std::vector<ggml_backend_event_t> events;
797
+ std::vector<void *> host_ptrs;
798
+ size_t buffer_idx = 0; // buffer to use for async loads
799
+ ggml_backend_t upload_backend = [&](const char * func) -> ggml_backend_t {
800
+ if (use_mmap || check_tensors) {
801
+ return nullptr;
802
+ }
803
+ // When not using mmaped io use async uploads from pinned memory to GPU memory.
804
+ // First determine if the backend supports the necessary features for async uploads.
805
+ auto * buf = bufs.count(0) ? bufs.at(0) : nullptr;
806
+ if (!buf) {
807
+ LLAMA_LOG_DEBUG("%s: no buffer found for async uploads\n", func);
808
+ return nullptr;
809
+ }
810
+
811
+ auto * buft = ggml_backend_buffer_get_type(buf);
812
+ auto * dev = ggml_backend_buft_get_device(buft);
813
+ if (!dev) {
814
+ LLAMA_LOG_DEBUG("%s: no device found for buffer type %s for async uploads\n", func,
815
+ ggml_backend_buft_name(buft));
816
+ return nullptr;
817
+ }
818
+
819
+ if (buft != ggml_backend_dev_buffer_type(dev)) {
820
+ LLAMA_LOG_DEBUG("%s: buffer type %s is not the default buffer type for device %s for async uploads\n", func,
821
+ ggml_backend_buft_name(buft), ggml_backend_dev_name(dev));
822
+ return nullptr;
823
+ }
824
+
825
+ ggml_backend_dev_props props;
826
+ ggml_backend_dev_get_props(dev, &props);
827
+ if (!props.caps.async || !props.caps.host_buffer || !props.caps.events) {
828
+ LLAMA_LOG_DEBUG("%s: device %s does not support async, host buffers or events\n", func,
829
+ ggml_backend_dev_name(dev));
830
+ return nullptr;
831
+ }
832
+
833
+ auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
834
+ if (!host_buft) {
835
+ LLAMA_LOG_DEBUG("%s: no host buffer type found for device %s\n", func,
836
+ ggml_backend_dev_name(dev));
837
+ return nullptr;
838
+ }
839
+
840
+ // If the backend is supported, create pinned memory buffers and events for synchronisation.
841
+ for (size_t idx = 0; idx < n_buffers; ++idx) {
842
+ auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size);
843
+ if (!buf) {
844
+ LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func,
845
+ ggml_backend_dev_name(dev));
846
+ return nullptr;
847
+ }
848
+
849
+ host_buffers.emplace_back(buf);
850
+ host_ptrs.emplace_back(ggml_backend_buffer_get_base(buf));
851
+
852
+ auto * event = ggml_backend_event_new(dev);
853
+ if (!event) {
854
+ LLAMA_LOG_DEBUG("%s: failed to create event for async uploads for device %s\n", func,
855
+ ggml_backend_dev_name(dev));
856
+ return nullptr;
857
+ }
858
+
859
+ events.emplace_back(event);
860
+ }
861
+
862
+ ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
863
+ if (!backend) {
864
+ LLAMA_LOG_DEBUG("%s: failed to initialize backend for device %s for async uploads\n", func,
865
+ ggml_backend_dev_name(dev));
866
+ return nullptr;
867
+ }
868
+
869
+ return backend;
870
+ }(__func__);
871
+
872
+ if (upload_backend) {
873
+ LLAMA_LOG_DEBUG("%s: using async uploads for device %s, buffer type %s, backend %s\n", __func__,
874
+ ggml_backend_dev_name(ggml_backend_get_device(upload_backend)),
875
+ ggml_backend_buft_name(ggml_backend_buffer_get_type(bufs.at(0))),
876
+ ggml_backend_name(upload_backend));
877
+ }
878
+
879
+ for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
880
+ const auto * weight = get_weight(ggml_get_name(cur));
881
+ if (weight == nullptr) {
882
+ // this can happen with split experts models
883
+ continue;
884
+ }
885
+
886
+ if (progress_callback) {
887
+ if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) {
888
+ return false;
889
+ }
890
+ }
891
+
892
+ size_t n_size = ggml_nbytes(cur);
893
+
894
+ if (use_mmap) {
895
+ const auto & mapping = mappings.at(weight->idx);
896
+ ggml_backend_buffer_t buf_mmap = nullptr;
897
+ if (bufs.count(weight->idx)) {
898
+ buf_mmap = bufs.at(weight->idx);
899
+ }
900
+ uint8_t * data = (uint8_t *) mapping->addr() + weight->offs;
901
+
902
+ if (check_tensors) {
903
+ validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] {
904
+ return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size));
905
+ }));
906
+ }
907
+
908
+ GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
909
+ if (buf_mmap && cur->data == nullptr) {
910
+ ggml_backend_tensor_alloc(buf_mmap, cur, data);
911
+ if (lmlocks) {
912
+ const auto & lmlock = lmlocks->at(weight->idx);
913
+ lmlock->grow_to(weight->offs + n_size);
914
+ }
915
+
916
+ auto & mmap_used = mmaps_used[weight->idx];
917
+ mmap_used.first = std::min(mmap_used.first, weight->offs);
918
+ mmap_used.second = std::max(mmap_used.second, weight->offs + n_size);
919
+ } else {
920
+ ggml_backend_tensor_set(cur, data, 0, n_size);
921
+ }
922
+ } else {
923
+ const auto & file = files.at(weight->idx);
924
+ if (ggml_backend_buffer_is_host(cur->buffer)) {
925
+ file->seek(weight->offs, SEEK_SET);
926
+ file->read_raw(cur->data, n_size);
927
+ if (check_tensors) {
928
+ validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
929
+ return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
930
+ }));
931
+ }
932
+ } else {
933
+ // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
934
+ if (upload_backend) {
935
+ file->seek(weight->offs, SEEK_SET);
936
+
937
+ size_t bytes_read = 0;
938
+
939
+ while (bytes_read < n_size) {
940
+ size_t read_iteration = std::min<size_t>(buffer_size, n_size - bytes_read);
941
+
942
+ ggml_backend_event_synchronize(events[buffer_idx]);
943
+ file->read_raw(host_ptrs[buffer_idx], read_iteration);
944
+ ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
945
+ ggml_backend_event_record(events[buffer_idx], upload_backend);
946
+
947
+ bytes_read += read_iteration;
948
+ ++buffer_idx;
949
+ buffer_idx %= n_buffers;
950
+ }
951
+ } else {
952
+ read_buf.resize(n_size);
953
+ file->seek(weight->offs, SEEK_SET);
954
+ file->read_raw(read_buf.data(), n_size);
955
+ ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
956
+ if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
957
+ throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
958
+ }
959
+ }
960
+ }
961
+ }
962
+
963
+ size_done += n_size;
964
+ }
965
+
966
+ // free temporary resources used for async uploads
967
+ for (auto * event : events) {
968
+ ggml_backend_event_synchronize(event);
969
+ ggml_backend_event_free(event);
970
+ }
971
+ for (auto * buf : host_buffers) {
972
+ ggml_backend_buffer_free(buf);
973
+ }
974
+ ggml_backend_free(upload_backend);
975
+
976
+ // check validation results
977
+ bool validation_failed = false;
978
+ for (auto & future : validation_result) {
979
+ auto result = future.get();
980
+ if (!result.second) {
981
+ LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first));
982
+ validation_failed = true;
983
+ }
984
+ }
985
+ if (validation_failed) {
986
+ throw std::runtime_error("found tensors with invalid data");
987
+ }
988
+
989
+ // check if this is the last call and do final cleanup
990
+ if (size_done >= size_data) {
991
+ // unmap offloaded tensors and metadata
992
+ if (use_mmap) {
993
+ for (uint32_t idx = 0; idx < mappings.size(); idx++) {
994
+ const auto & mmap_used = mmaps_used.at(idx);
995
+ auto & mapping = mappings.at(idx);
996
+ mapping->unmap_fragment(0, mmap_used.first);
997
+ if (mmap_used.second != 0) {
998
+ mapping->unmap_fragment(mmap_used.second, mapping->size());
999
+ }
1000
+ }
1001
+ }
1002
+ if (progress_callback) {
1003
+ // Even though the model is done loading, we still honor
1004
+ // cancellation since we need to free allocations.
1005
+ return progress_callback(1.0f, progress_callback_user_data);
1006
+ }
1007
+ }
1008
+
1009
+ return true;
1010
+ }
examples/talk-llama/llama-model-loader.h ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+
5
+ #include "llama-impl.h"
6
+ #include "llama-arch.h"
7
+ #include "llama-mmap.h"
8
+
9
+ #include "ggml-cpp.h"
10
+
11
+ #include <cstddef>
12
+ #include <map>
13
+ #include <stdexcept>
14
+ #include <unordered_map>
15
+
16
+ using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
17
+
18
+ enum llama_fver {
19
+ GGUF_FILE_VERSION_V1 = 1,
20
+ GGUF_FILE_VERSION_V2 = 2,
21
+ GGUF_FILE_VERSION_V3 = 3,
22
+ };
23
+
24
+ const char * llama_file_version_name(llama_fver version);
25
+
26
+ struct llama_model_loader {
27
+ // Holds information on a model weight
28
+ struct llama_tensor_weight {
29
+ uint16_t idx; // source file index
30
+ size_t offs; // tensor data offset in the original file
31
+
32
+ ggml_tensor * tensor;
33
+
34
+ llama_tensor_weight(const llama_file * file, uint16_t idx, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
35
+ const int tensor_idx = gguf_find_tensor(gguf_ctx, ggml_get_name(tensor));
36
+ if (tensor_idx < 0) {
37
+ throw std::runtime_error(format("tensor '%s' not found in the model", ggml_get_name(tensor)));
38
+ }
39
+
40
+ offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
41
+ if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) {
42
+ throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", ggml_get_name(tensor)));
43
+ }
44
+ }
45
+ };
46
+
47
+ // custom comparator to sort weights more nicely by layer
48
+ struct weight_name_comparer {
49
+ bool operator()(const std::string & a, const std::string & b) const {
50
+ int a_layer = -1;
51
+ int b_layer = -1;
52
+ sscanf(a.c_str(), "blk.%d.", &a_layer);
53
+ sscanf(b.c_str(), "blk.%d.", &b_layer);
54
+ if (a_layer != b_layer) {
55
+ return a_layer < b_layer;
56
+ }
57
+ return a < b;
58
+ }
59
+ };
60
+
61
+ static const int TENSOR_NOT_REQUIRED = 1;
62
+ static const int TENSOR_DUPLICATED = 2;
63
+
64
+ int n_kv = 0;
65
+ int n_tensors = 0;
66
+ int n_created = 0;
67
+
68
+ uint64_t n_elements = 0;
69
+ size_t n_bytes = 0;
70
+
71
+ bool use_mmap = false;
72
+ bool check_tensors;
73
+
74
+ llama_files files;
75
+ llama_ftype ftype;
76
+ llama_fver fver;
77
+
78
+ llama_mmaps mappings;
79
+
80
+ std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
81
+ std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
82
+
83
+ gguf_context_ptr meta;
84
+ std::vector<ggml_context_ptr> contexts;
85
+
86
+ std::string arch_name;
87
+ LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
88
+
89
+ size_t size_done = 0;
90
+ size_t size_data = 0;
91
+ std::vector<std::pair<size_t, size_t>> mmaps_used;
92
+
93
+ llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p);
94
+
95
+ template<typename T>
96
+ typename std::enable_if<std::is_integral<T>::value, bool>::type
97
+ get_arr_n(const std::string & key, T & result, bool required = true);
98
+
99
+ template<typename T>
100
+ typename std::enable_if<std::is_integral<T>::value, bool>::type
101
+ get_arr_n(enum llm_kv kid, T & result, bool required = true);
102
+
103
+ template<typename T>
104
+ bool get_arr(const std::string & key, std::vector<T> & result, bool required = true);
105
+
106
+ template<typename T, size_t N_MAX>
107
+ bool get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required = true);
108
+
109
+ template<typename T>
110
+ bool get_arr(enum llm_kv kid, T & result, bool required = true);
111
+
112
+ template<typename T>
113
+ bool get_key(const std::string & key, T & result, bool required = true);
114
+
115
+ template<typename T>
116
+ bool get_key(enum llm_kv kid, T & result, bool required = true);
117
+
118
+ template<typename T, size_t N_MAX>
119
+ bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, bool required = true);
120
+
121
+ template<typename T>
122
+ bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true);
123
+
124
+ std::string get_arch_name() const;
125
+
126
+ enum llm_arch get_arch() const;
127
+
128
+ const llama_tensor_weight * get_weight(const char * name) const;
129
+
130
+ const llama_tensor_weight & require_weight(const char * name) const;
131
+
132
+ struct ggml_tensor * get_tensor_meta(const char * name) const;
133
+
134
+ struct ggml_tensor * require_tensor_meta(const std::string & name) const;
135
+
136
+ const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const;
137
+
138
+ struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags = 0);
139
+
140
+ struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true);
141
+
142
+ void done_getting_tensors() const;
143
+
144
+ void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr);
145
+
146
+ void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const;
147
+
148
+ // for backwards compatibility, does not support ggml-backend
149
+ void load_data_for(struct ggml_tensor * cur) const;
150
+
151
+ // Returns false if cancelled by progress_callback
152
+ bool load_all_data(
153
+ struct ggml_context * ctx,
154
+ llama_buf_map & bufs,
155
+ llama_mlocks * lmlocks,
156
+ llama_progress_callback progress_callback,
157
+ void * progress_callback_user_data);
158
+ };
examples/talk-llama/llama-model.cpp ADDED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama-model.h ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+ #include "llama-arch.h"
5
+ #include "llama-hparams.h"
6
+ #include "llama-vocab.h"
7
+ #include "llama-mmap.h"
8
+
9
+ #include "ggml-cpp.h"
10
+
11
+ #include <vector>
12
+
13
+ // available models
14
+ // TODO: this enum does not follow the enum naming convention
15
+ enum llm_type {
16
+ MODEL_UNKNOWN,
17
+ MODEL_14M,
18
+ MODEL_17M,
19
+ MODEL_22M,
20
+ MODEL_33M,
21
+ MODEL_60M,
22
+ MODEL_70M,
23
+ MODEL_80M,
24
+ MODEL_109M,
25
+ MODEL_137M,
26
+ MODEL_160M,
27
+ MODEL_220M,
28
+ MODEL_250M,
29
+ MODEL_270M,
30
+ MODEL_335M,
31
+ MODEL_410M,
32
+ MODEL_450M,
33
+ MODEL_770M,
34
+ MODEL_780M,
35
+ MODEL_0_5B,
36
+ MODEL_1B,
37
+ MODEL_1_3B,
38
+ MODEL_1_4B,
39
+ MODEL_1_5B,
40
+ MODEL_1_6B,
41
+ MODEL_2B,
42
+ MODEL_2_8B,
43
+ MODEL_3B,
44
+ MODEL_4B,
45
+ MODEL_6B,
46
+ MODEL_6_9B,
47
+ MODEL_7B,
48
+ MODEL_8B,
49
+ MODEL_9B,
50
+ MODEL_11B,
51
+ MODEL_12B,
52
+ MODEL_13B,
53
+ MODEL_14B,
54
+ MODEL_15B,
55
+ MODEL_16B,
56
+ MODEL_20B,
57
+ MODEL_30B,
58
+ MODEL_32B,
59
+ MODEL_34B,
60
+ MODEL_35B,
61
+ MODEL_40B,
62
+ MODEL_65B,
63
+ MODEL_70B,
64
+ MODEL_236B,
65
+ MODEL_314B,
66
+ MODEL_671B,
67
+ MODEL_SMALL,
68
+ MODEL_MEDIUM,
69
+ MODEL_LARGE,
70
+ MODEL_XL,
71
+ MODEL_A1_7B,
72
+ MODEL_A2_7B,
73
+ MODEL_8x7B,
74
+ MODEL_8x22B,
75
+ MODEL_16x12B,
76
+ MODEL_10B_128x3_66B,
77
+ MODEL_57B_A14B,
78
+ MODEL_27B,
79
+ };
80
+
81
+ struct llama_layer_posnet {
82
+ // resnet
83
+ struct ggml_tensor * norm1 = nullptr;
84
+ struct ggml_tensor * norm1_b = nullptr;
85
+
86
+ struct ggml_tensor * conv1 = nullptr;
87
+ struct ggml_tensor * conv1_b = nullptr;
88
+
89
+ struct ggml_tensor * norm2 = nullptr;
90
+ struct ggml_tensor * norm2_b = nullptr;
91
+
92
+ struct ggml_tensor * conv2 = nullptr;
93
+ struct ggml_tensor * conv2_b = nullptr;
94
+
95
+ // attention
96
+ struct ggml_tensor * attn_norm = nullptr;
97
+ struct ggml_tensor * attn_norm_b = nullptr;
98
+
99
+ struct ggml_tensor * attn_q = nullptr;
100
+ struct ggml_tensor * attn_q_b = nullptr;
101
+
102
+ struct ggml_tensor * attn_k = nullptr;
103
+ struct ggml_tensor * attn_k_b = nullptr;
104
+
105
+ struct ggml_tensor * attn_v = nullptr;
106
+ struct ggml_tensor * attn_v_b = nullptr;
107
+
108
+ struct ggml_tensor * attn_o = nullptr;
109
+ struct ggml_tensor * attn_o_b = nullptr;
110
+
111
+ // normalize
112
+ struct ggml_tensor * norm = nullptr;
113
+ struct ggml_tensor * norm_b = nullptr;
114
+ };
115
+
116
+ struct llama_layer_convnext {
117
+ struct ggml_tensor * dw = nullptr;
118
+ struct ggml_tensor * dw_b = nullptr;
119
+
120
+ struct ggml_tensor * norm = nullptr;
121
+ struct ggml_tensor * norm_b = nullptr;
122
+
123
+ struct ggml_tensor * pw1 = nullptr;
124
+ struct ggml_tensor * pw1_b = nullptr;
125
+
126
+ struct ggml_tensor * pw2 = nullptr;
127
+ struct ggml_tensor * pw2_b = nullptr;
128
+
129
+ struct ggml_tensor * gamma = nullptr;
130
+ };
131
+
132
+ struct llama_layer {
133
+ // normalization
134
+ struct ggml_tensor * attn_norm = nullptr;
135
+ struct ggml_tensor * attn_norm_b = nullptr;
136
+ struct ggml_tensor * attn_norm_2 = nullptr;
137
+ struct ggml_tensor * attn_norm_2_b = nullptr;
138
+ struct ggml_tensor * attn_q_norm = nullptr;
139
+ struct ggml_tensor * attn_q_norm_b = nullptr;
140
+ struct ggml_tensor * attn_k_norm = nullptr;
141
+ struct ggml_tensor * attn_k_norm_b = nullptr;
142
+ struct ggml_tensor * attn_out_norm = nullptr;
143
+ struct ggml_tensor * attn_out_norm_b = nullptr;
144
+ struct ggml_tensor * attn_q_a_norm = nullptr;
145
+ struct ggml_tensor * attn_kv_a_norm = nullptr;
146
+ struct ggml_tensor * attn_sub_norm = nullptr;
147
+ struct ggml_tensor * attn_post_norm = nullptr;
148
+ struct ggml_tensor * ffn_sub_norm = nullptr;
149
+ struct ggml_tensor * attn_norm_cross = nullptr;
150
+ struct ggml_tensor * attn_norm_enc = nullptr;
151
+
152
+ // attention
153
+ struct ggml_tensor * wq = nullptr;
154
+ struct ggml_tensor * wk = nullptr;
155
+ struct ggml_tensor * wv = nullptr;
156
+ struct ggml_tensor * wo = nullptr;
157
+ struct ggml_tensor * wqkv = nullptr;
158
+ struct ggml_tensor * wq_a = nullptr;
159
+ struct ggml_tensor * wq_b = nullptr;
160
+ struct ggml_tensor * wkv_a_mqa = nullptr;
161
+ struct ggml_tensor * wkv_b = nullptr;
162
+ struct ggml_tensor * wq_cross = nullptr;
163
+ struct ggml_tensor * wk_cross = nullptr;
164
+ struct ggml_tensor * wv_cross = nullptr;
165
+ struct ggml_tensor * wo_cross = nullptr;
166
+ struct ggml_tensor * wq_enc = nullptr;
167
+ struct ggml_tensor * wk_enc = nullptr;
168
+ struct ggml_tensor * wv_enc = nullptr;
169
+ struct ggml_tensor * wo_enc = nullptr;
170
+
171
+ // attention bias
172
+ struct ggml_tensor * bq = nullptr;
173
+ struct ggml_tensor * bk = nullptr;
174
+ struct ggml_tensor * bv = nullptr;
175
+ struct ggml_tensor * bo = nullptr;
176
+ struct ggml_tensor * bqkv = nullptr;
177
+
178
+ // relative position bias
179
+ struct ggml_tensor * attn_rel_b = nullptr;
180
+ struct ggml_tensor * attn_rel_b_enc = nullptr;
181
+ struct ggml_tensor * attn_rel_b_cross = nullptr;
182
+
183
+ // normalization
184
+ struct ggml_tensor * ffn_norm = nullptr;
185
+ struct ggml_tensor * ffn_norm_b = nullptr;
186
+ struct ggml_tensor * ffn_post_norm = nullptr;
187
+ struct ggml_tensor * layer_out_norm = nullptr;
188
+ struct ggml_tensor * layer_out_norm_b = nullptr;
189
+ struct ggml_tensor * ffn_norm_exps = nullptr;
190
+ struct ggml_tensor * ffn_norm_enc = nullptr;
191
+
192
+ // ff
193
+ struct ggml_tensor * ffn_gate = nullptr; // w1
194
+ struct ggml_tensor * ffn_down = nullptr; // w2
195
+ struct ggml_tensor * ffn_up = nullptr; // w3
196
+ struct ggml_tensor * ffn_gate_enc = nullptr;
197
+ struct ggml_tensor * ffn_down_enc = nullptr;
198
+ struct ggml_tensor * ffn_up_enc = nullptr;
199
+
200
+ // ff MoE
201
+ struct ggml_tensor * ffn_gate_inp = nullptr;
202
+ struct ggml_tensor * ffn_gate_exps = nullptr;
203
+ struct ggml_tensor * ffn_down_exps = nullptr;
204
+ struct ggml_tensor * ffn_up_exps = nullptr;
205
+
206
+ // ff shared expert (shexp)
207
+ struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
208
+ struct ggml_tensor * ffn_gate_shexp = nullptr;
209
+ struct ggml_tensor * ffn_down_shexp = nullptr;
210
+ struct ggml_tensor * ffn_up_shexp = nullptr;
211
+
212
+ // ff bias
213
+ struct ggml_tensor * ffn_gate_b = nullptr;
214
+ struct ggml_tensor * ffn_down_b = nullptr; // b2
215
+ struct ggml_tensor * ffn_up_b = nullptr; // b3
216
+ struct ggml_tensor * ffn_act = nullptr;
217
+ struct ggml_tensor * ffn_exp_probs_b = nullptr;
218
+
219
+ // mamba proj
220
+ struct ggml_tensor * ssm_in = nullptr;
221
+ struct ggml_tensor * ssm_x = nullptr;
222
+ struct ggml_tensor * ssm_dt = nullptr;
223
+ struct ggml_tensor * ssm_out = nullptr;
224
+
225
+ // mamba
226
+ struct ggml_tensor * ssm_conv1d = nullptr;
227
+ struct ggml_tensor * ssm_a = nullptr;
228
+ struct ggml_tensor * ssm_d = nullptr;
229
+
230
+ // mamba bias
231
+ struct ggml_tensor * ssm_conv1d_b = nullptr;
232
+ struct ggml_tensor * ssm_dt_b = nullptr;
233
+
234
+ // rwkv
235
+ struct ggml_tensor * time_mix_w1 = nullptr;
236
+ struct ggml_tensor * time_mix_w2 = nullptr;
237
+ struct ggml_tensor * time_mix_lerp_x = nullptr;
238
+ struct ggml_tensor * time_mix_lerp_w = nullptr;
239
+ struct ggml_tensor * time_mix_lerp_k = nullptr;
240
+ struct ggml_tensor * time_mix_lerp_v = nullptr;
241
+ struct ggml_tensor * time_mix_lerp_r = nullptr;
242
+ struct ggml_tensor * time_mix_lerp_g = nullptr;
243
+
244
+ struct ggml_tensor * time_mix_first = nullptr;
245
+ struct ggml_tensor * time_mix_decay = nullptr;
246
+ struct ggml_tensor * time_mix_decay_w1 = nullptr;
247
+ struct ggml_tensor * time_mix_decay_w2 = nullptr;
248
+ struct ggml_tensor * time_mix_key = nullptr;
249
+ struct ggml_tensor * time_mix_value = nullptr;
250
+ struct ggml_tensor * time_mix_receptance = nullptr;
251
+ struct ggml_tensor * time_mix_gate = nullptr;
252
+
253
+ struct ggml_tensor * time_mix_ln = nullptr;
254
+ struct ggml_tensor * time_mix_ln_b = nullptr;
255
+ struct ggml_tensor * time_mix_output = nullptr;
256
+
257
+ struct ggml_tensor * channel_mix_lerp_k = nullptr;
258
+ struct ggml_tensor * channel_mix_lerp_r = nullptr;
259
+
260
+ struct ggml_tensor * channel_mix_key = nullptr;
261
+ struct ggml_tensor * channel_mix_receptance = nullptr;
262
+ struct ggml_tensor * channel_mix_value = nullptr;
263
+
264
+ // long rope factors
265
+ struct ggml_tensor * rope_long = nullptr;
266
+ struct ggml_tensor * rope_short = nullptr;
267
+ struct ggml_tensor * rope_freqs = nullptr;
268
+
269
+ // bitnet scale
270
+ struct ggml_tensor * wq_scale = nullptr;
271
+ struct ggml_tensor * wk_scale = nullptr;
272
+ struct ggml_tensor * wv_scale = nullptr;
273
+ struct ggml_tensor * wo_scale = nullptr;
274
+ struct ggml_tensor * ffn_gate_scale = nullptr;
275
+ struct ggml_tensor * ffn_up_scale = nullptr;
276
+ struct ggml_tensor * ffn_down_scale = nullptr;
277
+
278
+ struct llama_layer_posnet posnet;
279
+
280
+ struct llama_layer_convnext convnext;
281
+ };
282
+
283
+ struct llama_model {
284
+ llm_type type = MODEL_UNKNOWN;
285
+ llm_arch arch = LLM_ARCH_UNKNOWN;
286
+
287
+ llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
288
+
289
+ std::string name = "n/a";
290
+
291
+ llama_hparams hparams = {};
292
+ llama_vocab vocab;
293
+
294
+ struct ggml_tensor * tok_embd = nullptr;
295
+ struct ggml_tensor * type_embd = nullptr;
296
+ struct ggml_tensor * pos_embd = nullptr;
297
+ struct ggml_tensor * tok_norm = nullptr;
298
+ struct ggml_tensor * tok_norm_b = nullptr;
299
+
300
+ struct ggml_tensor * output_norm = nullptr;
301
+ struct ggml_tensor * output_norm_b = nullptr;
302
+ struct ggml_tensor * output = nullptr;
303
+ struct ggml_tensor * output_b = nullptr;
304
+ struct ggml_tensor * output_norm_enc = nullptr;
305
+
306
+ // classifier
307
+ struct ggml_tensor * cls = nullptr;
308
+ struct ggml_tensor * cls_b = nullptr;
309
+ struct ggml_tensor * cls_out = nullptr;
310
+ struct ggml_tensor * cls_out_b = nullptr;
311
+
312
+ struct ggml_tensor * conv1d = nullptr;
313
+ struct ggml_tensor * conv1d_b = nullptr;
314
+
315
+ std::vector<llama_layer> layers;
316
+
317
+ // gguf metadata
318
+ std::unordered_map<std::string, std::string> gguf_kv;
319
+
320
+ llama_split_mode split_mode;
321
+ int main_gpu;
322
+ int n_gpu_layers;
323
+
324
+ std::vector<std::string> rpc_servers;
325
+
326
+ // list of devices used in this model
327
+ std::vector<ggml_backend_dev_t> devices;
328
+
329
+
330
+ // lists of buffer types used for each layer
331
+ using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
332
+ buft_list_t cpu_buft_list;
333
+ std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
334
+
335
+ struct layer_dev {
336
+ ggml_backend_dev_t dev;
337
+ buft_list_t * buft_list;
338
+ };
339
+
340
+ layer_dev dev_input = {};
341
+ layer_dev dev_output = {};
342
+ std::vector<layer_dev> dev_layer;
343
+
344
+ // contexts where the model tensors metadata is stored
345
+ std::vector<ggml_context_ptr> ctxs;
346
+
347
+ // the model memory buffers for the tensor data
348
+ std::vector<ggml_backend_buffer_ptr> bufs;
349
+
350
+ // model memory mapped files
351
+ llama_mmaps mappings;
352
+
353
+ // objects representing data potentially being locked in memory
354
+ llama_mlocks mlock_bufs;
355
+ llama_mlocks mlock_mmaps;
356
+
357
+ // for quantize-stats only
358
+ std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
359
+
360
+ int64_t t_load_us = 0;
361
+ int64_t t_start_us = 0;
362
+
363
+ // total number of parameters in the model
364
+ uint64_t n_elements = 0;
365
+
366
+ // total size of all the tensors in the model in bytes
367
+ size_t n_bytes = 0;
368
+ };
369
+
370
+ const char * llm_type_name(llm_type type);
371
+
372
+ std::string llama_model_arch_name (const llama_model & model);
373
+ std::string llama_model_type_name (const llama_model & model);
374
+ std::string llama_model_ftype_name(const llama_model & model);
375
+
376
+ // used by llama_adapter_cvec
377
+ ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
378
+
379
+ // used by llama_adapter_lora
380
+ struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name);
381
+
382
+ size_t llama_model_max_nodes(const llama_model & model);
383
+
384
+ struct llama_model_loader;
385
+
386
+ // TODO: become llama_model methods
387
+ void llm_load_stats (llama_model_loader & ml, llama_model & model);
388
+ void llm_load_arch (llama_model_loader & ml, llama_model & model);
389
+ void llm_load_hparams (llama_model_loader & ml, llama_model & model);
390
+ void llm_load_vocab (llama_model_loader & ml, llama_model & model);
391
+ void llm_load_print_meta(llama_model_loader & ml, llama_model & model);
examples/talk-llama/llama-quant.cpp ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-quant.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-model.h"
5
+ #include "llama-model-loader.h"
6
+
7
+ #include <algorithm>
8
+ #include <cmath>
9
+ #include <cstring>
10
+ #include <fstream>
11
+ #include <mutex>
12
+ #include <thread>
13
+ #include <unordered_map>
14
+
15
+ // TODO: replace with ggml API call
16
+ #define QK_K 256
17
+
18
+ static void zeros(std::ofstream & file, size_t n) {
19
+ char zero = 0;
20
+ for (size_t i = 0; i < n; ++i) {
21
+ file.write(&zero, 1);
22
+ }
23
+ }
24
+
25
+ struct quantize_state_impl {
26
+ const llama_model & model;
27
+ const llama_model_quantize_params * params;
28
+
29
+ int n_attention_wv = 0;
30
+ int n_ffn_down = 0;
31
+ int n_ffn_gate = 0;
32
+ int n_ffn_up = 0;
33
+ int i_attention_wv = 0;
34
+ int i_ffn_down = 0;
35
+ int i_ffn_gate = 0;
36
+ int i_ffn_up = 0;
37
+
38
+ int n_k_quantized = 0;
39
+ int n_fallback = 0;
40
+
41
+ bool has_imatrix = false;
42
+
43
+ // used to figure out if a model shares tok_embd with the output weight
44
+ bool has_output = false;
45
+
46
+ quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
47
+ : model(model)
48
+ , params(params)
49
+ {}
50
+ };
51
+
52
+ static void llama_tensor_dequantize_impl(
53
+ struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
54
+ const size_t nelements, const int nthread
55
+ ) {
56
+ if (output.size() < nelements) {
57
+ output.resize(nelements);
58
+ }
59
+ float * f32_output = (float *) output.data();
60
+
61
+ const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type);
62
+ if (ggml_is_quantized(tensor->type)) {
63
+ if (qtype->to_float == NULL) {
64
+ throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
65
+ }
66
+ } else if (tensor->type != GGML_TYPE_F16 &&
67
+ tensor->type != GGML_TYPE_BF16) {
68
+ throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
69
+ }
70
+
71
+ if (nthread < 2) {
72
+ if (tensor->type == GGML_TYPE_F16) {
73
+ ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
74
+ } else if (tensor->type == GGML_TYPE_BF16) {
75
+ ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
76
+ } else if (ggml_is_quantized(tensor->type)) {
77
+ qtype->to_float(tensor->data, f32_output, nelements);
78
+ } else {
79
+ GGML_ABORT("fatal error"); // unreachable
80
+ }
81
+ return;
82
+ }
83
+
84
+ size_t block_size;
85
+ if (tensor->type == GGML_TYPE_F16 ||
86
+ tensor->type == GGML_TYPE_BF16) {
87
+ block_size = 1;
88
+ } else {
89
+ block_size = (size_t)ggml_blck_size(tensor->type);
90
+ }
91
+
92
+ size_t block_size_bytes = ggml_type_size(tensor->type);
93
+
94
+ GGML_ASSERT(nelements % block_size == 0);
95
+ size_t nblocks = nelements / block_size;
96
+ size_t blocks_per_thread = nblocks / nthread;
97
+ size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
98
+
99
+ size_t in_buff_offs = 0;
100
+ size_t out_buff_offs = 0;
101
+
102
+ for (int tnum = 0; tnum < nthread; tnum++) {
103
+ size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
104
+ size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
105
+ size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
106
+
107
+ auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
108
+ if (typ == GGML_TYPE_F16) {
109
+ ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
110
+ } else if (typ == GGML_TYPE_BF16) {
111
+ ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
112
+ } else {
113
+ qtype->to_float(inbuf, outbuf, nels);
114
+ }
115
+ };
116
+ workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
117
+ in_buff_offs += thr_block_bytes;
118
+ out_buff_offs += thr_elems;
119
+ }
120
+ for (auto & w : workers) { w.join(); }
121
+ workers.clear();
122
+ }
123
+
124
+ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
125
+ const std::string name = ggml_get_name(tensor);
126
+
127
+ // TODO: avoid hardcoded tensor names - use the TN_* constants
128
+ const llm_arch arch = qs.model.arch;
129
+ const auto tn = LLM_TN(arch);
130
+
131
+ auto use_more_bits = [](int i_layer, int n_layers) -> bool {
132
+ return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
133
+ };
134
+ const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
135
+ auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
136
+ if (n_expert > 1) {
137
+ // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
138
+ // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
139
+ // for getting the current layer as I initially thought, and we need to resort to parsing the
140
+ // tensor name.
141
+ if (sscanf(name, "blk.%d.", &i_layer) != 1) {
142
+ throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
143
+ }
144
+ if (i_layer < 0 || i_layer >= n_layer) {
145
+ throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
146
+ }
147
+ }
148
+ return std::make_pair(i_layer, n_layer);
149
+ };
150
+
151
+ // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
152
+ // with the quantization of the output tensor
153
+ if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
154
+ if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
155
+ new_type = qs.params->output_tensor_type;
156
+ } else {
157
+ int nx = tensor->ne[0];
158
+ if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
159
+ new_type = GGML_TYPE_Q8_0;
160
+ }
161
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
162
+ ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ||
163
+ ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
164
+ new_type = GGML_TYPE_Q5_K;
165
+ }
166
+ else if (new_type != GGML_TYPE_Q8_0) {
167
+ new_type = GGML_TYPE_Q6_K;
168
+ }
169
+ }
170
+ } else if (name == "token_embd.weight") {
171
+ if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
172
+ new_type = qs.params->token_embedding_type;
173
+ } else {
174
+ if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
175
+ ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
176
+ new_type = GGML_TYPE_Q2_K;
177
+ }
178
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
179
+ new_type = GGML_TYPE_IQ3_S;
180
+ }
181
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
182
+ new_type = GGML_TYPE_IQ3_S;
183
+ }
184
+ else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
185
+ new_type = GGML_TYPE_Q4_K;
186
+ }
187
+ }
188
+ } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
189
+ ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
190
+ if (name.find("attn_v.weight") != std::string::npos) {
191
+ if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
192
+ else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
193
+ ++qs.i_attention_wv;
194
+ }
195
+ else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
196
+ new_type = GGML_TYPE_Q4_K;
197
+ }
198
+ else if (name.find("ffn_down") != std::string::npos) {
199
+ if (qs.i_ffn_down < qs.n_ffn_down/8) {
200
+ new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
201
+ }
202
+ ++qs.i_ffn_down;
203
+ }
204
+ else if (name.find("attn_output.weight") != std::string::npos) {
205
+ if (qs.model.hparams.n_expert == 8) {
206
+ new_type = GGML_TYPE_Q5_K;
207
+ } else {
208
+ if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
209
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
210
+ }
211
+ }
212
+ } else if (name.find("attn_v.weight") != std::string::npos) {
213
+ if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
214
+ new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
215
+ }
216
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
217
+ new_type = GGML_TYPE_Q4_K;
218
+ }
219
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
220
+ new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
221
+ }
222
+ else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
223
+ new_type = GGML_TYPE_Q4_K;
224
+ }
225
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
226
+ new_type = GGML_TYPE_Q4_K;
227
+ }
228
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
229
+ new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
230
+ }
231
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
232
+ else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
233
+ new_type = GGML_TYPE_Q5_K;
234
+ }
235
+ else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
236
+ use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
237
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
238
+ if (qs.model.type == MODEL_70B) {
239
+ // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
240
+ // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
241
+ // nearly negligible increase in model size by quantizing this tensor with more bits:
242
+ if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
243
+ }
244
+ if (qs.model.hparams.n_expert == 8) {
245
+ // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
246
+ // TODO: explore better strategies
247
+ new_type = GGML_TYPE_Q8_0;
248
+ }
249
+ ++qs.i_attention_wv;
250
+ } else if (name.find("attn_k.weight") != std::string::npos) {
251
+ if (qs.model.hparams.n_expert == 8) {
252
+ // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
253
+ // TODO: explore better strategies
254
+ new_type = GGML_TYPE_Q8_0;
255
+ }
256
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
257
+ new_type = GGML_TYPE_IQ3_XXS;
258
+ }
259
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
260
+ new_type = GGML_TYPE_IQ2_S;
261
+ }
262
+ } else if (name.find("attn_q.weight") != std::string::npos) {
263
+ if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
264
+ new_type = GGML_TYPE_IQ3_XXS;
265
+ }
266
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
267
+ new_type = GGML_TYPE_IQ2_S;
268
+ }
269
+ } else if (name.find("ffn_down") != std::string::npos) {
270
+ auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
271
+ int i_layer = info.first, n_layer = info.second;
272
+ if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
273
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
274
+ if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
275
+ }
276
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
277
+ new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
278
+ }
279
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
280
+ new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
281
+ : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
282
+ : GGML_TYPE_Q3_K;
283
+ }
284
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
285
+ (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
286
+ new_type = GGML_TYPE_Q4_K;
287
+ }
288
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
289
+ new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
290
+ }
291
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
292
+ if (arch == LLM_ARCH_FALCON) {
293
+ new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
294
+ use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
295
+ } else {
296
+ if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
297
+ }
298
+ }
299
+ else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
300
+ new_type = GGML_TYPE_Q5_K;
301
+ }
302
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
303
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
304
+ new_type = GGML_TYPE_Q5_K;
305
+ }
306
+ else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
307
+ && qs.has_imatrix && i_layer < n_layer/8) {
308
+ // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
309
+ // We only do it when an imatrix is provided because a) we want to make sure that one can always get the
310
+ // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
311
+ new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
312
+ }
313
+ ++qs.i_ffn_down;
314
+ } else if (name.find("attn_output.weight") != std::string::npos) {
315
+ if (arch != LLM_ARCH_FALCON) {
316
+ if (qs.model.hparams.n_expert == 8) {
317
+ if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
318
+ ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
319
+ ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
320
+ ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
321
+ new_type = GGML_TYPE_Q5_K;
322
+ }
323
+ } else {
324
+ if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
325
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
326
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
327
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
328
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K;
329
+ }
330
+ } else {
331
+ if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
332
+ }
333
+ }
334
+ else if (name.find("attn_qkv.weight") != std::string::npos) {
335
+ if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
336
+ new_type = GGML_TYPE_Q4_K;
337
+ }
338
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
339
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
340
+ }
341
+ else if (name.find("ffn_gate") != std::string::npos) {
342
+ auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
343
+ int i_layer = info.first, n_layer = info.second;
344
+ if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
345
+ new_type = GGML_TYPE_IQ3_XXS;
346
+ }
347
+ ++qs.i_ffn_gate;
348
+ }
349
+ else if (name.find("ffn_up") != std::string::npos) {
350
+ auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
351
+ int i_layer = info.first, n_layer = info.second;
352
+ if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
353
+ new_type = GGML_TYPE_IQ3_XXS;
354
+ }
355
+ ++qs.i_ffn_up;
356
+ }
357
+
358
+ // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
359
+ //}
360
+ // IK: let's remove this, else Q2_K is almost the same as Q3_K_S
361
+ //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
362
+ // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
363
+ //}
364
+ // This can be used to reduce the size of the Q5_K_S model.
365
+ // The associated PPL increase is fully in line with the size reduction
366
+ //else {
367
+ // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
368
+ //}
369
+ bool convert_incompatible_tensor = false;
370
+ if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
371
+ new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS ||
372
+ new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S ||
373
+ new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S ||
374
+ new_type == GGML_TYPE_IQ1_M) {
375
+ int nx = tensor->ne[0];
376
+ int ny = tensor->ne[1];
377
+ if (nx % QK_K != 0) {
378
+ LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));
379
+ convert_incompatible_tensor = true;
380
+ } else {
381
+ ++qs.n_k_quantized;
382
+ }
383
+ }
384
+ if (convert_incompatible_tensor) {
385
+ switch (new_type) {
386
+ case GGML_TYPE_TQ1_0:
387
+ case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
388
+ case GGML_TYPE_IQ2_XXS:
389
+ case GGML_TYPE_IQ2_XS:
390
+ case GGML_TYPE_IQ2_S:
391
+ case GGML_TYPE_IQ3_XXS:
392
+ case GGML_TYPE_IQ3_S:
393
+ case GGML_TYPE_IQ1_S:
394
+ case GGML_TYPE_IQ1_M:
395
+ case GGML_TYPE_Q2_K:
396
+ case GGML_TYPE_Q3_K:
397
+ case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
398
+ case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
399
+ case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
400
+ case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
401
+ default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
402
+ }
403
+ if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
404
+ new_type = GGML_TYPE_F16;
405
+ }
406
+ LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
407
+ ++qs.n_fallback;
408
+ }
409
+
410
+ return new_type;
411
+ }
412
+
413
+ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
414
+ if (nthread < 2) {
415
+ // single-thread
416
+ size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
417
+ if (!ggml_validate_row_data(new_type, new_data, new_size)) {
418
+ throw std::runtime_error("quantized data validation failed");
419
+ }
420
+ return new_size;
421
+ }
422
+
423
+ std::mutex mutex;
424
+ int64_t counter = 0;
425
+ size_t new_size = 0;
426
+ bool valid = true;
427
+ auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
428
+ nrows, n_per_row, imatrix]() {
429
+ const int64_t nrows_per_chunk = chunk_size / n_per_row;
430
+ size_t local_size = 0;
431
+ while (true) {
432
+ std::unique_lock<std::mutex> lock(mutex);
433
+ int64_t first_row = counter; counter += nrows_per_chunk;
434
+ if (first_row >= nrows) {
435
+ if (local_size > 0) {
436
+ new_size += local_size;
437
+ }
438
+ break;
439
+ }
440
+ lock.unlock();
441
+ const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
442
+ size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
443
+ local_size += this_size;
444
+
445
+ // validate the quantized data
446
+ const size_t row_size = ggml_row_size(new_type, n_per_row);
447
+ void * this_data = (char *) new_data + first_row * row_size;
448
+ if (!ggml_validate_row_data(new_type, this_data, this_size)) {
449
+ std::unique_lock<std::mutex> lock(mutex);
450
+ valid = false;
451
+ break;
452
+ }
453
+ }
454
+ };
455
+ for (int it = 0; it < nthread - 1; ++it) {
456
+ workers.emplace_back(compute);
457
+ }
458
+ compute();
459
+ for (auto & w : workers) { w.join(); }
460
+ workers.clear();
461
+ if (!valid) {
462
+ throw std::runtime_error("quantized data validation failed");
463
+ }
464
+ return new_size;
465
+ }
466
+
467
+ static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
468
+ ggml_type default_type;
469
+ llama_ftype ftype = params->ftype;
470
+
471
+ switch (params->ftype) {
472
+ case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
473
+ case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
474
+ case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
475
+ case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
476
+ case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
477
+ case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
478
+ case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
479
+ case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
480
+
481
+ // K-quants
482
+ case LLAMA_FTYPE_MOSTLY_Q2_K_S:
483
+ case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
484
+ case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break;
485
+ case LLAMA_FTYPE_MOSTLY_Q3_K_S:
486
+ case LLAMA_FTYPE_MOSTLY_Q3_K_M:
487
+ case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break;
488
+ case LLAMA_FTYPE_MOSTLY_Q4_K_S:
489
+ case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break;
490
+ case LLAMA_FTYPE_MOSTLY_Q5_K_S:
491
+ case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break;
492
+ case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
493
+ case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break;
494
+ case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break;
495
+ case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
496
+ case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
497
+ case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break;
498
+ case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break;
499
+ case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
500
+ case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break;
501
+ case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break;
502
+ case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
503
+ case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
504
+ case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
505
+ case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
506
+
507
+ default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
508
+ }
509
+
510
+ int nthread = params->nthread;
511
+
512
+ if (nthread <= 0) {
513
+ nthread = std::thread::hardware_concurrency();
514
+ }
515
+
516
+ // mmap consistently increases speed Linux, and also increases speed on Windows with
517
+ // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
518
+ #if defined(__linux__) || defined(_WIN32)
519
+ constexpr bool use_mmap = true;
520
+ #else
521
+ constexpr bool use_mmap = false;
522
+ #endif
523
+
524
+ llama_model_kv_override * kv_overrides = nullptr;
525
+ if (params->kv_overrides) {
526
+ auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
527
+ kv_overrides = v->data();
528
+ }
529
+ llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
530
+ ml.init_mappings(false); // no prefetching
531
+
532
+ llama_model model;
533
+ llm_load_arch (ml, model);
534
+ llm_load_hparams(ml, model);
535
+ llm_load_stats (ml, model);
536
+
537
+ struct quantize_state_impl qs(model, params);
538
+
539
+ if (params->only_copy) {
540
+ ftype = model.ftype;
541
+ }
542
+ const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
543
+ if (params->imatrix) {
544
+ imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
545
+ if (imatrix_data) {
546
+ LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
547
+ qs.has_imatrix = true;
548
+ // check imatrix for nans or infs
549
+ for (const auto & kv : *imatrix_data) {
550
+ for (float f : kv.second) {
551
+ if (!std::isfinite(f)) {
552
+ throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
553
+ }
554
+ }
555
+ }
556
+ }
557
+ }
558
+
559
+ const size_t align = GGUF_DEFAULT_ALIGNMENT;
560
+ gguf_context_ptr ctx_out { gguf_init_empty() };
561
+
562
+ // copy the KV pairs from the input file
563
+ gguf_set_kv (ctx_out.get(), ml.meta.get());
564
+ gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
565
+ gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
566
+
567
+ // Remove split metadata
568
+ gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
569
+ gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
570
+ gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
571
+
572
+ if (params->kv_overrides) {
573
+ const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
574
+ for (const auto & o : overrides) {
575
+ if (o.key[0] == 0) break;
576
+ if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
577
+ gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
578
+ } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
579
+ gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64);
580
+ } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
581
+ gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
582
+ } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
583
+ gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
584
+ } else {
585
+ LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
586
+ }
587
+ }
588
+ }
589
+
590
+ // make a list of weights
591
+ std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
592
+ tensors.reserve(ml.weights_map.size());
593
+ for (const auto & it : ml.weights_map) {
594
+ tensors.push_back(&it.second);
595
+ }
596
+
597
+ // keep_split requires that the weights are sorted by split index
598
+ if (params->keep_split) {
599
+ std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) {
600
+ if (a->idx == b->idx) {
601
+ return a->offs < b->offs;
602
+ }
603
+ return a->idx < b->idx;
604
+ });
605
+ }
606
+
607
+ for (const auto * it : tensors) {
608
+ const struct ggml_tensor * tensor = it->tensor;
609
+
610
+ const std::string name = ggml_get_name(tensor);
611
+
612
+ // TODO: avoid hardcoded tensor names - use the TN_* constants
613
+ if (name.find("attn_v.weight") != std::string::npos ||
614
+ name.find("attn_qkv.weight") != std::string::npos ||
615
+ name.find("attn_kv_b.weight")!= std::string::npos) {
616
+ ++qs.n_attention_wv;
617
+ } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
618
+ qs.has_output = true;
619
+ }
620
+ }
621
+
622
+ qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
623
+
624
+ // sanity checks
625
+ {
626
+ const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
627
+ // attention layers have a non-zero number of kv heads
628
+ int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
629
+ if (llama_model_has_encoder(&model)) {
630
+ n_attn_layer *= 3;
631
+ }
632
+ GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
633
+ }
634
+
635
+ size_t total_size_org = 0;
636
+ size_t total_size_new = 0;
637
+
638
+ std::vector<std::thread> workers;
639
+ workers.reserve(nthread);
640
+
641
+ int idx = 0;
642
+
643
+ std::vector<no_init<uint8_t>> read_data;
644
+ std::vector<no_init<uint8_t>> work;
645
+ std::vector<no_init<float>> f32_conv_buf;
646
+
647
+ uint16_t n_split = 1;
648
+
649
+ // Assume split index is continuous
650
+ if (params->keep_split) {
651
+ for (const auto * it : tensors) {
652
+ n_split = std::max(uint16_t(it->idx + 1), n_split);
653
+ }
654
+ }
655
+ std::vector<gguf_context_ptr> ctx_outs(n_split);
656
+ ctx_outs[0] = std::move(ctx_out);
657
+
658
+ // populate the original tensors so we get an initial meta data
659
+ for (const auto * it : tensors) {
660
+ uint16_t i_split = params->keep_split ? it->idx : 0;
661
+ struct ggml_tensor * tensor = it->tensor;
662
+ if (!ctx_outs[i_split]) {
663
+ ctx_outs[i_split].reset(gguf_init_empty());
664
+ }
665
+ gguf_add_tensor(ctx_outs[i_split].get(), tensor);
666
+ }
667
+
668
+ // Set split info if needed
669
+ if (n_split > 1) {
670
+ for (size_t i = 0; i < ctx_outs.size(); ++i) {
671
+ gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
672
+ gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
673
+ gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
674
+ }
675
+ }
676
+
677
+ int cur_split = -1;
678
+ std::ofstream fout;
679
+ auto close_ofstream = [&]() {
680
+ // Write metadata and close file handler
681
+ if (fout.is_open()) {
682
+ fout.seekp(0);
683
+ std::vector<uint8_t> data(gguf_get_meta_size(ctx_outs[cur_split].get()));
684
+ gguf_get_meta_data(ctx_outs[cur_split].get(), data.data());
685
+ fout.write((const char *) data.data(), data.size());
686
+ fout.close();
687
+ }
688
+ };
689
+ auto new_ofstream = [&](int index) {
690
+ cur_split = index;
691
+ GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context");
692
+ std::string fname = fname_out;
693
+ if (params->keep_split) {
694
+ std::vector<char> split_path(llama_path_max(), 0);
695
+ llama_split_path(split_path.data(), split_path.size(), fname_out.c_str(), cur_split, n_split);
696
+ fname = std::string(split_path.data());
697
+ }
698
+
699
+ fout = std::ofstream(fname, std::ios::binary);
700
+ fout.exceptions(std::ofstream::failbit); // fail fast on write errors
701
+ const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get());
702
+ // placeholder for the meta data
703
+ ::zeros(fout, meta_size);
704
+ };
705
+
706
+ const auto tn = LLM_TN(model.arch);
707
+ new_ofstream(0);
708
+ for (const auto * it : tensors) {
709
+ const auto & weight = *it;
710
+ struct ggml_tensor * tensor = weight.tensor;
711
+ if (weight.idx != cur_split && params->keep_split) {
712
+ close_ofstream();
713
+ new_ofstream(weight.idx);
714
+ }
715
+
716
+ const std::string name = ggml_get_name(tensor);
717
+
718
+ if (!ml.use_mmap) {
719
+ if (read_data.size() < ggml_nbytes(tensor)) {
720
+ read_data.resize(ggml_nbytes(tensor));
721
+ }
722
+ tensor->data = read_data.data();
723
+ }
724
+ ml.load_data_for(tensor);
725
+
726
+ LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
727
+ ++idx, ml.n_tensors,
728
+ ggml_get_name(tensor),
729
+ llama_format_tensor_shape(tensor).c_str(),
730
+ ggml_type_name(tensor->type));
731
+
732
+ // This used to be a regex, but <regex> has an extreme cost to compile times.
733
+ bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
734
+
735
+ // quantize only 2D and 3D tensors (experts)
736
+ quantize &= (ggml_n_dims(tensor) >= 2);
737
+
738
+ // do not quantize norm tensors
739
+ quantize &= name.find("_norm.weight") == std::string::npos;
740
+
741
+ quantize &= params->quantize_output_tensor || name != "output.weight";
742
+ quantize &= !params->only_copy;
743
+
744
+ // do not quantize expert gating tensors
745
+ // NOTE: can't use LLM_TN here because the layer number is not known
746
+ quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
747
+
748
+ // do not quantize positional embeddings and token types (BERT)
749
+ quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
750
+ quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
751
+
752
+ // do not quantize Mamba's small yet 2D weights
753
+ // NOTE: can't use LLM_TN here because the layer number is not known
754
+ quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
755
+
756
+ // do not quantize RWKV's time_mix_first tensors
757
+ quantize &= name.find("time_mix_first.weight") == std::string::npos;
758
+ quantize &= name.find("time_mix_w1.weight") == std::string::npos;
759
+ quantize &= name.find("time_mix_w2.weight") == std::string::npos;
760
+ quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
761
+ quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
762
+
763
+ // do not quantize relative position bias (T5)
764
+ quantize &= name.find("attn_rel_b.weight") == std::string::npos;
765
+
766
+ enum ggml_type new_type;
767
+ void * new_data;
768
+ size_t new_size;
769
+
770
+ if (quantize) {
771
+ new_type = default_type;
772
+
773
+ // get more optimal quantization type based on the tensor shape, layer, etc.
774
+ if (!params->pure && ggml_is_quantized(default_type)) {
775
+ new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
776
+ }
777
+ if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
778
+ new_type = params->token_embedding_type;
779
+ }
780
+ if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
781
+ new_type = params->output_tensor_type;
782
+ }
783
+
784
+ // If we've decided to quantize to the same type the tensor is already
785
+ // in then there's nothing to do.
786
+ quantize = tensor->type != new_type;
787
+ }
788
+
789
+ if (!quantize) {
790
+ new_type = tensor->type;
791
+ new_data = tensor->data;
792
+ new_size = ggml_nbytes(tensor);
793
+ LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
794
+ } else {
795
+ const int64_t nelements = ggml_nelements(tensor);
796
+
797
+ const float * imatrix = nullptr;
798
+ if (imatrix_data) {
799
+ auto it = imatrix_data->find(tensor->name);
800
+ if (it == imatrix_data->end()) {
801
+ LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
802
+ } else {
803
+ if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
804
+ imatrix = it->second.data();
805
+ } else {
806
+ LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
807
+ int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
808
+
809
+ // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
810
+ // this is a significant error and it may be good idea to abort the process if this happens,
811
+ // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
812
+ // tok_embd should be ignored in this case, since it always causes this warning
813
+ if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
814
+ throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
815
+ int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
816
+ }
817
+ }
818
+ }
819
+ }
820
+ if ((new_type == GGML_TYPE_IQ2_XXS ||
821
+ new_type == GGML_TYPE_IQ2_XS ||
822
+ new_type == GGML_TYPE_IQ2_S ||
823
+ new_type == GGML_TYPE_IQ1_S ||
824
+ (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) ||
825
+ (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
826
+ LLAMA_LOG_ERROR("\n\n============================================================\n");
827
+ LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
828
+ LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
829
+ LLAMA_LOG_ERROR("============================================================\n\n");
830
+ throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
831
+ }
832
+
833
+ float * f32_data;
834
+
835
+ if (tensor->type == GGML_TYPE_F32) {
836
+ f32_data = (float *) tensor->data;
837
+ } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
838
+ throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
839
+ } else {
840
+ llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
841
+ f32_data = (float *) f32_conv_buf.data();
842
+ }
843
+
844
+ LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
845
+ fflush(stdout);
846
+
847
+ if (work.size() < (size_t)nelements * 4) {
848
+ work.resize(nelements * 4); // upper bound on size
849
+ }
850
+ new_data = work.data();
851
+
852
+ const int64_t n_per_row = tensor->ne[0];
853
+ const int64_t nrows = tensor->ne[1];
854
+
855
+ static const int64_t min_chunk_size = 32 * 512;
856
+ const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
857
+
858
+ const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
859
+ const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
860
+ const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
861
+
862
+ // quantize each expert separately since they have different importance matrices
863
+ new_size = 0;
864
+ for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
865
+ const float * f32_data_03 = f32_data + i03 * nelements_matrix;
866
+ void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
867
+ const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
868
+
869
+ new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
870
+ }
871
+ LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
872
+ }
873
+ total_size_org += ggml_nbytes(tensor);
874
+ total_size_new += new_size;
875
+
876
+ // update the gguf meta data as we go
877
+ gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
878
+ gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data, new_size);
879
+
880
+ // write tensor data + padding
881
+ fout.write((const char *) new_data, new_size);
882
+ zeros(fout, GGML_PAD(new_size, align) - new_size);
883
+ }
884
+ close_ofstream();
885
+
886
+ LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
887
+ LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
888
+
889
+ if (qs.n_fallback > 0) {
890
+ LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
891
+ __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
892
+ }
893
+ }
894
+
895
+ //
896
+ // interface implementation
897
+ //
898
+
899
+ struct llama_model_quantize_params llama_model_quantize_default_params() {
900
+ struct llama_model_quantize_params result = {
901
+ /*.nthread =*/ 0,
902
+ /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
903
+ /*.output_tensor_type =*/ GGML_TYPE_COUNT,
904
+ /*.token_embedding_type =*/ GGML_TYPE_COUNT,
905
+ /*.allow_requantize =*/ false,
906
+ /*.quantize_output_tensor =*/ true,
907
+ /*.only_copy =*/ false,
908
+ /*.pure =*/ false,
909
+ /*.keep_split =*/ false,
910
+ /*.imatrix =*/ nullptr,
911
+ /*.kv_overrides =*/ nullptr,
912
+ };
913
+
914
+ return result;
915
+ }
916
+
917
+ uint32_t llama_model_quantize(
918
+ const char * fname_inp,
919
+ const char * fname_out,
920
+ const llama_model_quantize_params * params) {
921
+ try {
922
+ llama_model_quantize_impl(fname_inp, fname_out, params);
923
+ } catch (const std::exception & err) {
924
+ LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
925
+ return 1;
926
+ }
927
+
928
+ return 0;
929
+ }
examples/talk-llama/llama-quant.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #pragma once
examples/talk-llama/llama-sampling.cpp CHANGED
@@ -1,5 +1,6 @@
1
  #include "llama-sampling.h"
2
 
 
3
  #include "llama-vocab.h"
4
  #include "llama-grammar.h"
5
 
@@ -14,6 +15,118 @@
14
  #include <numeric>
15
  #include <random>
16
  #include <unordered_map>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
19
  // iterator for the probabilities
@@ -144,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
144
  for (int i = 0; i < (int)cur_p->size; ++i) {
145
  const float val = cur_p->data[i].logit;
146
  int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
147
- ib = std::max(0, std::min(nbuckets-1, ib));
148
  bucket_idx[i] = ib;
149
  ++histo[ib];
150
  }
@@ -167,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
167
  for (int i = 0; i < (int)cur_p->size; ++i) {
168
  int j = bucket_idx[i];
169
  if (j >= ib) {
170
- *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
171
  }
172
  }
173
 
174
  ptr = tmp_tokens.data();
175
  int ndone = 0;
176
- for (int j = nbuckets-1; j > ib; --j) {
177
  std::sort(ptr, ptr + histo[j], comp);
178
  ptr += histo[j];
179
  ndone += histo[j];
@@ -1719,7 +1832,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
1719
  ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
1720
  if (n > 0) {
1721
  lt = k;
1722
- rt = k+n-1;
1723
  }
1724
  } else {
1725
  // If k is inside the current Z-box, consider two cases.
 
1
  #include "llama-sampling.h"
2
 
3
+ #include "llama-impl.h"
4
  #include "llama-vocab.h"
5
  #include "llama-grammar.h"
6
 
 
15
  #include <numeric>
16
  #include <random>
17
  #include <unordered_map>
18
+ #include <stdexcept>
19
+
20
+ // the ring buffer works similarly to std::deque, but with a fixed capacity
21
+ template<typename T>
22
+ struct ring_buffer {
23
+ ring_buffer(size_t cap) : capacity(cap), data(cap) {}
24
+
25
+ T & front() {
26
+ if (sz == 0) {
27
+ throw std::runtime_error("ring buffer is empty");
28
+ }
29
+ return data[first];
30
+ }
31
+
32
+ const T & front() const {
33
+ if (sz == 0) {
34
+ throw std::runtime_error("ring buffer is empty");
35
+ }
36
+ return data[first];
37
+ }
38
+
39
+ T & back() {
40
+ if (sz == 0) {
41
+ throw std::runtime_error("ring buffer is empty");
42
+ }
43
+ return data[pos];
44
+ }
45
+
46
+ const T & back() const {
47
+ if (sz == 0) {
48
+ throw std::runtime_error("ring buffer is empty");
49
+ }
50
+ return data[pos];
51
+ }
52
+
53
+ void push_back(const T & value) {
54
+ if (capacity == 0) {
55
+ throw std::runtime_error("ring buffer: capacity is zero");
56
+ }
57
+
58
+ if (sz == capacity) {
59
+ // advance the start when buffer is full
60
+ first = (first + 1) % capacity;
61
+ } else {
62
+ sz++;
63
+ }
64
+ data[pos] = value;
65
+ pos = (pos + 1) % capacity;
66
+ }
67
+
68
+ T pop_front() {
69
+ if (sz == 0) {
70
+ throw std::runtime_error("ring buffer is empty");
71
+ }
72
+ T value = data[first];
73
+ first = (first + 1) % capacity;
74
+ sz--;
75
+ return value;
76
+ }
77
+
78
+ //T & operator[](size_t i) {
79
+ // if (i >= sz) {
80
+ // throw std::runtime_error("ring buffer: index out of bounds");
81
+ // }
82
+ // return data[(first + i) % capacity];
83
+ //}
84
+
85
+ //const T & at(size_t i) const {
86
+ // if (i >= sz) {
87
+ // throw std::runtime_error("ring buffer: index out of bounds");
88
+ // }
89
+ // return data[(first + i) % capacity];
90
+ //}
91
+
92
+ const T & rat(size_t i) const {
93
+ if (i >= sz) {
94
+ throw std::runtime_error("ring buffer: index out of bounds");
95
+ }
96
+ return data[(first + sz - i - 1) % capacity];
97
+ }
98
+
99
+ std::vector<T> to_vector() const {
100
+ std::vector<T> result;
101
+ result.reserve(sz);
102
+ for (size_t i = 0; i < sz; i++) {
103
+ result.push_back(data[(first + i) % capacity]);
104
+ }
105
+ return result;
106
+ }
107
+
108
+ void clear() {
109
+ // here only reset the status of the buffer
110
+ sz = 0;
111
+ first = 0;
112
+ pos = 0;
113
+ }
114
+
115
+ bool empty() const {
116
+ return sz == 0;
117
+ }
118
+
119
+ size_t size() const {
120
+ return sz;
121
+ }
122
+
123
+ size_t capacity = 0;
124
+ size_t sz = 0;
125
+ size_t first = 0;
126
+ size_t pos = 0;
127
+
128
+ std::vector<T> data;
129
+ };
130
 
131
  static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
132
  // iterator for the probabilities
 
257
  for (int i = 0; i < (int)cur_p->size; ++i) {
258
  const float val = cur_p->data[i].logit;
259
  int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
260
+ ib = std::max(0, std::min(nbuckets - 1, ib));
261
  bucket_idx[i] = ib;
262
  ++histo[ib];
263
  }
 
280
  for (int i = 0; i < (int)cur_p->size; ++i) {
281
  int j = bucket_idx[i];
282
  if (j >= ib) {
283
+ *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
284
  }
285
  }
286
 
287
  ptr = tmp_tokens.data();
288
  int ndone = 0;
289
+ for (int j = nbuckets - 1; j > ib; --j) {
290
  std::sort(ptr, ptr + histo[j], comp);
291
  ptr += histo[j];
292
  ndone += histo[j];
 
1832
  ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
1833
  if (n > 0) {
1834
  lt = k;
1835
+ rt = k + n - 1;
1836
  }
1837
  } else {
1838
  // If k is inside the current Z-box, consider two cases.
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -1,5 +1,7 @@
1
  #include "llama-vocab.h"
2
 
 
 
3
  #include "unicode.h"
4
 
5
  #include <algorithm>
@@ -16,22 +18,6 @@
16
  // helpers
17
  //
18
 
19
- LLAMA_ATTRIBUTE_FORMAT(1, 2)
20
- static std::string format(const char * fmt, ...) {
21
- va_list ap;
22
- va_list ap2;
23
- va_start(ap, fmt);
24
- va_copy(ap2, ap);
25
- int size = vsnprintf(NULL, 0, fmt, ap);
26
- GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
27
- std::vector<char> buf(size + 1);
28
- int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
29
- GGML_ASSERT(size2 == size);
30
- va_end(ap2);
31
- va_end(ap);
32
- return std::string(buf.data(), size);
33
- }
34
-
35
  struct naive_trie {
36
  naive_trie() : has_value(false), value(0) {
37
  }
@@ -396,6 +382,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
396
  "\\p{N}+",
397
  };
398
  break;
 
 
 
 
 
 
 
399
  case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
400
  regex_exprs = {
401
  "[\r\n]",
@@ -504,7 +497,7 @@ struct llm_tokenizer_bpe_session {
504
 
505
  bool append_bos(std::vector<llama_vocab::id> & output) const {
506
  if (vocab.tokenizer_add_bos) {
507
- GGML_ASSERT(vocab.special_bos_id != -1);
508
  output.push_back(vocab.special_bos_id);
509
  return true;
510
  }
@@ -513,7 +506,7 @@ struct llm_tokenizer_bpe_session {
513
 
514
  bool append_eos(std::vector<llama_vocab::id> & output) const {
515
  if (vocab.tokenizer_add_eos) {
516
- GGML_ASSERT(vocab.special_eos_id != -1);
517
  output.push_back(vocab.special_eos_id);
518
  return true;
519
  }
@@ -1410,7 +1403,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1410
  if (source == 0) {
1411
  buffer.erase_after(buffer.before_begin());
1412
  } else {
1413
- buffer.erase_after(std::next(buffer.begin(), (source-1)));
1414
  }
1415
 
1416
  // repeat for the right side
@@ -1424,7 +1417,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1424
  if (source == 0) {
1425
  buffer.erase_after(buffer.before_begin());
1426
  } else {
1427
- buffer.erase_after(std::next(buffer.begin(), (source-1)));
1428
  }
1429
  break;
1430
  }
@@ -1461,7 +1454,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
1461
  bool is_prev_special = true; // prefix with space if first token
1462
 
1463
  if (add_special && vocab.tokenizer_add_bos) {
1464
- GGML_ASSERT(vocab.special_bos_id != -1);
1465
  output.push_back(vocab.special_bos_id);
1466
  is_prev_special = true;
1467
  }
@@ -1496,7 +1489,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
1496
  }
1497
 
1498
  if (add_special && vocab.tokenizer_add_eos) {
1499
- GGML_ASSERT(vocab.special_eos_id != -1);
1500
  output.push_back(vocab.special_eos_id);
1501
  }
1502
  } break;
@@ -1529,7 +1522,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
1529
  case LLAMA_VOCAB_TYPE_WPM:
1530
  {
1531
  if (add_special) {
1532
- GGML_ASSERT(vocab.special_cls_id != -1);
1533
  output.push_back(vocab.special_cls_id);
1534
  }
1535
 
@@ -1549,14 +1542,14 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
1549
  }
1550
 
1551
  if (add_special) {
1552
- GGML_ASSERT(vocab.special_sep_id != -1);
1553
  output.push_back(vocab.special_sep_id);
1554
  }
1555
  } break;
1556
  case LLAMA_VOCAB_TYPE_UGM:
1557
  {
1558
  if (add_special && vocab.tokenizer_add_bos) {
1559
- GGML_ASSERT(vocab.special_bos_id != -1);
1560
  output.push_back(vocab.special_bos_id);
1561
  }
1562
  llm_tokenizer_ugm_session session(vocab);
@@ -1581,7 +1574,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
1581
  }
1582
 
1583
  if (add_special && vocab.tokenizer_add_eos) {
1584
- GGML_ASSERT(vocab.special_eos_id != -1);
1585
  output.push_back(vocab.special_eos_id);
1586
  }
1587
  } break;
@@ -1649,7 +1642,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
1649
  }
1650
 
1651
  bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
1652
- return token != -1 && vocab.special_eog_ids.count(token) > 0;
1653
  }
1654
 
1655
  bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
@@ -1657,7 +1650,7 @@ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token t
1657
  }
1658
 
1659
  llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
1660
- return vocab.special_bos_id;
1661
  }
1662
 
1663
  llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
@@ -1867,6 +1860,10 @@ int32_t llama_detokenize_impl(
1867
  int32_t text_len_max,
1868
  bool remove_special,
1869
  bool unparse_special) {
 
 
 
 
1870
  GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
1871
 
1872
  int32_t avail = text_len_max;
@@ -1884,7 +1881,7 @@ int32_t llama_detokenize_impl(
1884
  }
1885
 
1886
  if (remove_special && vocab.tokenizer_add_eos) {
1887
- if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
1888
  n_tokens--;
1889
  }
1890
  }
 
1
  #include "llama-vocab.h"
2
 
3
+ #include "llama-impl.h"
4
+
5
  #include "unicode.h"
6
 
7
  #include <algorithm>
 
18
  // helpers
19
  //
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  struct naive_trie {
22
  naive_trie() : has_value(false), value(0) {
23
  }
 
382
  "\\p{N}+",
383
  };
384
  break;
385
+ case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM:
386
+ regex_exprs = {
387
+ "\\p{N}{1,3}",
388
+ "[一-龥぀-ゟ゠-ヿ]+",
389
+ "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
390
+ };
391
+ break;
392
  case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
393
  regex_exprs = {
394
  "[\r\n]",
 
497
 
498
  bool append_bos(std::vector<llama_vocab::id> & output) const {
499
  if (vocab.tokenizer_add_bos) {
500
+ GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
501
  output.push_back(vocab.special_bos_id);
502
  return true;
503
  }
 
506
 
507
  bool append_eos(std::vector<llama_vocab::id> & output) const {
508
  if (vocab.tokenizer_add_eos) {
509
+ GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
510
  output.push_back(vocab.special_eos_id);
511
  return true;
512
  }
 
1403
  if (source == 0) {
1404
  buffer.erase_after(buffer.before_begin());
1405
  } else {
1406
+ buffer.erase_after(std::next(buffer.begin(), (source - 1)));
1407
  }
1408
 
1409
  // repeat for the right side
 
1417
  if (source == 0) {
1418
  buffer.erase_after(buffer.before_begin());
1419
  } else {
1420
+ buffer.erase_after(std::next(buffer.begin(), (source - 1)));
1421
  }
1422
  break;
1423
  }
 
1454
  bool is_prev_special = true; // prefix with space if first token
1455
 
1456
  if (add_special && vocab.tokenizer_add_bos) {
1457
+ GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
1458
  output.push_back(vocab.special_bos_id);
1459
  is_prev_special = true;
1460
  }
 
1489
  }
1490
 
1491
  if (add_special && vocab.tokenizer_add_eos) {
1492
+ GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
1493
  output.push_back(vocab.special_eos_id);
1494
  }
1495
  } break;
 
1522
  case LLAMA_VOCAB_TYPE_WPM:
1523
  {
1524
  if (add_special) {
1525
+ GGML_ASSERT(vocab.special_cls_id != LLAMA_TOKEN_NULL);
1526
  output.push_back(vocab.special_cls_id);
1527
  }
1528
 
 
1542
  }
1543
 
1544
  if (add_special) {
1545
+ GGML_ASSERT(vocab.special_sep_id != LLAMA_TOKEN_NULL);
1546
  output.push_back(vocab.special_sep_id);
1547
  }
1548
  } break;
1549
  case LLAMA_VOCAB_TYPE_UGM:
1550
  {
1551
  if (add_special && vocab.tokenizer_add_bos) {
1552
+ GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
1553
  output.push_back(vocab.special_bos_id);
1554
  }
1555
  llm_tokenizer_ugm_session session(vocab);
 
1574
  }
1575
 
1576
  if (add_special && vocab.tokenizer_add_eos) {
1577
+ GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
1578
  output.push_back(vocab.special_eos_id);
1579
  }
1580
  } break;
 
1642
  }
1643
 
1644
  bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
1645
+ return token != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(token) > 0;
1646
  }
1647
 
1648
  bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
 
1650
  }
1651
 
1652
  llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
1653
+ return vocab.type != LLAMA_VOCAB_TYPE_WPM ? vocab.special_bos_id : vocab.special_cls_id;
1654
  }
1655
 
1656
  llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
 
1860
  int32_t text_len_max,
1861
  bool remove_special,
1862
  bool unparse_special) {
1863
+ if (vocab.type == LLAMA_VOCAB_TYPE_NONE) {
1864
+ return 0;
1865
+ }
1866
+
1867
  GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
1868
 
1869
  int32_t avail = text_len_max;
 
1881
  }
1882
 
1883
  if (remove_special && vocab.tokenizer_add_eos) {
1884
+ if (n_tokens > 0 && tokens[n_tokens - 1] == vocab.special_eos_id) {
1885
  n_tokens--;
1886
  }
1887
  }
examples/talk-llama/llama-vocab.h CHANGED
@@ -1,6 +1,6 @@
1
  #pragma once
2
 
3
- #include "llama-impl.h"
4
 
5
  #include <string>
6
  #include <vector>
@@ -8,6 +8,18 @@
8
  #include <map>
9
  #include <set>
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  struct llm_tokenizer;
12
 
13
  struct llama_vocab {
@@ -45,7 +57,7 @@ struct llama_vocab {
45
  id special_unk_id = 0;
46
  id special_sep_id = LLAMA_TOKEN_NULL;
47
  id special_pad_id = LLAMA_TOKEN_NULL;
48
- id special_cls_id = LLAMA_TOKEN_NULL;
49
  id special_mask_id = LLAMA_TOKEN_NULL;
50
 
51
  id linefeed_id = 13;
 
1
  #pragma once
2
 
3
+ #include "llama.h"
4
 
5
  #include <string>
6
  #include <vector>
 
8
  #include <map>
9
  #include <set>
10
 
11
+ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
12
+ switch (type) {
13
+ case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
14
+ case LLAMA_VOCAB_TYPE_SPM: return "SPM";
15
+ case LLAMA_VOCAB_TYPE_BPE: return "BPE";
16
+ case LLAMA_VOCAB_TYPE_WPM: return "WPM";
17
+ case LLAMA_VOCAB_TYPE_UGM: return "UGM";
18
+ case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
19
+ default: return "unknown";
20
+ }
21
+ }
22
+
23
  struct llm_tokenizer;
24
 
25
  struct llama_vocab {
 
57
  id special_unk_id = 0;
58
  id special_sep_id = LLAMA_TOKEN_NULL;
59
  id special_pad_id = LLAMA_TOKEN_NULL;
60
+ id special_cls_id = LLAMA_TOKEN_NULL; // TODO: revisit if this is really needed https://github.com/ggerganov/llama.cpp/pull/10930
61
  id special_mask_id = LLAMA_TOKEN_NULL;
62
 
63
  id linefeed_id = 13;
examples/talk-llama/llama.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama.h CHANGED
@@ -34,7 +34,6 @@
34
 
35
  #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
36
 
37
- // TODO: use everywhere in the implementation
38
  #define LLAMA_TOKEN_NULL -1
39
 
40
  #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
@@ -105,6 +104,7 @@ extern "C" {
105
  LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
106
  LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
107
  LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
 
108
  };
109
 
110
  enum llama_rope_type {
@@ -385,6 +385,7 @@ extern "C" {
385
  } llama_chat_message;
386
 
387
  // lora adapter
 
388
  struct llama_lora_adapter;
389
 
390
  // Helpers for getting default parameters
@@ -412,11 +413,19 @@ extern "C" {
412
  // Call once at the end of the program - currently only used for MPI
413
  LLAMA_API void llama_backend_free(void);
414
 
415
- LLAMA_API struct llama_model * llama_load_model_from_file(
 
 
 
 
 
416
  const char * path_model,
417
  struct llama_model_params params);
418
 
419
- LLAMA_API void llama_free_model(struct llama_model * model);
 
 
 
420
 
421
  // TODO: rename to llama_init_from_model
422
  LLAMA_API struct llama_context * llama_new_context_with_model(
@@ -482,9 +491,6 @@ extern "C" {
482
  // Returns the total number of parameters in the model
483
  LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
484
 
485
- // Get a llama model tensor
486
- LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
487
-
488
  // Returns true if the model contains an encoder that requires llama_encode() call
489
  LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
490
 
@@ -504,14 +510,19 @@ extern "C" {
504
  const char * fname_out,
505
  const llama_model_quantize_params * params);
506
 
 
 
 
 
507
  // Load a LoRA adapter from file
508
- // The loaded adapter will be associated to the given model, and will be free when the model is deleted
509
  LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
510
  struct llama_model * model,
511
  const char * path_lora);
512
 
513
  // Add a loaded LoRA adapter to given context
514
  // This will not modify model's weight
 
515
  LLAMA_API int32_t llama_lora_adapter_set(
516
  struct llama_context * ctx,
517
  struct llama_lora_adapter * adapter,
@@ -519,16 +530,18 @@ extern "C" {
519
 
520
  // Remove a specific LoRA adapter from given context
521
  // Return -1 if the adapter is not present in the context
 
522
  LLAMA_API int32_t llama_lora_adapter_remove(
523
  struct llama_context * ctx,
524
  struct llama_lora_adapter * adapter);
525
 
526
  // Remove all LoRA adapters from given context
527
- LLAMA_API void llama_lora_adapter_clear(
528
- struct llama_context * ctx);
529
 
530
  // Manually free a LoRA adapter
531
  // Note: loaded adapters will be free when the associated model is deleted
 
532
  LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
533
 
534
  // Apply a loaded control vector to a llama_context, or if data is NULL, clear
@@ -537,6 +550,7 @@ extern "C" {
537
  // to an n_embd x n_layers buffer starting from layer 1.
538
  // il_start and il_end are the layer range the vector should apply to (both inclusive)
539
  // See llama_control_vector_load in common to load a control vector.
 
540
  LLAMA_API int32_t llama_control_vector_apply(
541
  struct llama_context * lctx,
542
  const float * data,
@@ -549,6 +563,8 @@ extern "C" {
549
  // KV cache
550
  //
551
 
 
 
552
  // Information associated with an individual cell in the KV cache view.
553
  struct llama_kv_cache_view_cell {
554
  // The position for this cell. Takes KV cache shifts into account.
@@ -595,8 +611,11 @@ extern "C" {
595
  LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
596
 
597
  // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
 
598
  LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
599
 
 
 
600
  // Returns the number of tokens in the KV cache (slow, use only for debug)
601
  // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
602
  LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
@@ -666,6 +685,9 @@ extern "C" {
666
  struct llama_context * ctx,
667
  llama_seq_id seq_id);
668
 
 
 
 
669
  // Defragment the KV cache
670
  // This will be applied:
671
  // - lazily on next llama_decode()
 
34
 
35
  #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
36
 
 
37
  #define LLAMA_TOKEN_NULL -1
38
 
39
  #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
 
104
  LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
105
  LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
106
  LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
107
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
108
  };
109
 
110
  enum llama_rope_type {
 
385
  } llama_chat_message;
386
 
387
  // lora adapter
388
+ // TODO: rename to llama_adapter_lora
389
  struct llama_lora_adapter;
390
 
391
  // Helpers for getting default parameters
 
413
  // Call once at the end of the program - currently only used for MPI
414
  LLAMA_API void llama_backend_free(void);
415
 
416
+ DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file(
417
+ const char * path_model,
418
+ struct llama_model_params params),
419
+ "use llama_model_load_from_file instead");
420
+
421
+ LLAMA_API struct llama_model * llama_model_load_from_file(
422
  const char * path_model,
423
  struct llama_model_params params);
424
 
425
+ DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
426
+ "use llama_model_free instead");
427
+
428
+ LLAMA_API void llama_model_free(struct llama_model * model);
429
 
430
  // TODO: rename to llama_init_from_model
431
  LLAMA_API struct llama_context * llama_new_context_with_model(
 
491
  // Returns the total number of parameters in the model
492
  LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
493
 
 
 
 
494
  // Returns true if the model contains an encoder that requires llama_encode() call
495
  LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
496
 
 
510
  const char * fname_out,
511
  const llama_model_quantize_params * params);
512
 
513
+ //
514
+ // Adapters
515
+ //
516
+
517
  // Load a LoRA adapter from file
518
+ // TODO: rename to llama_adapter_lora_init
519
  LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
520
  struct llama_model * model,
521
  const char * path_lora);
522
 
523
  // Add a loaded LoRA adapter to given context
524
  // This will not modify model's weight
525
+ // TODO: rename to llama_set_adapter_lora
526
  LLAMA_API int32_t llama_lora_adapter_set(
527
  struct llama_context * ctx,
528
  struct llama_lora_adapter * adapter,
 
530
 
531
  // Remove a specific LoRA adapter from given context
532
  // Return -1 if the adapter is not present in the context
533
+ // TODO: rename to llama_rm_adapter_lora
534
  LLAMA_API int32_t llama_lora_adapter_remove(
535
  struct llama_context * ctx,
536
  struct llama_lora_adapter * adapter);
537
 
538
  // Remove all LoRA adapters from given context
539
+ // TODO: rename to llama_clear_adapter_lora
540
+ LLAMA_API void llama_lora_adapter_clear(struct llama_context * ctx);
541
 
542
  // Manually free a LoRA adapter
543
  // Note: loaded adapters will be free when the associated model is deleted
544
+ // TODO: rename to llama_adapter_lora_free
545
  LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
546
 
547
  // Apply a loaded control vector to a llama_context, or if data is NULL, clear
 
550
  // to an n_embd x n_layers buffer starting from layer 1.
551
  // il_start and il_end are the layer range the vector should apply to (both inclusive)
552
  // See llama_control_vector_load in common to load a control vector.
553
+ // TODO: rename to llama_adapter_cvec_apply
554
  LLAMA_API int32_t llama_control_vector_apply(
555
  struct llama_context * lctx,
556
  const float * data,
 
563
  // KV cache
564
  //
565
 
566
+ // TODO: remove llama_kv_cache_view_* API
567
+
568
  // Information associated with an individual cell in the KV cache view.
569
  struct llama_kv_cache_view_cell {
570
  // The position for this cell. Takes KV cache shifts into account.
 
611
  LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
612
 
613
  // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
614
+ // TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
615
  LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
616
 
617
+ ///
618
+
619
  // Returns the number of tokens in the KV cache (slow, use only for debug)
620
  // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
621
  LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
 
685
  struct llama_context * ctx,
686
  llama_seq_id seq_id);
687
 
688
+ // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
689
+ // how to avoid this?
690
+
691
  // Defragment the KV cache
692
  // This will be applied:
693
  // - lazily on next llama_decode()
examples/talk-llama/talk-llama.cpp CHANGED
@@ -304,7 +304,7 @@ int main(int argc, char ** argv) {
304
  lmparams.n_gpu_layers = params.n_gpu_layers;
305
  }
306
 
307
- struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);
308
  if (!model_llama) {
309
  fprintf(stderr, "No llama.cpp model specified. Please provide using -ml <modelfile>\n");
310
  return 1;
 
304
  lmparams.n_gpu_layers = params.n_gpu_layers;
305
  }
306
 
307
+ struct llama_model * model_llama = llama_model_load_from_file(params.model_llama.c_str(), lmparams);
308
  if (!model_llama) {
309
  fprintf(stderr, "No llama.cpp model specified. Please provide using -ml <modelfile>\n");
310
  return 1;
examples/talk-llama/unicode.cpp CHANGED
@@ -667,18 +667,24 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
667
  { "\\p{N}", unicode_cpt_flags::NUMBER },
668
  { "\\p{L}", unicode_cpt_flags::LETTER },
669
  { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
 
 
670
  };
671
 
672
  static const std::map<int, int> k_ucat_cpt = {
673
  { unicode_cpt_flags::NUMBER, 0xD1 },
674
  { unicode_cpt_flags::LETTER, 0xD2 },
675
  { unicode_cpt_flags::PUNCTUATION, 0xD3 },
 
 
676
  };
677
 
678
  static const std::map<int, std::string> k_ucat_map = {
679
  { unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
680
  { unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
681
  { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
 
 
682
  };
683
 
684
  // compute collapsed codepoints only if needed by at least one regex
 
667
  { "\\p{N}", unicode_cpt_flags::NUMBER },
668
  { "\\p{L}", unicode_cpt_flags::LETTER },
669
  { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
670
+ { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
671
+ { "\\p{S}", unicode_cpt_flags::SYMBOL },
672
  };
673
 
674
  static const std::map<int, int> k_ucat_cpt = {
675
  { unicode_cpt_flags::NUMBER, 0xD1 },
676
  { unicode_cpt_flags::LETTER, 0xD2 },
677
  { unicode_cpt_flags::PUNCTUATION, 0xD3 },
678
+ { unicode_cpt_flags::ACCENT_MARK, 0xD4 },
679
+ { unicode_cpt_flags::SYMBOL, 0xD5 },
680
  };
681
 
682
  static const std::map<int, std::string> k_ucat_map = {
683
  { unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
684
  { unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
685
  { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
686
+ { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
687
+ { unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
688
  };
689
 
690
  // compute collapsed codepoints only if needed by at least one regex