ggerganov commited on
Commit
05d6d9c
·
1 Parent(s): b12517c

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/CMakeLists.txt CHANGED
@@ -20,6 +20,7 @@ if (WHISPER_SDL2)
20
  llama-memory.cpp
21
  llama-mmap.cpp
22
  llama-model-loader.cpp
 
23
  llama-model.cpp
24
  llama-quant.cpp
25
  llama-sampling.cpp
 
20
  llama-memory.cpp
21
  llama-mmap.cpp
22
  llama-model-loader.cpp
23
+ llama-model-saver.cpp
24
  llama-model.cpp
25
  llama-quant.cpp
26
  llama-sampling.cpp
examples/talk-llama/llama-adapter.cpp CHANGED
@@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
253
  std::vector<ggml_backend_buffer_type_t> buft_extra;
254
  {
255
  auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
 
 
 
256
  auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
257
 
258
  auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
@@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
291
  LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
292
 
293
  auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
 
 
 
294
  buft = ggml_backend_dev_buffer_type(cpu_dev);
295
 
296
  break;
 
253
  std::vector<ggml_backend_buffer_type_t> buft_extra;
254
  {
255
  auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
256
+ if (!cpu_dev) {
257
+ throw std::runtime_error(format("%s: no CPU backend found", __func__));
258
+ }
259
  auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
260
 
261
  auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
 
294
  LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
295
 
296
  auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
297
+ if (!cpu_dev) {
298
+ throw std::runtime_error(format("%s: no CPU backend found", __func__));
299
+ }
300
  buft = ggml_backend_dev_buffer_type(cpu_dev);
301
 
302
  break;
examples/talk-llama/llama-batch.cpp CHANGED
@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
189
  return ubatch;
190
  }
191
 
192
- void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
  GGML_ASSERT(batch.n_tokens >= 0);
194
  this->batch = &batch;
195
  this->n_embd = n_embd;
@@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
203
  for (size_t i = 0; i < n_tokens; ++i) {
204
  ids[i] = i;
205
  }
 
206
  if (simple_split) {
207
  seq.resize(1);
208
  llama_sbatch_seq & s = seq[0];
@@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
212
  s.length = n_tokens;
213
  return;
214
  }
 
215
  std::sort(ids.begin(), ids.end(),
216
  [&batch](size_t a, size_t b) {
217
  int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
@@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
239
  return n_seq_a > n_seq_b;
240
  }
241
  );
 
242
  // init seq
243
  llama_sbatch_seq * last_seq = nullptr;
244
 
@@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
262
  seq.push_back(new_seq);
263
  last_seq = &seq.back();
264
  }
 
265
  // keep shared prompts first at the end, then sort by length descending.
266
  std::sort(seq.begin(), seq.end(),
267
  [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
 
189
  return ubatch;
190
  }
191
 
192
+ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
  GGML_ASSERT(batch.n_tokens >= 0);
194
  this->batch = &batch;
195
  this->n_embd = n_embd;
 
203
  for (size_t i = 0; i < n_tokens; ++i) {
204
  ids[i] = i;
205
  }
206
+
207
  if (simple_split) {
208
  seq.resize(1);
209
  llama_sbatch_seq & s = seq[0];
 
213
  s.length = n_tokens;
214
  return;
215
  }
216
+
217
  std::sort(ids.begin(), ids.end(),
218
  [&batch](size_t a, size_t b) {
219
  int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
 
241
  return n_seq_a > n_seq_b;
242
  }
243
  );
244
+
245
  // init seq
246
  llama_sbatch_seq * last_seq = nullptr;
247
 
 
265
  seq.push_back(new_seq);
266
  last_seq = &seq.back();
267
  }
268
+
269
  // keep shared prompts first at the end, then sort by length descending.
270
  std::sort(seq.begin(), seq.end(),
271
  [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
examples/talk-llama/llama-batch.h CHANGED
@@ -70,7 +70,8 @@ struct llama_sbatch {
70
  // sequence-wise split
71
  llama_ubatch split_seq(size_t n_ubatch);
72
 
73
- void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
 
74
  };
75
 
76
  // temporary allocate memory for the input batch if needed
 
70
  // sequence-wise split
71
  llama_ubatch split_seq(size_t n_ubatch);
72
 
73
+ llama_sbatch() = default;
74
+ llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
75
  };
76
 
77
  // temporary allocate memory for the input batch if needed
examples/talk-llama/llama-chat.cpp CHANGED
@@ -35,6 +35,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
35
  { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
36
  { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
37
  { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
 
38
  { "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
39
  { "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
40
  { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
@@ -202,19 +203,20 @@ int32_t llm_chat_apply_template(
202
  if (add_ass) {
203
  ss << "<|im_start|>assistant\n";
204
  }
205
- } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
206
  // Official mistral 'v7' template
207
  // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
 
 
208
  for (auto message : chat) {
209
  std::string role(message->role);
210
  std::string content(message->content);
211
  if (role == "system") {
212
- ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
213
  } else if (role == "user") {
214
- ss << "[INST] " << content << "[/INST]";
215
- }
216
- else {
217
- ss << " " << content << "</s>";
218
  }
219
  }
220
  } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
@@ -447,8 +449,16 @@ int32_t llm_chat_apply_template(
447
  if (add_ass) {
448
  ss << "<|assistant|>";
449
  }
450
- } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
451
  ss << "[gMASK]" << "<sop>";
 
 
 
 
 
 
 
 
452
  for (auto message : chat) {
453
  std::string role(message->role);
454
  ss << "<|" << role << "|>" << "\n" << message->content;
 
35
  { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
36
  { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
37
  { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
38
+ { "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
39
  { "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
40
  { "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
41
  { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
 
203
  if (add_ass) {
204
  ss << "<|im_start|>assistant\n";
205
  }
206
+ } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
207
  // Official mistral 'v7' template
208
  // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
209
+ // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
210
+ const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
211
  for (auto message : chat) {
212
  std::string role(message->role);
213
  std::string content(message->content);
214
  if (role == "system") {
215
+ ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
216
  } else if (role == "user") {
217
+ ss << "[INST]" << trailing_space << content << "[/INST]";
218
+ } else {
219
+ ss << trailing_space << content << "</s>";
 
220
  }
221
  }
222
  } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
 
449
  if (add_ass) {
450
  ss << "<|assistant|>";
451
  }
452
+ } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
453
  ss << "[gMASK]" << "<sop>";
454
+ for (auto message : chat) {
455
+ std::string role(message->role);
456
+ ss << "<|" << role << "|>" << "\n" << message->content;
457
+ }
458
+ if (add_ass) {
459
+ ss << "<|assistant|>\n";
460
+ }
461
+ } else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
462
  for (auto message : chat) {
463
  std::string role(message->role);
464
  ss << "<|" << role << "|>" << "\n" << message->content;
examples/talk-llama/llama-chat.h CHANGED
@@ -14,6 +14,7 @@ enum llm_chat_template {
14
  LLM_CHAT_TEMPLATE_MISTRAL_V3,
15
  LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
16
  LLM_CHAT_TEMPLATE_MISTRAL_V7,
 
17
  LLM_CHAT_TEMPLATE_PHI_3,
18
  LLM_CHAT_TEMPLATE_PHI_4,
19
  LLM_CHAT_TEMPLATE_FALCON_3,
 
14
  LLM_CHAT_TEMPLATE_MISTRAL_V3,
15
  LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
16
  LLM_CHAT_TEMPLATE_MISTRAL_V7,
17
+ LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
18
  LLM_CHAT_TEMPLATE_PHI_3,
19
  LLM_CHAT_TEMPLATE_PHI_4,
20
  LLM_CHAT_TEMPLATE_FALCON_3,
examples/talk-llama/llama-context.cpp CHANGED
@@ -6,11 +6,9 @@
6
  #include "llama-model.h"
7
  #include "llama-kv-cache.h"
8
 
9
- #include <cassert>
10
  #include <cstring>
11
  #include <stdexcept>
12
  #include <cinttypes>
13
- #include <cmath>
14
 
15
  //
16
  // llama_context
@@ -95,6 +93,7 @@ llama_context::llama_context(
95
  }
96
 
97
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
 
98
 
99
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
100
 
@@ -118,8 +117,6 @@ llama_context::llama_context(
118
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
119
  }
120
 
121
- logits_all = params.logits_all;
122
-
123
  if (!hparams.vocab_only) {
124
  // GPU backends
125
  for (auto * dev : model.devices) {
@@ -177,44 +174,13 @@ llama_context::llama_context(
177
  }
178
 
179
  // init the memory module
180
- // TODO: for now, always create a unified KV cache
181
  if (!hparams.vocab_only) {
182
- kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
183
-
184
- LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
185
-
186
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
187
-
188
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
189
-
190
- uint32_t kv_size = cparams.n_ctx;
191
- ggml_type type_k = params.type_k;
192
- ggml_type type_v = params.type_v;
193
-
194
- if (llama_model_is_recurrent(&model)) {
195
- // Mamba needs at least as many KV cells as there are sequences kept at any time
196
- kv_size = std::max((uint32_t) 1, params.n_seq_max);
197
- // it's probably best to keep as much precision as possible for the states
198
- type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
199
- type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
200
- }
201
-
202
- GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
203
- GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
204
-
205
- if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
206
- throw std::runtime_error("failed to initialize self-attention cache");
207
- }
208
 
209
- {
210
- const size_t memory_size_k = kv_self->size_k_bytes();
211
- const size_t memory_size_v = kv_self->size_v_bytes();
212
-
213
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
214
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
215
- ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
216
- ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
217
- }
218
  }
219
 
220
  // init backends
@@ -278,7 +244,7 @@ llama_context::llama_context(
278
  }
279
  }
280
 
281
- sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
282
 
283
  if (pipeline_parallel) {
284
  LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
@@ -286,7 +252,7 @@ llama_context::llama_context(
286
  }
287
 
288
  // reserve worst-case graph
289
- if (!hparams.vocab_only) {
290
  const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
291
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
292
 
@@ -305,7 +271,9 @@ llama_context::llama_context(
305
  int n_nodes_tg = -1;
306
 
307
  // simulate full KV cache
308
- kv_self->n = kv_self->size;
 
 
309
 
310
  cross.v_embd.clear();
311
 
@@ -391,7 +359,9 @@ llama_context::llama_context(
391
  }
392
  }
393
 
394
- llama_context::~llama_context() = default;
 
 
395
 
396
  void llama_context::synchronize() {
397
  ggml_backend_sched_synchronize(sched.get());
@@ -427,6 +397,18 @@ const llama_model & llama_context::get_model() const {
427
  return model;
428
  }
429
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  uint32_t llama_context::n_ctx() const {
431
  return cparams.n_ctx;
432
  }
@@ -456,337 +438,21 @@ uint32_t llama_context::n_threads_batch() const {
456
  }
457
 
458
  llama_kv_cache * llama_context::get_kv_self() {
459
- return kv_self.get();
 
460
  }
461
 
462
  const llama_kv_cache * llama_context::get_kv_self() const {
463
- return kv_self.get();
464
- }
465
-
466
- ggml_tensor * llama_context::build_rope_shift(
467
- ggml_context * ctx0,
468
- ggml_tensor * cur,
469
- ggml_tensor * shift,
470
- ggml_tensor * factors,
471
- float freq_base,
472
- float freq_scale) const {
473
- const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
474
-
475
- const auto & yarn_ext_factor = cparams.yarn_ext_factor;
476
- const auto & yarn_beta_fast = cparams.yarn_beta_fast;
477
- const auto & yarn_beta_slow = cparams.yarn_beta_slow;
478
-
479
- const auto & hparams = model.hparams;
480
-
481
- const auto & n_rot = hparams.n_rot;
482
- const auto & rope_type = hparams.rope_type;
483
-
484
- // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
485
- // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
486
- const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
487
-
488
- ggml_tensor * tmp;
489
-
490
- if (ggml_is_quantized(cur->type)) {
491
- // dequantize to f32 -> RoPE -> quantize back
492
- tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
493
-
494
- tmp = ggml_rope_ext(ctx0, tmp,
495
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
496
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
497
-
498
- tmp = ggml_cpy(ctx0, tmp, cur);
499
- } else {
500
- // we rotate only the first n_rot dimensions
501
- tmp = ggml_rope_ext_inplace(ctx0, cur,
502
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
503
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
504
- }
505
-
506
- return tmp;
507
- }
508
-
509
- class llm_graph_input_k_shift : public llm_graph_input_i {
510
- public:
511
- llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
512
- virtual ~llm_graph_input_k_shift() = default;
513
-
514
- void set_input(const llama_ubatch * ubatch) override;
515
-
516
- ggml_tensor * k_shift; // I32 [kv_size]
517
-
518
- const llama_kv_cache_unified * kv_self;
519
- };
520
-
521
- void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
522
- GGML_UNUSED(ubatch);
523
-
524
- if (k_shift) {
525
- assert(ggml_backend_buffer_is_host(k_shift->buffer));
526
-
527
- int32_t * data = (int32_t *) k_shift->data;
528
-
529
- for (uint32_t i = 0; i < kv_self->size; ++i) {
530
- data[i] = kv_self->cells[i].delta;
531
- }
532
- }
533
- }
534
-
535
- llm_graph_result_ptr llama_context::build_kv_self_shift(
536
- ggml_context * ctx0,
537
- ggml_cgraph * gf) const {
538
- auto res = std::make_unique<llm_graph_result>();
539
-
540
- const auto & hparams = model.hparams;
541
-
542
- const auto & n_layer = hparams.n_layer;
543
-
544
- const auto & n_embd_head_k = hparams.n_embd_head_k;
545
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
546
-
547
- //GGML_ASSERT(kv_self->size == n_ctx);
548
-
549
- auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
550
-
551
- inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
552
- ggml_set_input(inp->k_shift);
553
-
554
- for (uint32_t il = 0; il < n_layer; ++il) {
555
- const int64_t n_head_kv = hparams.n_head_kv(il);
556
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
557
-
558
- const bool is_swa = hparams.is_swa(il);
559
-
560
- // note: the swa rope params could become part of the cparams in the future
561
- // if we decide to make them configurable, like the non-sliding ones
562
- const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
563
- const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
564
-
565
- ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
566
-
567
- ggml_tensor * k =
568
- ggml_view_3d(ctx0, kv_self->k_l[il],
569
- n_embd_head_k, n_head_kv, kv_self->size,
570
- ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
571
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
572
- 0);
573
-
574
- ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
575
-
576
- ggml_build_forward_expand(gf, cur);
577
- }
578
-
579
- res->add_input(std::move(inp));
580
-
581
- return res;
582
- }
583
-
584
- llm_graph_result_ptr llama_context::build_kv_self_defrag(
585
- ggml_context * ctx0,
586
- ggml_cgraph * gf) const {
587
- auto res = std::make_unique<llm_graph_result>();
588
-
589
- const auto & hparams = model.hparams;
590
-
591
- const auto & ids = kv_self->defrag_info.ids;
592
-
593
- #if 0
594
- // CPU defrag
595
- //
596
- // TODO: optimizations are possible:
597
- // - multiple threads
598
- // - avoid copying to the host memory when already there
599
- //
600
- // likely not worth the effort, as we have ggml_graph based defrag
601
- //
602
-
603
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
604
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
605
-
606
- const uint32_t kv_size = size;
607
-
608
- std::vector<uint8_t> buf_k;
609
- std::vector<uint8_t> buf_v;
610
-
611
- for (uint32_t il = 0; il < n_layer; ++il) {
612
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
613
- const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
614
-
615
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
616
- const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
617
-
618
- buf_k.resize(k_size);
619
- buf_v.resize(v_size);
620
-
621
- ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
622
- ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
623
-
624
- // batch move [i, i+nm) to [id, id+nm)
625
- // note: cells can move only to a lower index
626
- for (uint32_t i = 0; i < n_kv; ++i) {
627
- const uint32_t id = ids[i];
628
-
629
- if (i == id || id == n_kv) {
630
- continue;
631
- }
632
-
633
- uint32_t nm = 1;
634
-
635
- while (i + nm < n_kv && ids[i + nm] == id + nm) {
636
- nm++;
637
- }
638
-
639
- // move keys
640
- {
641
- const int64_t os = i*k_size_row;
642
- const int64_t od = id*k_size_row;
643
-
644
- memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
645
- }
646
-
647
- // move values (note: they are transposed)
648
- {
649
- const int64_t os = i;
650
- const int64_t od = id;
651
-
652
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
653
- memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
654
- }
655
- }
656
-
657
- i += nm - 1;
658
- }
659
-
660
- ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
661
- ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
662
- }
663
- #else
664
- for (uint32_t i = 0; i < ids.size(); ++i) {
665
- const uint32_t id = ids[i];
666
-
667
- if (i == id || id == ids.size()) {
668
- continue;
669
- }
670
-
671
- uint32_t nm = 1;
672
-
673
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
674
- nm++;
675
- }
676
-
677
- for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
678
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
679
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
680
-
681
- ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
682
- n_embd_k_gqa, nm,
683
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
684
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
685
-
686
- ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
687
- n_embd_k_gqa, nm,
688
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
689
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
690
-
691
- ggml_tensor * view_v_src;
692
- ggml_tensor * view_v_dst;
693
-
694
- if (cparams.flash_attn) {
695
- // NOTE: the V cache is not transposed when using flash attention
696
- view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
697
- n_embd_v_gqa, nm,
698
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
699
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
700
-
701
- view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
702
- n_embd_v_gqa, nm,
703
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
704
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
705
- } else {
706
- view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
707
- nm, n_embd_v_gqa,
708
- ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
709
- ggml_row_size(kv_self->v_l[il]->type, i));
710
-
711
- view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
712
- nm, n_embd_v_gqa,
713
- ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
714
- ggml_row_size(kv_self->v_l[il]->type, id));
715
- }
716
-
717
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
718
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
719
- }
720
-
721
- i += nm - 1;
722
- }
723
-
724
- //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
725
- #endif
726
-
727
- return res;
728
  }
729
 
730
  void llama_context::kv_self_update() {
731
- auto & kv = kv_self;
732
-
733
  bool need_reserve = false;
734
 
735
- if (kv->has_shift) {
736
- if (!kv->get_can_shift()) {
737
- GGML_ABORT("The current context does not support K-shift");
738
- }
739
-
740
- LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
741
-
742
- // apply K-shift if needed
743
- if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
744
- ggml_backend_sched_reset(sched.get());
745
-
746
- auto * gf = graph_init();
747
-
748
- auto res = build_kv_self_shift(ctx_compute.get(), gf);
749
-
750
- ggml_backend_sched_alloc_graph(sched.get(), gf);
751
-
752
- res->set_inputs(nullptr);
753
-
754
- graph_compute(gf, false);
755
 
756
- need_reserve = true;
757
- }
758
-
759
- {
760
- kv->has_shift = false;
761
-
762
- for (uint32_t i = 0; i < kv->size; ++i) {
763
- kv->cells[i].delta = 0;
764
- }
765
- }
766
- }
767
-
768
- // defragment the KV cache if needed
769
- if (kv->do_defrag) {
770
- LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
771
-
772
- if (kv->defrag_prepare(graph_max_nodes())) {
773
- ggml_backend_sched_reset(sched.get());
774
-
775
- auto * gf = graph_init();
776
-
777
- auto res = build_kv_self_defrag(ctx_compute.get(), gf);
778
-
779
- ggml_backend_sched_alloc_graph(sched.get(), gf);
780
-
781
- res->set_inputs(nullptr);
782
-
783
- graph_compute(gf, false);
784
-
785
- need_reserve = true;
786
- }
787
-
788
- kv->do_defrag = false;
789
- }
790
 
791
  // reserve a worst case graph if needed
792
  if (need_reserve) {
@@ -797,7 +463,7 @@ void llama_context::kv_self_update() {
797
  uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
798
 
799
  // simulate full KV cache
800
- kv_self->n = kv_self->size;
801
 
802
  llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
803
  llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@@ -818,9 +484,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
818
  }
819
 
820
  float * llama_context::get_logits() {
821
- // reorder logits for backward compatibility
822
- output_reorder();
823
-
824
  return logits;
825
  }
826
 
@@ -863,9 +526,6 @@ float * llama_context::get_logits_ith(int32_t i) {
863
  }
864
 
865
  float * llama_context::get_embeddings() {
866
- // reorder embeddings for backward compatibility
867
- output_reorder();
868
-
869
  return embd;
870
  }
871
 
@@ -1017,8 +677,8 @@ int llama_context::encode(llama_batch & inp_batch) {
1017
  }
1018
 
1019
  // temporary allocate memory for the input batch if needed
1020
- // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1021
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
1022
 
1023
  const llama_batch & batch = batch_allocr.batch;
1024
  const int32_t n_tokens = batch.n_tokens;
@@ -1043,11 +703,13 @@ int llama_context::encode(llama_batch & inp_batch) {
1043
  t_compute_start_us = ggml_time_us();
1044
  }
1045
 
 
 
1046
  n_queued_tokens += n_tokens;
1047
 
1048
  const int64_t n_embd = hparams.n_embd;
1049
 
1050
- sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
1051
 
1052
  const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
1053
 
@@ -1104,12 +766,12 @@ int llama_context::encode(llama_batch & inp_batch) {
1104
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1105
  GGML_ASSERT(backend_embd != nullptr);
1106
 
1107
- GGML_ASSERT(embd != nullptr);
1108
-
1109
  switch (cparams.pooling_type) {
1110
  case LLAMA_POOLING_TYPE_NONE:
1111
  {
1112
  // extract token embeddings
 
 
1113
  GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
1114
  ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1115
  } break;
@@ -1134,11 +796,18 @@ int llama_context::encode(llama_batch & inp_batch) {
1134
  } break;
1135
  case LLAMA_POOLING_TYPE_RANK:
1136
  {
1137
- // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
1138
- // wait for an encoder model that requires this pooling type in order to test it
1139
- // https://github.com/ggerganov/llama.cpp/pull/9510
1140
- GGML_ABORT("RANK pooling not implemented yet");
1141
- }
 
 
 
 
 
 
 
1142
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
1143
  {
1144
  GGML_ABORT("unknown pooling type");
@@ -1176,14 +845,21 @@ int llama_context::encode(llama_batch & inp_batch) {
1176
  }
1177
 
1178
  int llama_context::decode(llama_batch & inp_batch) {
 
 
 
 
 
1179
  if (inp_batch.n_tokens == 0) {
1180
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1181
  return -1;
1182
  }
1183
 
 
 
1184
  // temporary allocate memory for the input batch if needed
1185
- // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1186
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
1187
 
1188
  const llama_batch & batch = batch_allocr.batch;
1189
 
@@ -1195,7 +871,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1195
  const int64_t n_tokens_all = batch.n_tokens;
1196
  const int64_t n_embd = hparams.n_embd;
1197
 
1198
- llama_kv_cache_guard kv_guard(kv_self.get());
1199
 
1200
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1201
 
@@ -1229,18 +905,14 @@ int llama_context::decode(llama_batch & inp_batch) {
1229
  for (uint32_t i = 0; i < n_tokens_all; ++i) {
1230
  n_outputs_all += batch.logits[i] != 0;
1231
  }
1232
- } else if (logits_all || embd_pooled) {
1233
  n_outputs_all = n_tokens_all;
1234
  } else {
1235
  // keep last output only
1236
  n_outputs_all = 1;
1237
  }
1238
 
1239
- const bool logits_all = n_outputs_all == n_tokens_all;
1240
-
1241
- sbatch.from_batch(batch, n_embd,
1242
- /* simple_split */ !kv_self->recurrent,
1243
- /* logits_all */ logits_all);
1244
 
1245
  // reserve output buffer
1246
  if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1254,22 +926,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1254
  int64_t n_outputs_prev = 0;
1255
 
1256
  while (sbatch.n_tokens > 0) {
1257
- llama_ubatch ubatch = llama_ubatch();
1258
-
1259
- const auto & n_ubatch = cparams.n_ubatch;
1260
-
1261
- if (kv_self->recurrent) {
1262
- if (embd_pooled) {
1263
- // Pooled embeddings cannot be split across ubatches (yet)
1264
- ubatch = sbatch.split_seq(cparams.n_ubatch);
1265
- } else {
1266
- // recurrent model architectures are easier to implement
1267
- // with equal-length sequences
1268
- ubatch = sbatch.split_equal(cparams.n_ubatch);
1269
- }
1270
- } else {
1271
- ubatch = sbatch.split_simple(n_ubatch);
1272
- }
1273
 
1274
  // count the outputs in this u_batch
1275
  {
@@ -1289,24 +946,12 @@ int llama_context::decode(llama_batch & inp_batch) {
1289
  }
1290
 
1291
  // find KV slot
1292
- {
1293
- if (!kv_self->find_slot(ubatch)) {
1294
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1295
-
1296
- return 1;
1297
- }
1298
 
1299
- if (!kv_self->recurrent) {
1300
- // a heuristic, to avoid attending the full cache if it is not yet utilized
1301
- // after enough generations, the benefit from this heuristic disappears
1302
- // if we start defragmenting the cache, the benefit from this will be more important
1303
- const uint32_t pad = kv_self->get_padding(cparams);
1304
- kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
1305
- }
1306
  }
1307
 
1308
- //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
1309
-
1310
  ggml_backend_sched_reset(sched.get());
1311
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1312
 
@@ -1420,43 +1065,68 @@ int llama_context::decode(llama_batch & inp_batch) {
1420
  // finalize the batch processing
1421
  kv_guard.commit();
1422
 
 
 
 
1423
  // set output mappings
1424
  {
1425
  bool sorted_output = true;
1426
 
1427
- GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
 
 
1428
 
1429
  for (int64_t i = 0; i < n_outputs_all; ++i) {
1430
- int64_t out_id = sbatch.out_ids[i];
1431
  output_ids[out_id] = i;
1432
  if (out_id != i) {
1433
  sorted_output = false;
1434
  }
1435
  }
1436
 
1437
- if (sorted_output) {
1438
- sbatch.out_ids.clear();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1439
  }
1440
  }
1441
 
1442
- // set to total number of outputs in the batch, for use in llama_get_logits_ith
1443
- n_outputs = n_outputs_all;
1444
-
1445
  // wait for the computation to finish (automatically done when obtaining the model output)
1446
  //synchronize();
1447
 
1448
  // decide if we need to defrag the kv cache
1449
- if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1450
- // - do not defrag small contexts (i.e. < 2048 tokens)
1451
- // - count the padding towards the number of used tokens
1452
- const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
1453
-
1454
- // queue defragmentation for next llama_kv_cache_update
1455
- if (fragmentation > cparams.defrag_thold) {
1456
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
1457
-
1458
- kv_self->defrag();
1459
- }
1460
  }
1461
 
1462
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@@ -1542,44 +1212,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1542
  return n_outputs_max;
1543
  }
1544
 
1545
- void llama_context::output_reorder() {
1546
- auto & out_ids = sbatch.out_ids;
1547
- if (!out_ids.empty()) {
1548
- const uint32_t n_vocab = model.vocab.n_tokens();
1549
- const uint32_t n_embd = model.hparams.n_embd;
1550
-
1551
- GGML_ASSERT((size_t) n_outputs == out_ids.size());
1552
-
1553
- // TODO: is there something more efficient which also minimizes swaps?
1554
- // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1555
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1556
- int32_t j_min = i;
1557
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1558
- if (out_ids[j] < out_ids[j_min]) {
1559
- j_min = j;
1560
- }
1561
- }
1562
- if (j_min == i) { continue; }
1563
- std::swap(out_ids[i], out_ids[j_min]);
1564
- if (logits_size > 0) {
1565
- for (uint32_t k = 0; k < n_vocab; k++) {
1566
- std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1567
- }
1568
- }
1569
- if (embd_size > 0) {
1570
- for (uint32_t k = 0; k < n_embd; k++) {
1571
- std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1572
- }
1573
- }
1574
- }
1575
- std::fill(output_ids.begin(), output_ids.end(), -1);
1576
- for (int32_t i = 0; i < n_outputs; ++i) {
1577
- output_ids[out_ids[i]] = i;
1578
- }
1579
- out_ids.clear();
1580
- }
1581
- }
1582
-
1583
  //
1584
  // graph
1585
  //
@@ -1616,7 +1248,7 @@ llm_graph_result_ptr llama_context::graph_build(
1616
  /*.backend_cpu =*/ backend_cpu,
1617
  /*.cvec =*/ &cvec,
1618
  /*.loras =*/ &loras,
1619
- /*.memory =*/ kv_self.get(),
1620
  /*.cross =*/ &cross,
1621
  /*.n_outputs =*/ n_outputs,
1622
  /*.cb =*/ graph_get_cb(),
@@ -2020,8 +1652,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2020
  {
2021
  LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
2022
 
2023
- output_reorder();
2024
-
2025
  const auto n_outputs = this->n_outputs;
2026
  const auto & output_ids = this->output_ids;
2027
 
@@ -2075,6 +1705,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2075
  }
2076
 
2077
  LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
 
 
2078
  kv_self->state_write(io);
2079
 
2080
  return io.n_bytes();
@@ -2158,8 +1790,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2158
  }
2159
  }
2160
 
2161
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
2162
- kv_self->state_read(io);
 
 
 
 
 
2163
 
2164
  return io.n_bytes();
2165
  }
@@ -2167,7 +1804,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2167
  size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
2168
  GGML_UNUSED(seq_id);
2169
 
2170
- kv_self->state_write(io, seq_id);
 
 
 
 
2171
 
2172
  return io.n_bytes();
2173
  }
@@ -2175,7 +1816,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
2175
  size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
2176
  GGML_UNUSED(seq_id);
2177
 
2178
- kv_self->state_read(io, seq_id);
 
 
 
 
2179
 
2180
  return io.n_bytes();
2181
  }
@@ -2203,6 +1848,215 @@ void llama_context::perf_reset() {
2203
  t_p_eval_us = n_p_eval = 0;
2204
  }
2205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2206
  //
2207
  // interface implementation
2208
  //
@@ -2230,13 +2084,13 @@ llama_context_params llama_context_default_params() {
2230
  /*.cb_eval_user_data =*/ nullptr,
2231
  /*.type_k =*/ GGML_TYPE_F16,
2232
  /*.type_v =*/ GGML_TYPE_F16,
2233
- /*.logits_all =*/ false,
 
2234
  /*.embeddings =*/ false,
2235
  /*.offload_kqv =*/ true,
2236
  /*.flash_attn =*/ false,
2237
  /*.no_perf =*/ true,
2238
- /*.abort_callback =*/ nullptr,
2239
- /*.abort_callback_data =*/ nullptr,
2240
  };
2241
 
2242
  return result;
@@ -2530,7 +2384,7 @@ void llama_kv_cache_seq_cp(
2530
  llama_seq_id seq_id_dst,
2531
  llama_pos p0,
2532
  llama_pos p1) {
2533
- return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2534
  }
2535
 
2536
  void llama_kv_self_seq_cp(
@@ -2544,14 +2398,14 @@ void llama_kv_self_seq_cp(
2544
  return;
2545
  }
2546
 
2547
- return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2548
  }
2549
 
2550
  // deprecated
2551
  void llama_kv_cache_seq_keep(
2552
  llama_context * ctx,
2553
  llama_seq_id seq_id) {
2554
- return llama_kv_self_seq_keep(ctx, seq_id);
2555
  }
2556
 
2557
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
@@ -2560,7 +2414,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2560
  return;
2561
  }
2562
 
2563
- return kv->seq_keep(seq_id);
2564
  }
2565
 
2566
  // deprecated
@@ -2570,7 +2424,7 @@ void llama_kv_cache_seq_add(
2570
  llama_pos p0,
2571
  llama_pos p1,
2572
  llama_pos delta) {
2573
- return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2574
  }
2575
 
2576
  void llama_kv_self_seq_add(
@@ -2584,7 +2438,7 @@ void llama_kv_self_seq_add(
2584
  return;
2585
  }
2586
 
2587
- return kv->seq_add(seq_id, p0, p1, delta);
2588
  }
2589
 
2590
  // deprecated
@@ -2594,7 +2448,7 @@ void llama_kv_cache_seq_div(
2594
  llama_pos p0,
2595
  llama_pos p1,
2596
  int d) {
2597
- return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2598
  }
2599
 
2600
  void llama_kv_self_seq_div(
@@ -2608,7 +2462,7 @@ void llama_kv_self_seq_div(
2608
  return;
2609
  }
2610
 
2611
- return kv->seq_div(seq_id, p0, p1, d);
2612
  }
2613
 
2614
  // deprecated
@@ -2627,7 +2481,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2627
 
2628
  // deprecated
2629
  void llama_kv_cache_defrag(llama_context * ctx) {
2630
- return llama_kv_self_defrag(ctx);
2631
  }
2632
 
2633
  void llama_kv_self_defrag(llama_context * ctx) {
@@ -2636,7 +2490,8 @@ void llama_kv_self_defrag(llama_context * ctx) {
2636
  return;
2637
  }
2638
 
2639
- return kv->defrag();
 
2640
  }
2641
 
2642
  // deprecated
@@ -2820,3 +2675,34 @@ void llama_perf_context_print(const llama_context * ctx) {
2820
  void llama_perf_context_reset(llama_context * ctx) {
2821
  ctx->perf_reset();
2822
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  #include "llama-model.h"
7
  #include "llama-kv-cache.h"
8
 
 
9
  #include <cstring>
10
  #include <stdexcept>
11
  #include <cinttypes>
 
12
 
13
  //
14
  // llama_context
 
93
  }
94
 
95
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
96
+ cparams.op_offload = params.op_offload;
97
 
98
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
99
 
 
117
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
118
  }
119
 
 
 
120
  if (!hparams.vocab_only) {
121
  // GPU backends
122
  for (auto * dev : model.devices) {
 
174
  }
175
 
176
  // init the memory module
 
177
  if (!hparams.vocab_only) {
178
+ llama_memory_params params_mem = {
179
+ /*.type_k =*/ params.type_k,
180
+ /*.type_v =*/ params.type_v,
181
+ };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ memory.reset(model.create_memory(params_mem, cparams));
 
 
 
 
 
 
 
 
184
  }
185
 
186
  // init backends
 
244
  }
245
  }
246
 
247
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
248
 
249
  if (pipeline_parallel) {
250
  LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
 
252
  }
253
 
254
  // reserve worst-case graph
255
+ if (!hparams.vocab_only && memory) {
256
  const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
257
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
258
 
 
271
  int n_nodes_tg = -1;
272
 
273
  // simulate full KV cache
274
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
275
+
276
+ kv_self->set_full();
277
 
278
  cross.v_embd.clear();
279
 
 
359
  }
360
  }
361
 
362
+ llama_context::~llama_context() {
363
+ ggml_opt_free(opt_ctx);
364
+ }
365
 
366
  void llama_context::synchronize() {
367
  ggml_backend_sched_synchronize(sched.get());
 
397
  return model;
398
  }
399
 
400
+ const llama_cparams & llama_context::get_cparams() const {
401
+ return cparams;
402
+ }
403
+
404
+ ggml_backend_sched_t llama_context::get_sched() const {
405
+ return sched.get();
406
+ }
407
+
408
+ ggml_context * llama_context::get_ctx_compute() const {
409
+ return ctx_compute.get();
410
+ }
411
+
412
  uint32_t llama_context::n_ctx() const {
413
  return cparams.n_ctx;
414
  }
 
438
  }
439
 
440
  llama_kv_cache * llama_context::get_kv_self() {
441
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
442
+ return kv_self;
443
  }
444
 
445
  const llama_kv_cache * llama_context::get_kv_self() const {
446
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
447
+ return kv_self;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  }
449
 
450
  void llama_context::kv_self_update() {
 
 
451
  bool need_reserve = false;
452
 
453
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
+ need_reserve = kv_self->update(*this);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
  // reserve a worst case graph if needed
458
  if (need_reserve) {
 
463
  uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
464
 
465
  // simulate full KV cache
466
+ kv_self->set_full();
467
 
468
  llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
469
  llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
 
484
  }
485
 
486
  float * llama_context::get_logits() {
 
 
 
487
  return logits;
488
  }
489
 
 
526
  }
527
 
528
  float * llama_context::get_embeddings() {
 
 
 
529
  return embd;
530
  }
531
 
 
677
  }
678
 
679
  // temporary allocate memory for the input batch if needed
680
+ // note: during encode, we always pass the full sequence starting from pos = 0
681
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
682
 
683
  const llama_batch & batch = batch_allocr.batch;
684
  const int32_t n_tokens = batch.n_tokens;
 
703
  t_compute_start_us = ggml_time_us();
704
  }
705
 
706
+ embd_seq.clear();
707
+
708
  n_queued_tokens += n_tokens;
709
 
710
  const int64_t n_embd = hparams.n_embd;
711
 
712
+ llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
713
 
714
  const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
715
 
 
766
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
767
  GGML_ASSERT(backend_embd != nullptr);
768
 
 
 
769
  switch (cparams.pooling_type) {
770
  case LLAMA_POOLING_TYPE_NONE:
771
  {
772
  // extract token embeddings
773
+ GGML_ASSERT(embd != nullptr);
774
+
775
  GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
776
  ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
777
  } break;
 
796
  } break;
797
  case LLAMA_POOLING_TYPE_RANK:
798
  {
799
+ // extract the rerank score - a single float per sequence
800
+ auto & embd_seq_out = embd_seq;
801
+
802
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
803
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
804
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
805
+ continue;
806
+ }
807
+ embd_seq_out[seq_id].resize(1);
808
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
809
+ }
810
+ } break;
811
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
812
  {
813
  GGML_ABORT("unknown pooling type");
 
845
  }
846
 
847
  int llama_context::decode(llama_batch & inp_batch) {
848
+ if (!memory) {
849
+ LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
850
+ return encode(inp_batch);
851
+ }
852
+
853
  if (inp_batch.n_tokens == 0) {
854
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
855
  return -1;
856
  }
857
 
858
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
859
+
860
  // temporary allocate memory for the input batch if needed
861
+ // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
862
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
863
 
864
  const llama_batch & batch = batch_allocr.batch;
865
 
 
871
  const int64_t n_tokens_all = batch.n_tokens;
872
  const int64_t n_embd = hparams.n_embd;
873
 
874
+ llama_kv_cache_guard kv_guard(kv_self);
875
 
876
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
877
 
 
905
  for (uint32_t i = 0; i < n_tokens_all; ++i) {
906
  n_outputs_all += batch.logits[i] != 0;
907
  }
908
+ } else if (embd_pooled) {
909
  n_outputs_all = n_tokens_all;
910
  } else {
911
  // keep last output only
912
  n_outputs_all = 1;
913
  }
914
 
915
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
 
 
 
 
916
 
917
  // reserve output buffer
918
  if (output_reserve(n_outputs_all) < n_outputs_all) {
 
926
  int64_t n_outputs_prev = 0;
927
 
928
  while (sbatch.n_tokens > 0) {
929
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930
 
931
  // count the outputs in this u_batch
932
  {
 
946
  }
947
 
948
  // find KV slot
949
+ if (!kv_self->find_slot(ubatch)) {
950
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
 
 
 
 
951
 
952
+ return 1;
 
 
 
 
 
 
953
  }
954
 
 
 
955
  ggml_backend_sched_reset(sched.get());
956
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
957
 
 
1065
  // finalize the batch processing
1066
  kv_guard.commit();
1067
 
1068
+ // set to total number of outputs in the batch, for use in llama_get_logits_ith
1069
+ n_outputs = n_outputs_all;
1070
+
1071
  // set output mappings
1072
  {
1073
  bool sorted_output = true;
1074
 
1075
+ auto & out_ids = sbatch.out_ids;
1076
+
1077
+ GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1078
 
1079
  for (int64_t i = 0; i < n_outputs_all; ++i) {
1080
+ int64_t out_id = out_ids[i];
1081
  output_ids[out_id] = i;
1082
  if (out_id != i) {
1083
  sorted_output = false;
1084
  }
1085
  }
1086
 
1087
+ // make the outputs have the same order they had in the user-provided batch
1088
+ // note: this is mostly relevant for recurrent models atm
1089
+ if (!sorted_output) {
1090
+ const uint32_t n_vocab = model.vocab.n_tokens();
1091
+ const uint32_t n_embd = model.hparams.n_embd;
1092
+
1093
+ GGML_ASSERT((size_t) n_outputs == out_ids.size());
1094
+
1095
+ // TODO: is there something more efficient which also minimizes swaps?
1096
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1097
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
1098
+ int32_t j_min = i;
1099
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
1100
+ if (out_ids[j] < out_ids[j_min]) {
1101
+ j_min = j;
1102
+ }
1103
+ }
1104
+ if (j_min == i) { continue; }
1105
+ std::swap(out_ids[i], out_ids[j_min]);
1106
+ if (logits_size > 0) {
1107
+ for (uint32_t k = 0; k < n_vocab; k++) {
1108
+ std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1109
+ }
1110
+ }
1111
+ if (embd_size > 0) {
1112
+ for (uint32_t k = 0; k < n_embd; k++) {
1113
+ std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1114
+ }
1115
+ }
1116
+ }
1117
+ std::fill(output_ids.begin(), output_ids.end(), -1);
1118
+ for (int32_t i = 0; i < n_outputs; ++i) {
1119
+ output_ids[out_ids[i]] = i;
1120
+ }
1121
  }
1122
  }
1123
 
 
 
 
1124
  // wait for the computation to finish (automatically done when obtaining the model output)
1125
  //synchronize();
1126
 
1127
  // decide if we need to defrag the kv cache
1128
+ if (cparams.defrag_thold > 0.0f) {
1129
+ kv_self->defrag_sched(cparams.defrag_thold);
 
 
 
 
 
 
 
 
 
1130
  }
1131
 
1132
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
 
1212
  return n_outputs_max;
1213
  }
1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1215
  //
1216
  // graph
1217
  //
 
1248
  /*.backend_cpu =*/ backend_cpu,
1249
  /*.cvec =*/ &cvec,
1250
  /*.loras =*/ &loras,
1251
+ /*.memory =*/ memory.get(),
1252
  /*.cross =*/ &cross,
1253
  /*.n_outputs =*/ n_outputs,
1254
  /*.cb =*/ graph_get_cb(),
 
1652
  {
1653
  LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
1654
 
 
 
1655
  const auto n_outputs = this->n_outputs;
1656
  const auto & output_ids = this->output_ids;
1657
 
 
1705
  }
1706
 
1707
  LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1708
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1709
+
1710
  kv_self->state_write(io);
1711
 
1712
  return io.n_bytes();
 
1790
  }
1791
  }
1792
 
1793
+ if (memory) {
1794
+ LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1795
+
1796
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1797
+
1798
+ kv_self->state_read(io);
1799
+ }
1800
 
1801
  return io.n_bytes();
1802
  }
 
1804
  size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
1805
  GGML_UNUSED(seq_id);
1806
 
1807
+ if (memory) {
1808
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1809
+
1810
+ kv_self->state_write(io, seq_id);
1811
+ }
1812
 
1813
  return io.n_bytes();
1814
  }
 
1816
  size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
1817
  GGML_UNUSED(seq_id);
1818
 
1819
+ if (memory) {
1820
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1821
+
1822
+ kv_self->state_read(io, seq_id);
1823
+ }
1824
 
1825
  return io.n_bytes();
1826
  }
 
1848
  t_p_eval_us = n_p_eval = 0;
1849
  }
1850
 
1851
+ //
1852
+ // training
1853
+ //
1854
+
1855
+ static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
1856
+ if (!tensor || tensor->type != GGML_TYPE_F32) {
1857
+ return;
1858
+ }
1859
+ if (!param_filter(tensor, userdata)) {
1860
+ return;
1861
+ }
1862
+ if (strcmp(tensor->name, "token_embd.weight") == 0) {
1863
+ return; // FIXME
1864
+ }
1865
+ if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
1866
+ return; // FIXME
1867
+ }
1868
+ ggml_set_param(tensor);
1869
+ }
1870
+
1871
+ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
1872
+ GGML_ASSERT(!opt_ctx);
1873
+ model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
1874
+ const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
1875
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1876
+ GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
1877
+ GGML_ASSERT(n_batch % n_ubatch == 0);
1878
+
1879
+ ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
1880
+ opt_params.opt_period = n_batch / n_ubatch;
1881
+ opt_params.get_opt_pars = lopt_params.get_opt_pars;
1882
+ opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
1883
+
1884
+ opt_ctx = ggml_opt_init(opt_params);
1885
+
1886
+ llama_opt_param_filter param_filter = lopt_params.param_filter;
1887
+ void * param_filter_ud = lopt_params.param_filter_ud;
1888
+
1889
+ //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
1890
+ llama_set_param(model->type_embd, param_filter, param_filter_ud);
1891
+ llama_set_param(model->pos_embd, param_filter, param_filter_ud);
1892
+ llama_set_param(model->tok_norm, param_filter, param_filter_ud);
1893
+ llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
1894
+ llama_set_param(model->output_norm, param_filter, param_filter_ud);
1895
+ llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
1896
+ llama_set_param(model->output, param_filter, param_filter_ud);
1897
+ llama_set_param(model->output_b, param_filter, param_filter_ud);
1898
+ llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
1899
+ llama_set_param(model->cls, param_filter, param_filter_ud);
1900
+ llama_set_param(model->cls_b, param_filter, param_filter_ud);
1901
+ llama_set_param(model->cls_out, param_filter, param_filter_ud);
1902
+ llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
1903
+
1904
+ for (struct llama_layer & layer : model->layers) {
1905
+ for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
1906
+ llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
1907
+ }
1908
+ }
1909
+ }
1910
+
1911
+ void llama_context::opt_epoch_iter(
1912
+ ggml_opt_dataset_t dataset,
1913
+ ggml_opt_result_t result,
1914
+ const std::vector<llama_token> & tokens,
1915
+ const std::vector<llama_token> & labels_sparse,
1916
+ llama_batch & batch,
1917
+ ggml_opt_epoch_callback callback,
1918
+ bool train,
1919
+ int64_t idata_in_loop,
1920
+ int64_t ndata_in_loop,
1921
+ int64_t t_loop_start) {
1922
+ GGML_ASSERT(opt_ctx);
1923
+ const uint32_t n_ctx = llama_model_n_ctx_train(&model);
1924
+ const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1925
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1926
+
1927
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1928
+
1929
+ kv_self->clear();
1930
+ llama_kv_cache_guard kv_guard(kv_self);
1931
+
1932
+ for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1933
+ batch.n_tokens = n_batch;
1934
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
1935
+ batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
1936
+ batch.pos [pos_batch] = pos_ctx + pos_batch;
1937
+ batch.n_seq_id[pos_batch] = 1;
1938
+ batch.seq_id [pos_batch][0] = 0;
1939
+ batch.logits [pos_batch] = true;
1940
+ }
1941
+
1942
+ const auto n_tokens_all = batch.n_tokens;
1943
+
1944
+ n_queued_tokens += n_tokens_all;
1945
+
1946
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1947
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1948
+
1949
+ embd_seq.clear();
1950
+
1951
+ int64_t n_outputs_all = n_tokens_all;
1952
+
1953
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
1954
+
1955
+ // reserve output buffer
1956
+ if (output_reserve(n_outputs_all) < n_outputs_all) {
1957
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1958
+ GGML_ABORT("TODO: handle this error");
1959
+ };
1960
+
1961
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1962
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1963
+
1964
+ n_outputs = ubatch.n_tokens;
1965
+
1966
+ // TODO: not sure if this is needed
1967
+ if (!kv_self->find_slot(ubatch)) {
1968
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1969
+
1970
+ GGML_ABORT("TODO: handle this error");
1971
+ }
1972
+
1973
+ auto * gf = graph_init();
1974
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1975
+
1976
+ struct ggml_context * ctx_compute_opt;
1977
+ {
1978
+ const size_t size_gf = ggml_graph_size(gf);
1979
+ const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
1980
+ struct ggml_init_params params = {
1981
+ /*.mem_size =*/ size_meta,
1982
+ /*.mem_buffer =*/ nullptr,
1983
+ /*.no_alloc =*/ true,
1984
+ };
1985
+ ctx_compute_opt = ggml_init(params);
1986
+ }
1987
+ ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
1988
+ ggml_opt_alloc(opt_ctx, train);
1989
+ res->set_inputs(&ubatch);
1990
+ {
1991
+ struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
1992
+ GGML_ASSERT(labels->ne[1] == n_ubatch);
1993
+ ggml_set_zero(labels);
1994
+ const float onef = 1.0f;
1995
+ for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
1996
+ const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
1997
+ GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
1998
+ ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
1999
+ }
2000
+ }
2001
+ ggml_opt_eval(opt_ctx, result);
2002
+ if (callback) {
2003
+ callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2004
+ }
2005
+ ggml_free(ctx_compute_opt);
2006
+ }
2007
+ }
2008
+
2009
+ kv_guard.commit();
2010
+ }
2011
+
2012
+ void llama_context::opt_epoch(
2013
+ ggml_opt_dataset_t dataset,
2014
+ ggml_opt_result_t result_train,
2015
+ ggml_opt_result_t result_eval,
2016
+ int64_t idata_split,
2017
+ ggml_opt_epoch_callback callback_train,
2018
+ ggml_opt_epoch_callback callback_eval) {
2019
+ const uint32_t n_ctx = this->n_ctx();
2020
+ const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
2021
+ const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
2022
+ const int64_t ndata = ggml_opt_dataset_ndata(dataset);
2023
+
2024
+ GGML_ASSERT(idata_split >= 0);
2025
+ GGML_ASSERT(idata_split <= ndata);
2026
+
2027
+ const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
2028
+
2029
+ struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
2030
+ std::vector<llama_token> tokens(n_ctx);
2031
+ std::vector<llama_token> labels_sparse(n_ctx);
2032
+
2033
+ int64_t idata = 0;
2034
+
2035
+ int64_t t_loop_start = ggml_time_us();
2036
+ int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
2037
+ for (; idata < idata_split; ++idata) {
2038
+ constexpr bool train = true;
2039
+ const int64_t idata_in_loop = idata*ubatch_per_ctx;
2040
+
2041
+ ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2042
+ opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
2043
+ callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2044
+ }
2045
+
2046
+ t_loop_start = ggml_time_us();
2047
+ ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
2048
+ for (; idata < ndata; ++idata) {
2049
+ constexpr bool train = false;
2050
+ const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
2051
+
2052
+ ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2053
+ opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
2054
+ callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2055
+ }
2056
+
2057
+ llama_batch_free(batch);
2058
+ }
2059
+
2060
  //
2061
  // interface implementation
2062
  //
 
2084
  /*.cb_eval_user_data =*/ nullptr,
2085
  /*.type_k =*/ GGML_TYPE_F16,
2086
  /*.type_v =*/ GGML_TYPE_F16,
2087
+ /*.abort_callback =*/ nullptr,
2088
+ /*.abort_callback_data =*/ nullptr,
2089
  /*.embeddings =*/ false,
2090
  /*.offload_kqv =*/ true,
2091
  /*.flash_attn =*/ false,
2092
  /*.no_perf =*/ true,
2093
+ /*.op_offload =*/ true,
 
2094
  };
2095
 
2096
  return result;
 
2384
  llama_seq_id seq_id_dst,
2385
  llama_pos p0,
2386
  llama_pos p1) {
2387
+ llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2388
  }
2389
 
2390
  void llama_kv_self_seq_cp(
 
2398
  return;
2399
  }
2400
 
2401
+ kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2402
  }
2403
 
2404
  // deprecated
2405
  void llama_kv_cache_seq_keep(
2406
  llama_context * ctx,
2407
  llama_seq_id seq_id) {
2408
+ llama_kv_self_seq_keep(ctx, seq_id);
2409
  }
2410
 
2411
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
 
2414
  return;
2415
  }
2416
 
2417
+ kv->seq_keep(seq_id);
2418
  }
2419
 
2420
  // deprecated
 
2424
  llama_pos p0,
2425
  llama_pos p1,
2426
  llama_pos delta) {
2427
+ llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2428
  }
2429
 
2430
  void llama_kv_self_seq_add(
 
2438
  return;
2439
  }
2440
 
2441
+ kv->seq_add(seq_id, p0, p1, delta);
2442
  }
2443
 
2444
  // deprecated
 
2448
  llama_pos p0,
2449
  llama_pos p1,
2450
  int d) {
2451
+ llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2452
  }
2453
 
2454
  void llama_kv_self_seq_div(
 
2462
  return;
2463
  }
2464
 
2465
+ kv->seq_div(seq_id, p0, p1, d);
2466
  }
2467
 
2468
  // deprecated
 
2481
 
2482
  // deprecated
2483
  void llama_kv_cache_defrag(llama_context * ctx) {
2484
+ llama_kv_self_defrag(ctx);
2485
  }
2486
 
2487
  void llama_kv_self_defrag(llama_context * ctx) {
 
2490
  return;
2491
  }
2492
 
2493
+ // force defrag
2494
+ kv->defrag_sched(-1.0f);
2495
  }
2496
 
2497
  // deprecated
 
2675
  void llama_perf_context_reset(llama_context * ctx) {
2676
  ctx->perf_reset();
2677
  }
2678
+
2679
+ //
2680
+ // training
2681
+ //
2682
+
2683
+ bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
2684
+ GGML_UNUSED(tensor);
2685
+ GGML_UNUSED(userdata);
2686
+ return true;
2687
+ }
2688
+
2689
+ void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
2690
+ ctx->opt_init(model, lopt_params);
2691
+ }
2692
+
2693
+ void llama_opt_epoch(
2694
+ struct llama_context * ctx,
2695
+ ggml_opt_dataset_t dataset,
2696
+ ggml_opt_result_t result_train,
2697
+ ggml_opt_result_t result_eval,
2698
+ int64_t idata_split,
2699
+ ggml_opt_epoch_callback callback_train,
2700
+ ggml_opt_epoch_callback callback_eval) {
2701
+ ctx->opt_epoch(
2702
+ dataset,
2703
+ result_train,
2704
+ result_eval,
2705
+ idata_split,
2706
+ callback_train,
2707
+ callback_eval);
2708
+ }
examples/talk-llama/llama-context.h CHANGED
@@ -7,6 +7,7 @@
7
  #include "llama-adapter.h"
8
 
9
  #include "ggml-cpp.h"
 
10
 
11
  #include <map>
12
  #include <vector>
@@ -27,7 +28,12 @@ struct llama_context {
27
 
28
  void synchronize();
29
 
30
- const llama_model & get_model() const;
 
 
 
 
 
31
 
32
  uint32_t n_ctx() const;
33
  uint32_t n_ctx_per_seq() const;
@@ -128,6 +134,32 @@ struct llama_context {
128
  llama_perf_context_data perf_get_data() const;
129
  void perf_reset();
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  private:
132
  //
133
  // output
@@ -137,49 +169,30 @@ private:
137
  // Returns max number of outputs for which space was reserved.
138
  int32_t output_reserve(int32_t n_outputs);
139
 
140
- // make the outputs have the same order they had in the user-provided batch
141
- // TODO: maybe remove this
142
- void output_reorder();
143
-
144
  //
145
  // graph
146
  //
147
 
 
148
  int32_t graph_max_nodes() const;
149
 
150
  // zero-out inputs and create the ctx_compute for the compute graph
151
  ggml_cgraph * graph_init();
152
 
 
 
 
 
 
 
153
  llm_graph_result_ptr graph_build(
154
  ggml_context * ctx,
155
  ggml_cgraph * gf,
156
  const llama_ubatch & ubatch,
157
  llm_graph_type gtype);
158
 
159
- // returns the result of ggml_backend_sched_graph_compute_async execution
160
- ggml_status graph_compute(
161
- ggml_cgraph * gf,
162
- bool batched);
163
-
164
  llm_graph_cb graph_get_cb() const;
165
 
166
- // used by kv_self_update()
167
- ggml_tensor * build_rope_shift(
168
- ggml_context * ctx0,
169
- ggml_tensor * cur,
170
- ggml_tensor * shift,
171
- ggml_tensor * factors,
172
- float freq_base,
173
- float freq_scale) const;
174
-
175
- llm_graph_result_ptr build_kv_self_shift(
176
- ggml_context * ctx0,
177
- ggml_cgraph * gf) const;
178
-
179
- llm_graph_result_ptr build_kv_self_defrag(
180
- ggml_context * ctx0,
181
- ggml_cgraph * gf) const;
182
-
183
  // TODO: read/write lora adapters and cvec
184
  size_t state_write_data(llama_io_write_i & io);
185
  size_t state_read_data (llama_io_read_i & io);
@@ -196,14 +209,10 @@ private:
196
  llama_cparams cparams;
197
  llama_adapter_cvec cvec;
198
  llama_adapter_loras loras;
199
- llama_sbatch sbatch;
200
 
201
  llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
202
 
203
- std::unique_ptr<llama_kv_cache_unified> kv_self;
204
-
205
- // TODO: remove
206
- bool logits_all = false;
207
 
208
  // decode output (2-dimensional array: [n_outputs][n_vocab])
209
  size_t logits_size = 0; // capacity (of floats) for logits
@@ -230,6 +239,9 @@ private:
230
 
231
  ggml_context_ptr ctx_compute;
232
 
 
 
 
233
  ggml_threadpool_t threadpool = nullptr;
234
  ggml_threadpool_t threadpool_batch = nullptr;
235
 
 
7
  #include "llama-adapter.h"
8
 
9
  #include "ggml-cpp.h"
10
+ #include "ggml-opt.h"
11
 
12
  #include <map>
13
  #include <vector>
 
28
 
29
  void synchronize();
30
 
31
+ const llama_model & get_model() const;
32
+ const llama_cparams & get_cparams() const;
33
+
34
+ ggml_backend_sched_t get_sched() const;
35
+
36
+ ggml_context * get_ctx_compute() const;
37
 
38
  uint32_t n_ctx() const;
39
  uint32_t n_ctx_per_seq() const;
 
134
  llama_perf_context_data perf_get_data() const;
135
  void perf_reset();
136
 
137
+ //
138
+ // training
139
+ //
140
+
141
+ void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
142
+
143
+ void opt_epoch(
144
+ ggml_opt_dataset_t dataset,
145
+ ggml_opt_result_t result_train,
146
+ ggml_opt_result_t result_eval,
147
+ int64_t idata_split,
148
+ ggml_opt_epoch_callback callback_train,
149
+ ggml_opt_epoch_callback callback_eval);
150
+
151
+ void opt_epoch_iter(
152
+ ggml_opt_dataset_t dataset,
153
+ ggml_opt_result_t result,
154
+ const std::vector<llama_token> & tokens,
155
+ const std::vector<llama_token> & labels_sparse,
156
+ llama_batch & batch,
157
+ ggml_opt_epoch_callback callback,
158
+ bool train,
159
+ int64_t idata_in_loop,
160
+ int64_t ndata_in_loop,
161
+ int64_t t_loop_start);
162
+
163
  private:
164
  //
165
  // output
 
169
  // Returns max number of outputs for which space was reserved.
170
  int32_t output_reserve(int32_t n_outputs);
171
 
 
 
 
 
172
  //
173
  // graph
174
  //
175
 
176
+ public:
177
  int32_t graph_max_nodes() const;
178
 
179
  // zero-out inputs and create the ctx_compute for the compute graph
180
  ggml_cgraph * graph_init();
181
 
182
+ // returns the result of ggml_backend_sched_graph_compute_async execution
183
+ ggml_status graph_compute(
184
+ ggml_cgraph * gf,
185
+ bool batched);
186
+
187
+ private:
188
  llm_graph_result_ptr graph_build(
189
  ggml_context * ctx,
190
  ggml_cgraph * gf,
191
  const llama_ubatch & ubatch,
192
  llm_graph_type gtype);
193
 
 
 
 
 
 
194
  llm_graph_cb graph_get_cb() const;
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  // TODO: read/write lora adapters and cvec
197
  size_t state_write_data(llama_io_write_i & io);
198
  size_t state_read_data (llama_io_read_i & io);
 
209
  llama_cparams cparams;
210
  llama_adapter_cvec cvec;
211
  llama_adapter_loras loras;
 
212
 
213
  llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
214
 
215
+ std::unique_ptr<llama_memory_i> memory;
 
 
 
216
 
217
  // decode output (2-dimensional array: [n_outputs][n_vocab])
218
  size_t logits_size = 0; // capacity (of floats) for logits
 
239
 
240
  ggml_context_ptr ctx_compute;
241
 
242
+ // training
243
+ ggml_opt_context_t opt_ctx = nullptr;
244
+
245
  ggml_threadpool_t threadpool = nullptr;
246
  ggml_threadpool_t threadpool_batch = nullptr;
247
 
examples/talk-llama/llama-cparams.h CHANGED
@@ -30,6 +30,7 @@ struct llama_cparams {
30
  bool flash_attn;
31
  bool no_perf;
32
  bool warmup;
 
33
 
34
  enum llama_pooling_type pooling_type;
35
 
 
30
  bool flash_attn;
31
  bool no_perf;
32
  bool warmup;
33
+ bool op_offload;
34
 
35
  enum llama_pooling_type pooling_type;
36
 
examples/talk-llama/llama-graph.cpp CHANGED
@@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
284
 
285
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
286
  for (uint32_t i = 0; i < n_kv; ++i) {
287
- const uint32_t cell_id = i + kv_self->head;
288
-
289
- //////////////////////////////////////////////
290
- // TODO: this should not mutate the KV cache !
291
- llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
292
-
293
- // prevent out-of-bound sources
294
- if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
295
- kv_cell.src = cell_id;
296
- }
297
-
298
- data[i] = kv_cell.src;
299
-
300
- // TODO: do not mutate the KV cache
301
- // ensure copy only happens once
302
- if (kv_cell.src != (int32_t) cell_id) {
303
- kv_cell.src = cell_id;
304
- }
305
  }
306
  }
307
  }
@@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
317
 
318
  // clear unused states
319
  for (int i = 0; i < n_kv; ++i) {
320
- const uint32_t cell_id = i + kv_self->head;
321
-
322
- //////////////////////////////////////////////
323
- // TODO: this should not mutate the KV cache !
324
- llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
325
-
326
- data[i] = (float) (kv_cell.src >= 0);
327
-
328
- // only clear once
329
- if (kv_cell.src < 0) {
330
- kv_cell.src = cell_id;
331
- }
332
  }
333
  }
334
  }
@@ -810,7 +782,7 @@ ggml_tensor * llm_graph_context::build_ffn(
810
  } break;
811
  }
812
 
813
- if (type_gate == LLM_FFN_PAR) {
814
  cur = ggml_mul(ctx0, cur, tmp);
815
  cb(cur, "ffn_gate_par", il);
816
  }
@@ -999,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
999
  inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1000
  //cb(inp->tokens, "inp_tokens", -1);
1001
  ggml_set_input(inp->tokens);
 
1002
 
1003
  cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1004
 
@@ -1105,7 +1078,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
1105
  }
1106
 
1107
  ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1108
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1109
 
1110
  auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
1111
 
@@ -1122,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1122
  }
1123
 
1124
  ggml_tensor * llm_graph_context::build_inp_s_mask() const {
1125
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1126
 
1127
  auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
1128
 
@@ -1255,8 +1228,19 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1255
  ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1256
 
1257
  if (v_mla) {
 
 
 
1258
  cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1259
  cur = ggml_mul_mat(ctx0, v_mla, cur);
 
 
 
 
 
 
 
 
1260
  }
1261
 
1262
  cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
@@ -1436,8 +1420,6 @@ ggml_tensor * llm_graph_context::build_attn(
1436
 
1437
  // store to KV cache
1438
  {
1439
- GGML_ASSERT(!kv_self->recurrent);
1440
-
1441
  const auto kv_head = kv_self->head;
1442
 
1443
  GGML_ASSERT(kv_self->size == n_ctx);
@@ -1587,7 +1569,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
1587
  ggml_tensor * state_mask,
1588
  int32_t n_state,
1589
  int32_t n_seqs) const {
1590
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1591
 
1592
  const auto n_kv = kv_self->n;
1593
  const auto kv_head = kv_self->head;
@@ -1619,7 +1601,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1619
  ggml_tensor * state_mask,
1620
  const llama_ubatch & ubatch,
1621
  int il) const {
1622
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1623
 
1624
  const auto token_shift_count = hparams.token_shift_count;
1625
 
@@ -1640,7 +1622,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1640
  ggml_tensor * token_shift,
1641
  const llama_ubatch & ubatch,
1642
  int il) const {
1643
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1644
 
1645
  const auto token_shift_count = hparams.token_shift_count;
1646
  const auto n_embd = hparams.n_embd;
 
284
 
285
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
286
  for (uint32_t i = 0; i < n_kv; ++i) {
287
+ data[i] = kv_self->s_copy(i);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  }
289
  }
290
  }
 
300
 
301
  // clear unused states
302
  for (int i = 0; i < n_kv; ++i) {
303
+ data[i] = kv_self->s_mask(i);
 
 
 
 
 
 
 
 
 
 
 
304
  }
305
  }
306
  }
 
782
  } break;
783
  }
784
 
785
+ if (gate && type_gate == LLM_FFN_PAR) {
786
  cur = ggml_mul(ctx0, cur, tmp);
787
  cb(cur, "ffn_gate_par", il);
788
  }
 
971
  inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
972
  //cb(inp->tokens, "inp_tokens", -1);
973
  ggml_set_input(inp->tokens);
974
+ res->t_tokens = inp->tokens;
975
 
976
  cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
977
 
 
1078
  }
1079
 
1080
  ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1081
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1082
 
1083
  auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
1084
 
 
1095
  }
1096
 
1097
  ggml_tensor * llm_graph_context::build_inp_s_mask() const {
1098
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1099
 
1100
  auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
1101
 
 
1228
  ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1229
 
1230
  if (v_mla) {
1231
+ #if 0
1232
+ // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1233
+ // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1234
  cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1235
  cur = ggml_mul_mat(ctx0, v_mla, cur);
1236
+ #else
1237
+ // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1238
+ // The permutations are noops and only change how the tensor data is interpreted.
1239
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1240
+ cur = ggml_mul_mat(ctx0, v_mla, cur);
1241
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1242
+ cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1243
+ #endif
1244
  }
1245
 
1246
  cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
 
1420
 
1421
  // store to KV cache
1422
  {
 
 
1423
  const auto kv_head = kv_self->head;
1424
 
1425
  GGML_ASSERT(kv_self->size == n_ctx);
 
1569
  ggml_tensor * state_mask,
1570
  int32_t n_state,
1571
  int32_t n_seqs) const {
1572
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1573
 
1574
  const auto n_kv = kv_self->n;
1575
  const auto kv_head = kv_self->head;
 
1601
  ggml_tensor * state_mask,
1602
  const llama_ubatch & ubatch,
1603
  int il) const {
1604
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1605
 
1606
  const auto token_shift_count = hparams.token_shift_count;
1607
 
 
1622
  ggml_tensor * token_shift,
1623
  const llama_ubatch & ubatch,
1624
  int il) const {
1625
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1626
 
1627
  const auto token_shift_count = hparams.token_shift_count;
1628
  const auto n_embd = hparams.n_embd;
examples/talk-llama/llama-graph.h CHANGED
@@ -19,6 +19,7 @@ struct llama_cparams;
19
 
20
  class llama_memory_i;
21
  class llama_kv_cache_unified;
 
22
 
23
  // certain models (typically multi-modal) can produce different types of graphs
24
  enum llm_graph_type {
@@ -186,26 +187,26 @@ public:
186
 
187
  class llm_graph_input_s_copy : public llm_graph_input_i {
188
  public:
189
- llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
190
  virtual ~llm_graph_input_s_copy() = default;
191
 
192
  void set_input(const llama_ubatch * ubatch) override;
193
 
194
  ggml_tensor * s_copy; // I32 [kv_size]
195
 
196
- const llama_kv_cache_unified * kv_self;
197
  };
198
 
199
  class llm_graph_input_s_mask : public llm_graph_input_i {
200
  public:
201
- llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
202
  virtual ~llm_graph_input_s_mask() = default;
203
 
204
  void set_input(const llama_ubatch * ubatch) override;
205
 
206
  ggml_tensor * s_mask; // F32 [1, n_kv]
207
 
208
- const llama_kv_cache_unified * kv_self;
209
  };
210
 
211
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -297,6 +298,7 @@ class llm_graph_result_i {
297
  public:
298
  virtual ~llm_graph_result_i() = default;
299
 
 
300
  virtual ggml_tensor * get_logits() = 0;
301
  virtual ggml_tensor * get_embd() = 0;
302
  virtual ggml_tensor * get_embd_pooled() = 0;
@@ -311,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i {
311
  public:
312
  virtual ~llm_graph_result() = default;
313
 
 
314
  ggml_tensor * get_logits() override { return t_logits; }
315
  ggml_tensor * get_embd() override { return t_embd; }
316
  ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
@@ -327,6 +330,7 @@ public:
327
  }
328
 
329
  // important graph nodes
 
330
  ggml_tensor * t_logits = nullptr;
331
  ggml_tensor * t_embd = nullptr;
332
  ggml_tensor * t_embd_pooled = nullptr;
@@ -350,8 +354,8 @@ struct llm_graph_params {
350
  const llama_cparams & cparams;
351
  const llama_ubatch & ubatch;
352
 
353
- ggml_backend_sched * sched;
354
- ggml_backend * backend_cpu;
355
 
356
  const llama_adapter_cvec * cvec;
357
  const llama_adapter_loras * loras;
@@ -402,9 +406,9 @@ struct llm_graph_context {
402
 
403
  ggml_context * ctx0 = nullptr;
404
 
405
- ggml_backend_sched * sched;
406
 
407
- ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
408
 
409
  const llama_adapter_cvec * cvec;
410
  const llama_adapter_loras * loras;
 
19
 
20
  class llama_memory_i;
21
  class llama_kv_cache_unified;
22
+ class llama_kv_cache_recurrent;
23
 
24
  // certain models (typically multi-modal) can produce different types of graphs
25
  enum llm_graph_type {
 
187
 
188
  class llm_graph_input_s_copy : public llm_graph_input_i {
189
  public:
190
+ llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
191
  virtual ~llm_graph_input_s_copy() = default;
192
 
193
  void set_input(const llama_ubatch * ubatch) override;
194
 
195
  ggml_tensor * s_copy; // I32 [kv_size]
196
 
197
+ const llama_kv_cache_recurrent * kv_self;
198
  };
199
 
200
  class llm_graph_input_s_mask : public llm_graph_input_i {
201
  public:
202
+ llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
203
  virtual ~llm_graph_input_s_mask() = default;
204
 
205
  void set_input(const llama_ubatch * ubatch) override;
206
 
207
  ggml_tensor * s_mask; // F32 [1, n_kv]
208
 
209
+ const llama_kv_cache_recurrent * kv_self;
210
  };
211
 
212
  class llm_graph_input_cross_embd : public llm_graph_input_i {
 
298
  public:
299
  virtual ~llm_graph_result_i() = default;
300
 
301
+ virtual ggml_tensor * get_tokens() = 0;
302
  virtual ggml_tensor * get_logits() = 0;
303
  virtual ggml_tensor * get_embd() = 0;
304
  virtual ggml_tensor * get_embd_pooled() = 0;
 
313
  public:
314
  virtual ~llm_graph_result() = default;
315
 
316
+ ggml_tensor * get_tokens() override { return t_tokens; }
317
  ggml_tensor * get_logits() override { return t_logits; }
318
  ggml_tensor * get_embd() override { return t_embd; }
319
  ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
 
330
  }
331
 
332
  // important graph nodes
333
+ ggml_tensor * t_tokens = nullptr;
334
  ggml_tensor * t_logits = nullptr;
335
  ggml_tensor * t_embd = nullptr;
336
  ggml_tensor * t_embd_pooled = nullptr;
 
354
  const llama_cparams & cparams;
355
  const llama_ubatch & ubatch;
356
 
357
+ ggml_backend_sched_t sched;
358
+ ggml_backend_t backend_cpu;
359
 
360
  const llama_adapter_cvec * cvec;
361
  const llama_adapter_loras * loras;
 
406
 
407
  ggml_context * ctx0 = nullptr;
408
 
409
+ ggml_backend_sched_t sched;
410
 
411
+ ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
412
 
413
  const llama_adapter_cvec * cvec;
414
  const llama_adapter_loras * loras;
examples/talk-llama/llama-kv-cache.cpp CHANGED
@@ -4,33 +4,41 @@
4
  #include "llama-batch.h"
5
  #include "llama-cparams.h"
6
  #include "llama-model.h"
 
7
 
8
  #include <algorithm>
9
  #include <cassert>
 
10
  #include <limits>
11
  #include <map>
12
  #include <stdexcept>
13
 
14
- llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
 
 
 
 
 
 
15
  }
16
 
17
- bool llama_kv_cache_unified::init(
18
  const llama_model & model,
19
- const llama_cparams & cparams,
20
  ggml_type type_k,
21
  ggml_type type_v,
 
 
22
  uint32_t kv_size,
23
- bool offload) {
24
  const int32_t n_layer = hparams.n_layer;
25
 
26
  has_shift = false;
 
27
 
28
- recurrent = llama_model_is_recurrent(&model);
29
- v_trans = !recurrent && !cparams.flash_attn;
30
- can_shift = !recurrent;
31
 
32
- LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
33
- __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
34
 
35
  head = 0;
36
  size = kv_size;
@@ -76,23 +84,20 @@ bool llama_kv_cache_unified::init(
76
 
77
  const char * dev_name = "CPU";
78
 
79
- ggml_backend_buffer_type_t buft;
 
80
  if (offload) {
81
  auto * dev = model.dev_layer(i);
82
  buft = ggml_backend_dev_buffer_type(dev);
83
 
84
  dev_name = ggml_backend_dev_name(dev);
85
- } else {
86
- buft = ggml_backend_cpu_buffer_type();
87
  }
88
 
89
- LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
90
- i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
91
 
92
  ggml_context * ctx = ctx_for_buft(buft);
93
  if (!ctx) {
94
- LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
95
- return false;
96
  }
97
 
98
  ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
@@ -110,55 +115,28 @@ bool llama_kv_cache_unified::init(
110
 
111
  ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
112
  if (!buf) {
113
- LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
114
- return false;
115
  }
116
  ggml_backend_buffer_clear(buf, 0);
117
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
118
  bufs.emplace_back(buf);
119
  }
120
 
121
- return true;
122
- }
123
-
124
- int32_t llama_kv_cache_unified::get_n_tokens() const {
125
- int32_t result = 0;
126
-
127
- for (uint32_t i = 0; i < size; i++) {
128
- result += cells[i].seq_id.size();
129
- }
130
-
131
- return result;
132
- }
133
-
134
- int32_t llama_kv_cache_unified::get_used_cells() const {
135
- return used;
136
- }
137
-
138
- size_t llama_kv_cache_unified::total_size() const {
139
- size_t size = 0;
140
- for (const auto & buf : bufs) {
141
- size += ggml_backend_buffer_get_size(buf.get());
142
- }
143
-
144
- return size;
145
- }
146
 
147
- llama_pos llama_kv_cache_unified::pos_max() const {
148
- llama_pos pos_max = -1;
149
- for (const auto & cell : cells) {
150
- pos_max = std::max(pos_max, cell.pos);
151
  }
152
-
153
- return pos_max;
154
  }
155
 
156
  void llama_kv_cache_unified::clear() {
157
  for (int32_t i = 0; i < (int32_t) size; ++i) {
158
  cells[i].pos = -1;
159
  cells[i].seq_id.clear();
160
- cells[i].src = -1;
161
- cells[i].tail = -1;
162
  }
163
  head = 0;
164
  used = 0;
@@ -179,35 +157,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
179
  p1 = std::numeric_limits<llama_pos>::max();
180
  }
181
 
182
- // models like Mamba or RWKV can't have a state partially erased
183
- if (recurrent) {
184
- if (seq_id >= (int64_t) size) {
185
- // could be fatal
186
- return false;
187
- }
188
- if (0 <= seq_id) {
189
- int32_t & tail_id = cells[seq_id].tail;
190
- if (tail_id >= 0) {
191
- const llama_kv_cell & cell = cells[tail_id];
192
- // partial intersection is invalid
193
- if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
194
- return false;
195
- }
196
- // invalidate tails which will be cleared
197
- if (p0 <= cell.pos && cell.pos < p1) {
198
- tail_id = -1;
199
- }
200
- }
201
- } else {
202
- // seq_id is negative, then the range should include everything or nothing
203
- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
204
- return false;
205
- }
206
- }
207
-
208
- return true;
209
- }
210
-
211
  for (uint32_t i = 0; i < size; ++i) {
212
  if (cells[i].pos >= p0 && cells[i].pos < p1) {
213
  if (seq_id < 0) {
@@ -224,7 +173,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
224
  }
225
 
226
  cells[i].pos = -1;
227
- cells[i].src = -1;
228
 
229
  if (new_head == size) {
230
  new_head = i;
@@ -254,34 +202,6 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
254
  p1 = std::numeric_limits<llama_pos>::max();
255
  }
256
 
257
- if (recurrent) {
258
- if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
259
- llama_kv_cell & tail_src = cells[seq_id_src];
260
- llama_kv_cell & tail_dst = cells[seq_id_dst];
261
- if (tail_dst.tail >= 0) {
262
- // clear destination seq_id if it wasn't empty
263
- llama_kv_cell & cell_dst = cells[tail_dst.tail];
264
-
265
- cell_dst.seq_id.erase(seq_id_dst);
266
- tail_dst.tail = -1;
267
- if (cell_dst.seq_id.empty()) {
268
- cell_dst.pos = -1;
269
- cell_dst.delta = -1;
270
- cell_dst.src = -1;
271
- used -= 1;
272
- }
273
- }
274
- if (tail_src.tail >= 0) {
275
- llama_kv_cell & cell_src = cells[tail_src.tail];
276
-
277
- cell_src.seq_id.insert(seq_id_dst);
278
- tail_dst.tail = tail_src.tail;
279
- }
280
- }
281
-
282
- return;
283
- }
284
-
285
  // otherwise, this is the KV of a Transformer-like model
286
  head = 0;
287
 
@@ -296,17 +216,12 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
296
  uint32_t new_head = size;
297
 
298
  for (uint32_t i = 0; i < size; ++i) {
299
- if (recurrent && (llama_seq_id) i != seq_id) {
300
- cells[i].tail = -1;
301
- }
302
-
303
  if (!cells[i].has_seq_id(seq_id)) {
304
  if (cells[i].pos >= 0) {
305
  used--;
306
  }
307
 
308
  cells[i].pos = -1;
309
- cells[i].src = -1;
310
  cells[i].seq_id.clear();
311
 
312
  if (new_head == size){
@@ -344,20 +259,6 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
344
  return;
345
  }
346
 
347
- if (recurrent) {
348
- // for Mamba-like or RWKV models, only the pos needs to be shifted
349
- if (0 <= seq_id && seq_id < (int64_t) size) {
350
- const int32_t tail_id = cells[seq_id].tail;
351
- if (tail_id >= 0) {
352
- llama_kv_cell & cell = cells[tail_id];
353
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
354
- cell.pos += delta;
355
- }
356
- }
357
- }
358
- return;
359
- }
360
-
361
  for (uint32_t i = 0; i < size; ++i) {
362
  if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
363
  has_shift = true;
@@ -400,21 +301,6 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
400
  return;
401
  }
402
 
403
- if (recurrent) {
404
- // for Mamba-like or RWKV models, only the pos needs to be changed
405
- if (0 <= seq_id && seq_id < (int64_t) size) {
406
- const int32_t tail_id = cells[seq_id].tail;
407
- if (tail_id >= 0) {
408
- llama_kv_cell & cell = cells[tail_id];
409
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
410
- cell.pos /= d;
411
- }
412
- }
413
- }
414
-
415
- return;
416
- }
417
-
418
  for (uint32_t i = 0; i < size; ++i) {
419
  if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
420
  has_shift = true;
@@ -440,23 +326,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
440
  return result;
441
  }
442
 
443
- void llama_kv_cache_unified::defrag() {
444
- if (!recurrent) {
445
- do_defrag = true;
446
- }
447
- }
448
-
449
  void llama_kv_cache_unified::restore() {
450
  if (pending.ranges.empty()) {
451
  return;
452
  }
453
 
454
- // TODO: tmp - move to llama_kv_cache_recurrent
455
- if (recurrent) {
456
- seq_rm(-1, -1, -1);
457
- return;
458
- }
459
-
460
  uint32_t new_head = size;
461
 
462
  for (auto & range : pending.ranges) {
@@ -469,7 +343,6 @@ void llama_kv_cache_unified::restore() {
469
  }
470
 
471
  cells[i].pos = -1;
472
- cells[i].src = -1;
473
  }
474
 
475
  new_head = std::min(new_head, range.c0);
@@ -481,11 +354,6 @@ void llama_kv_cache_unified::restore() {
481
  }
482
 
483
  void llama_kv_cache_unified::commit() {
484
- // TODO: tmp - move to llama_kv_cache_recurrent
485
- if (recurrent) {
486
- return;
487
- }
488
-
489
  if (pending.ranges.empty()) {
490
  LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
491
  __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
@@ -495,183 +363,110 @@ void llama_kv_cache_unified::commit() {
495
  pending.ranges.clear();
496
  }
497
 
498
- bool llama_kv_cache_unified::get_can_shift() const {
499
- return can_shift;
500
- }
501
 
502
- bool llama_kv_cache_unified::find_slot(
503
- const llama_ubatch & ubatch) {
504
- const uint32_t n_tokens = ubatch.n_tokens;
505
- const uint32_t n_seqs = ubatch.n_seqs;
506
- const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
507
 
508
- // if we have enough unused cells before the current head ->
509
- // better to start searching from the beginning of the cache, hoping to fill it
510
- if (head > used + 2*ubatch.n_tokens) {
511
- head = 0;
512
- }
513
 
514
- if (recurrent) {
515
- // For recurrent state architectures (like Mamba or RWKV),
516
- // each cache cell can store the state for a whole sequence.
517
- // A slot should be always be contiguous.
518
 
519
- // can only process batches with an equal number of new tokens in each sequence
520
- GGML_ASSERT(ubatch.equal_seqs);
 
521
 
522
- int32_t min = size - 1;
523
- int32_t max = 0;
524
 
525
- // everything should fit if all seq_ids are smaller than the max
526
- for (uint32_t s = 0; s < n_seqs; ++s) {
527
- const uint32_t n_seq_id = ubatch.n_seq_id[s];
528
- for (uint32_t j = 0; j < n_seq_id; ++j) {
529
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
530
 
531
- if (seq_id < 0 || (uint32_t) seq_id >= size) {
532
- // too big seq_id
533
- // TODO: would it be possible to resize the cache instead?
534
- LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
535
- return false;
536
- }
537
- if (j > 0) {
538
- llama_kv_cell & seq = cells[seq_id];
539
- if (seq.tail >= 0) {
540
- llama_kv_cell & cell = cells[seq.tail];
541
- // clear cells from seq_ids that become shared
542
- // (should not normally happen, but let's handle it anyway)
543
- cell.seq_id.erase(seq_id);
544
- seq.tail = -1;
545
- if (cell.seq_id.empty()) {
546
- cell.pos = -1;
547
- cell.src = -1;
548
- used -= 1;
549
- }
550
- }
551
- }
552
- }
553
  }
554
 
555
- #ifndef NDEBUG
556
  {
557
- std::vector<int32_t> tails_verif;
558
- tails_verif.assign(size, -1);
559
- for (uint32_t i = 0; i < size; ++i) {
560
- llama_kv_cell & cell = cells[i];
561
- for (llama_seq_id seq_id : cell.seq_id) {
562
- if (tails_verif[seq_id] != -1) {
563
- LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
564
- }
565
- tails_verif[seq_id] = i;
566
- }
567
- }
568
  for (uint32_t i = 0; i < size; ++i) {
569
- if (tails_verif[i] != cells[i].tail) {
570
- LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
571
- }
572
  }
573
  }
574
- #endif
575
 
576
- // find next empty cell
577
- uint32_t next_empty_cell = head;
578
 
579
- for (uint32_t i = 0; i < size; ++i) {
580
- if (next_empty_cell >= size) { next_empty_cell -= size; }
581
- llama_kv_cell & cell = cells[next_empty_cell];
582
- if (cell.is_empty()) { break; }
583
- next_empty_cell += 1;
584
- }
585
 
586
- // find usable cell range
587
- for (uint32_t s = 0; s < n_seqs; ++s) {
588
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
589
- llama_kv_cell & seq_meta = cells[seq_id];
590
- bool has_cell = false;
591
- if (seq_meta.tail >= 0) {
592
- llama_kv_cell & cell = cells[seq_meta.tail];
593
- GGML_ASSERT(cell.has_seq_id(seq_id));
594
- // does this seq_id "own" the cell?
595
- if (cell.seq_id.size() == 1) { has_cell = true; }
596
- }
597
- if (!has_cell) {
598
- llama_kv_cell & empty_cell = cells[next_empty_cell];
599
- GGML_ASSERT(empty_cell.is_empty());
600
- // copy old tail into the empty cell
601
- if (seq_meta.tail >= 0) {
602
- llama_kv_cell & orig_cell = cells[seq_meta.tail];
603
- empty_cell.pos = orig_cell.pos;
604
- empty_cell.src = orig_cell.src;
605
- orig_cell.seq_id.erase(seq_id);
606
- empty_cell.seq_id.insert(seq_id); // will be overwritten
607
- }
608
- seq_meta.tail = next_empty_cell;
609
- // find next empty cell
610
- if (s + 1 < n_seqs) {
611
- next_empty_cell += 1;
612
- for (uint32_t i = 0; i < size; ++i) {
613
- if (next_empty_cell >= size) { next_empty_cell -= size; }
614
- llama_kv_cell & cell = cells[next_empty_cell];
615
- if (cell.is_empty()) { break; }
616
- next_empty_cell += 1;
617
- }
618
- }
619
- }
620
- if (min > seq_meta.tail) { min = seq_meta.tail; }
621
- if (max < seq_meta.tail) { max = seq_meta.tail; }
622
- }
623
 
624
- // gather and re-order
625
- for (uint32_t s = 0; s < n_seqs; ++s) {
626
- int32_t dst_id = s + min;
627
- int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
628
- if (dst_id != src_id) {
629
- llama_kv_cell & dst_cell = cells[dst_id];
630
- llama_kv_cell & src_cell = cells[src_id];
631
 
632
- std::swap(dst_cell.pos, src_cell.pos);
633
- std::swap(dst_cell.src, src_cell.src);
634
- std::swap(dst_cell.seq_id, src_cell.seq_id);
635
 
636
- // swap tails (assuming they NEVER overlap)
637
- for (const llama_seq_id seq_id : src_cell.seq_id) {
638
- cells[seq_id].tail = src_id;
639
- }
640
- for (const llama_seq_id seq_id : dst_cell.seq_id) {
641
- cells[seq_id].tail = dst_id;
642
- }
643
- }
644
- }
645
 
646
- // update the pos of the used seqs
647
- for (uint32_t s = 0; s < n_seqs; ++s) {
648
- const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
649
- int32_t cell_id = s + min;
650
- llama_kv_cell & cell = cells[cell_id];
651
 
652
- if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
653
- // What should happen when the pos backtracks or skips a value?
654
- // Clearing the state mid-batch would require special-casing which isn't done.
655
- LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
656
- __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
657
- }
658
- cell.pos = last_pos;
659
- cell.seq_id.clear();
660
- for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
661
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
662
- cell.seq_id.insert(seq_id);
663
- cells[seq_id].tail = cell_id;
664
- }
665
  }
666
 
667
- // allow getting the range of used cells, from head to head + n
668
- head = min;
669
- n = max - min + 1;
670
- used = std::count_if(cells.begin(), cells.end(),
671
- [](const llama_kv_cell& cell){ return !cell.is_empty(); });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
 
673
- // sanity check
674
- return n >= n_seqs;
 
 
675
  }
676
 
677
  // otherwise, one cell per token.
@@ -725,24 +520,50 @@ bool llama_kv_cache_unified::find_slot(
725
 
726
  pending.ranges.push_back({head, head + n_tokens});
727
 
 
 
 
 
 
 
 
728
  return true;
729
  }
730
 
731
- uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
732
- // the FA kernels require padding to avoid extra runtime boundary checks
733
- return cparams.flash_attn ? 256u : 32u;
 
 
 
 
 
734
  }
735
 
736
- uint32_t llama_kv_cache_unified::cell_max() const {
737
- for (uint32_t i = size; i > 0; --i) {
738
- const llama_kv_cell & cell = cells[i - 1];
739
 
740
- if (cell.pos >= 0 && !cell.is_empty()) {
741
- return i;
742
- }
 
 
 
 
 
743
  }
744
 
745
- return 0;
 
 
 
 
 
 
 
 
 
746
  }
747
 
748
  size_t llama_kv_cache_unified::size_k_bytes() const {
@@ -765,68 +586,331 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
765
  return size_v_bytes;
766
  }
767
 
768
- bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
769
- const uint32_t n_layer = hparams.n_layer;
 
 
 
 
 
 
 
770
 
771
- const uint32_t n_kv = cell_max();
772
- const uint32_t n_used = used;
 
773
 
774
- assert(n_used <= n_kv);
 
775
 
776
- //const int64_t t_start = ggml_time_us();
 
 
777
 
778
- // number of cells moved
779
- uint32_t n_moves = 0;
780
 
781
- // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
782
- // - source view, destination view, copy operation
783
- // - x2 for keys and values
784
- //const uint32_t max_moves = max_nodes()/(6*n_layer);
785
- // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
786
- const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
787
 
788
- // determine which KV cells to move where
789
- //
790
- // cell i moves to ids[i]
791
- //
792
- // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
793
- //
794
- auto & ids = defrag_info.ids;
795
 
796
- ids.clear();
797
- ids.resize(n_kv, n_kv);
 
 
 
 
 
798
 
799
- for (uint32_t i0 = 0; i0 < n_used; ++i0) {
800
- const auto & cell0 = cells[i0];
801
 
802
- if (!cell0.is_empty()) {
803
- ids[i0] = i0;
 
 
804
 
805
- continue;
806
- }
807
 
808
- // found a hole - fill it with data from the end of the cache
809
 
810
- uint32_t nh = 1;
 
811
 
812
- // determine the size of the hole
813
- while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
814
- nh++;
 
 
 
 
 
 
 
815
  }
 
 
816
 
817
- uint32_t nf = 0;
818
- uint32_t is = n_kv - 1;
 
 
 
819
 
820
- // starting from the end, find nh non-empty cells
821
- for (; is > i0; --is) {
822
- const auto & cell1 = cells[is];
823
 
824
- if (cell1.is_empty() || ids[is] != n_kv) {
825
- continue;
826
- }
827
 
828
- // non-empty cell which is not yet moved
829
- nf++;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
830
 
831
  if (nf == nh) {
832
  break;
@@ -867,7 +951,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
867
  cells[i0 + nf] = cell1;
868
 
869
  // clear the old cell and move the head there
870
- cell1 = llama_kv_cell();
871
  head = n_used;
872
 
873
  if (!cont) {
@@ -895,13 +979,25 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
895
  return false;
896
  }
897
 
898
- LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
899
 
900
- LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer);
901
 
902
  return true;
903
  }
904
 
 
 
 
 
 
 
 
 
 
 
 
 
905
  void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
906
  std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
907
  uint32_t cell_count = 0;
@@ -1110,7 +1206,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1110
  clear();
1111
 
1112
  for (uint32_t i = 0; i < cell_count; ++i) {
1113
- llama_kv_cell & cell = cells[i];
1114
 
1115
  llama_pos pos;
1116
  uint32_t n_seq_id;
@@ -1133,15 +1229,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1133
  }
1134
 
1135
  cell.seq_id.insert(seq_id);
1136
-
1137
- if (recurrent) {
1138
- int32_t & tail = cells[seq_id].tail;
1139
- if (tail != -1) {
1140
- LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
1141
- return false;
1142
- }
1143
- tail = i;
1144
- }
1145
  }
1146
  }
1147
 
@@ -1149,14 +1236,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1149
  used = cell_count;
1150
  }
1151
 
1152
- if (recurrent) {
1153
- for (uint32_t i = 0; i < cell_count; ++i) {
1154
- uint32_t cell_id = head + i;
1155
- // make sure the recurrent states will keep their restored state
1156
- cells[cell_id].src = cell_id;
1157
- }
1158
- }
1159
-
1160
  return true;
1161
  }
1162
 
@@ -1174,7 +1253,1034 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1174
  LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
1175
  return false;
1176
  }
1177
- if (v_trans != (bool) v_trans) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1178
  LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1179
  return false;
1180
  }
@@ -1326,7 +2432,7 @@ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache
1326
  view->cells_sequences = (llama_seq_id *)p;
1327
  }
1328
 
1329
- const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
1330
  llama_kv_cache_view_cell * c_curr = view->cells;
1331
  llama_seq_id * cs_curr = view->cells_sequences;
1332
  int32_t used_cells = 0;
 
4
  #include "llama-batch.h"
5
  #include "llama-cparams.h"
6
  #include "llama-model.h"
7
+ #include "llama-context.h"
8
 
9
  #include <algorithm>
10
  #include <cassert>
11
+ #include <cmath>
12
  #include <limits>
13
  #include <map>
14
  #include <stdexcept>
15
 
16
+ //
17
+ // llama_kv_cache_unified
18
+ //
19
+
20
+ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
21
+ // the FA kernels require padding to avoid extra runtime boundary checks
22
+ return cparams.flash_attn ? 256u : 32u;
23
  }
24
 
25
+ llama_kv_cache_unified::llama_kv_cache_unified(
26
  const llama_model & model,
 
27
  ggml_type type_k,
28
  ggml_type type_v,
29
+ bool v_trans,
30
+ bool offload,
31
  uint32_t kv_size,
32
+ uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
33
  const int32_t n_layer = hparams.n_layer;
34
 
35
  has_shift = false;
36
+ can_shift = true;
37
 
38
+ LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
39
+ __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
 
40
 
41
+ GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
 
42
 
43
  head = 0;
44
  size = kv_size;
 
84
 
85
  const char * dev_name = "CPU";
86
 
87
+ ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
88
+
89
  if (offload) {
90
  auto * dev = model.dev_layer(i);
91
  buft = ggml_backend_dev_buffer_type(dev);
92
 
93
  dev_name = ggml_backend_dev_name(dev);
 
 
94
  }
95
 
96
+ LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name);
 
97
 
98
  ggml_context * ctx = ctx_for_buft(buft);
99
  if (!ctx) {
100
+ throw std::runtime_error("failed to create ggml context for kv cache");
 
101
  }
102
 
103
  ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
 
115
 
116
  ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
117
  if (!buf) {
118
+ throw std::runtime_error("failed to allocate buffer for kv cache");
 
119
  }
120
  ggml_backend_buffer_clear(buf, 0);
121
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
122
  bufs.emplace_back(buf);
123
  }
124
 
125
+ {
126
+ const size_t memory_size_k = size_k_bytes();
127
+ const size_t memory_size_v = size_v_bytes();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
130
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
131
+ ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
132
+ ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
133
  }
 
 
134
  }
135
 
136
  void llama_kv_cache_unified::clear() {
137
  for (int32_t i = 0; i < (int32_t) size; ++i) {
138
  cells[i].pos = -1;
139
  cells[i].seq_id.clear();
 
 
140
  }
141
  head = 0;
142
  used = 0;
 
157
  p1 = std::numeric_limits<llama_pos>::max();
158
  }
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  for (uint32_t i = 0; i < size; ++i) {
161
  if (cells[i].pos >= p0 && cells[i].pos < p1) {
162
  if (seq_id < 0) {
 
173
  }
174
 
175
  cells[i].pos = -1;
 
176
 
177
  if (new_head == size) {
178
  new_head = i;
 
202
  p1 = std::numeric_limits<llama_pos>::max();
203
  }
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  // otherwise, this is the KV of a Transformer-like model
206
  head = 0;
207
 
 
216
  uint32_t new_head = size;
217
 
218
  for (uint32_t i = 0; i < size; ++i) {
 
 
 
 
219
  if (!cells[i].has_seq_id(seq_id)) {
220
  if (cells[i].pos >= 0) {
221
  used--;
222
  }
223
 
224
  cells[i].pos = -1;
 
225
  cells[i].seq_id.clear();
226
 
227
  if (new_head == size){
 
259
  return;
260
  }
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  for (uint32_t i = 0; i < size; ++i) {
263
  if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
264
  has_shift = true;
 
301
  return;
302
  }
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  for (uint32_t i = 0; i < size; ++i) {
305
  if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
306
  has_shift = true;
 
326
  return result;
327
  }
328
 
 
 
 
 
 
 
329
  void llama_kv_cache_unified::restore() {
330
  if (pending.ranges.empty()) {
331
  return;
332
  }
333
 
 
 
 
 
 
 
334
  uint32_t new_head = size;
335
 
336
  for (auto & range : pending.ranges) {
 
343
  }
344
 
345
  cells[i].pos = -1;
 
346
  }
347
 
348
  new_head = std::min(new_head, range.c0);
 
354
  }
355
 
356
  void llama_kv_cache_unified::commit() {
 
 
 
 
 
357
  if (pending.ranges.empty()) {
358
  LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
359
  __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
 
363
  pending.ranges.clear();
364
  }
365
 
366
+ bool llama_kv_cache_unified::update(llama_context & lctx) {
367
+ bool need_reserve = false;
 
368
 
369
+ auto * sched = lctx.get_sched();
 
 
 
 
370
 
371
+ if (has_shift) {
372
+ if (!get_can_shift()) {
373
+ GGML_ABORT("The current KV cache / model configuration does not support K-shift");
374
+ }
 
375
 
376
+ LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
 
 
 
377
 
378
+ // apply K-shift if needed
379
+ if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
380
+ ggml_backend_sched_reset(sched);
381
 
382
+ auto * gf = lctx.graph_init();
 
383
 
384
+ auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
 
 
 
 
385
 
386
+ ggml_backend_sched_alloc_graph(sched, gf);
387
+
388
+ res->set_inputs(nullptr);
389
+
390
+ lctx.graph_compute(gf, false);
391
+
392
+ need_reserve = true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  }
394
 
 
395
  {
396
+ has_shift = false;
397
+
 
 
 
 
 
 
 
 
 
398
  for (uint32_t i = 0; i < size; ++i) {
399
+ cells[i].delta = 0;
 
 
400
  }
401
  }
402
+ }
403
 
404
+ if (do_defrag) {
405
+ LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
406
 
407
+ if (defrag_prepare(lctx.graph_max_nodes())) {
408
+ ggml_backend_sched_reset(sched);
 
 
 
 
409
 
410
+ auto * gf = lctx.graph_init();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
+ auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
 
 
 
 
 
 
413
 
414
+ ggml_backend_sched_alloc_graph(sched, gf);
 
 
415
 
416
+ res->set_inputs(nullptr);
 
 
 
 
 
 
 
 
417
 
418
+ lctx.graph_compute(gf, false);
 
 
 
 
419
 
420
+ need_reserve = true;
 
 
 
 
 
 
 
 
 
 
 
 
421
  }
422
 
423
+ do_defrag = false;
424
+ }
425
+
426
+ return need_reserve;
427
+ }
428
+
429
+ void llama_kv_cache_unified::defrag_sched(float thold) {
430
+ // - do not defrag small contexts (i.e. < 2048 tokens)
431
+ // - count the padding towards the number of used tokens
432
+ const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
433
+
434
+ // queue defragmentation for next llama_kv_cache_update
435
+ if (fragmentation > thold) {
436
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
437
+
438
+ do_defrag = true;
439
+ }
440
+ }
441
+
442
+ void llama_kv_cache_unified::set_full() {
443
+ n = size;
444
+ }
445
+
446
+ llama_sbatch llama_kv_cache_unified::sbatch_init(
447
+ const llama_batch & batch,
448
+ bool logits_all) {
449
+ return llama_sbatch(batch, hparams.n_embd, true, logits_all);
450
+ }
451
+
452
+ llama_ubatch llama_kv_cache_unified::ubatch_next(
453
+ llama_sbatch & sbatch,
454
+ uint32_t n_ubatch,
455
+ bool embd_pooled) const {
456
+ GGML_UNUSED(embd_pooled);
457
+ return sbatch.split_simple(n_ubatch);
458
+ }
459
+
460
+ bool llama_kv_cache_unified::find_slot(
461
+ const llama_ubatch & ubatch) {
462
+ const uint32_t n_tokens = ubatch.n_tokens;
463
+ const uint32_t n_seqs = ubatch.n_seqs;
464
+ const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
465
 
466
+ // if we have enough unused cells before the current head ->
467
+ // better to start searching from the beginning of the cache, hoping to fill it
468
+ if (head > used + 2*ubatch.n_tokens) {
469
+ head = 0;
470
  }
471
 
472
  // otherwise, one cell per token.
 
520
 
521
  pending.ranges.push_back({head, head + n_tokens});
522
 
523
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
524
+ // after enough generations, the benefit from this heuristic disappears
525
+ // if we start defragmenting the cache, the benefit from this will be more important
526
+ n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
527
+
528
+ //printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
529
+
530
  return true;
531
  }
532
 
533
+ int32_t llama_kv_cache_unified::get_n_tokens() const {
534
+ int32_t result = 0;
535
+
536
+ for (uint32_t i = 0; i < size; i++) {
537
+ result += cells[i].seq_id.size();
538
+ }
539
+
540
+ return result;
541
  }
542
 
543
+ int32_t llama_kv_cache_unified::get_used_cells() const {
544
+ return used;
545
+ }
546
 
547
+ bool llama_kv_cache_unified::get_can_shift() const {
548
+ return can_shift;
549
+ }
550
+
551
+ llama_pos llama_kv_cache_unified::get_pos_max() const {
552
+ llama_pos pos_max = -1;
553
+ for (const auto & cell : cells) {
554
+ pos_max = std::max(pos_max, cell.pos);
555
  }
556
 
557
+ return pos_max;
558
+ }
559
+
560
+ size_t llama_kv_cache_unified::total_size() const {
561
+ size_t size = 0;
562
+ for (const auto & buf : bufs) {
563
+ size += ggml_backend_buffer_get_size(buf.get());
564
+ }
565
+
566
+ return size;
567
  }
568
 
569
  size_t llama_kv_cache_unified::size_k_bytes() const {
 
586
  return size_v_bytes;
587
  }
588
 
589
+ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
590
+ const llama_cparams & cparams,
591
+ ggml_context * ctx,
592
+ ggml_tensor * cur,
593
+ ggml_tensor * shift,
594
+ ggml_tensor * factors,
595
+ float freq_base,
596
+ float freq_scale) const {
597
+ const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
598
 
599
+ const auto & yarn_ext_factor = cparams.yarn_ext_factor;
600
+ const auto & yarn_beta_fast = cparams.yarn_beta_fast;
601
+ const auto & yarn_beta_slow = cparams.yarn_beta_slow;
602
 
603
+ const auto & n_rot = hparams.n_rot;
604
+ const auto & rope_type = hparams.rope_type;
605
 
606
+ // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
607
+ // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
608
+ const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
609
 
610
+ ggml_tensor * tmp;
 
611
 
612
+ if (ggml_is_quantized(cur->type)) {
613
+ // dequantize to f32 -> RoPE -> quantize back
614
+ tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
 
 
 
615
 
616
+ tmp = ggml_rope_ext(ctx, tmp,
617
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
618
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
 
 
 
 
619
 
620
+ tmp = ggml_cpy(ctx, tmp, cur);
621
+ } else {
622
+ // we rotate only the first n_rot dimensions
623
+ tmp = ggml_rope_ext_inplace(ctx, cur,
624
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
625
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
626
+ }
627
 
628
+ return tmp;
629
+ }
630
 
631
+ class llm_graph_input_k_shift : public llm_graph_input_i {
632
+ public:
633
+ llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
634
+ virtual ~llm_graph_input_k_shift() = default;
635
 
636
+ void set_input(const llama_ubatch * ubatch) override;
 
637
 
638
+ ggml_tensor * k_shift; // I32 [kv_size]
639
 
640
+ const llama_kv_cache_unified * kv_self;
641
+ };
642
 
643
+ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
644
+ GGML_UNUSED(ubatch);
645
+
646
+ if (k_shift) {
647
+ assert(ggml_backend_buffer_is_host(k_shift->buffer));
648
+
649
+ int32_t * data = (int32_t *) k_shift->data;
650
+
651
+ for (uint32_t i = 0; i < kv_self->size; ++i) {
652
+ data[i] = kv_self->cells[i].delta;
653
  }
654
+ }
655
+ }
656
 
657
+ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
658
+ const llama_cparams & cparams,
659
+ ggml_context * ctx,
660
+ ggml_cgraph * gf) const {
661
+ auto res = std::make_unique<llm_graph_result>();
662
 
663
+ const auto & n_layer = hparams.n_layer;
 
 
664
 
665
+ const auto & n_embd_head_k = hparams.n_embd_head_k;
666
+ //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
667
 
668
+ const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
669
+
670
+ //GGML_ASSERT(kv_self->size == n_ctx);
671
+
672
+ auto inp = std::make_unique<llm_graph_input_k_shift>(this);
673
+
674
+ inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
675
+ ggml_set_input(inp->k_shift);
676
+
677
+ for (uint32_t il = 0; il < n_layer; ++il) {
678
+ const int64_t n_head_kv = hparams.n_head_kv(il);
679
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
680
+
681
+ const bool is_swa = hparams.is_swa(il);
682
+
683
+ // note: the swa rope params could become part of the cparams in the future
684
+ // if we decide to make them configurable, like the non-sliding ones
685
+ const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
686
+ const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
687
+
688
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
689
+
690
+ ggml_tensor * k =
691
+ ggml_view_3d(ctx, k_l[il],
692
+ n_embd_head_k, n_head_kv, size,
693
+ ggml_row_size(k_l[il]->type, n_embd_head_k),
694
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa),
695
+ 0);
696
+
697
+ ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
698
+
699
+ ggml_build_forward_expand(gf, cur);
700
+ }
701
+
702
+ res->add_input(std::move(inp));
703
+
704
+ return res;
705
+ }
706
+
707
+ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
708
+ const llama_cparams & cparams,
709
+ ggml_context * ctx,
710
+ ggml_cgraph * gf) const {
711
+ auto res = std::make_unique<llm_graph_result>();
712
+
713
+ const auto & ids = defrag_info.ids;
714
+
715
+ #if 0
716
+ // CPU defrag
717
+ //
718
+ // TODO: optimizations are possible:
719
+ // - multiple threads
720
+ // - avoid copying to the host memory when already there
721
+ //
722
+ // likely not worth the effort, as we have ggml_graph based defrag
723
+ //
724
+
725
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
726
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
727
+
728
+ const uint32_t kv_size = size;
729
+
730
+ std::vector<uint8_t> buf_k;
731
+ std::vector<uint8_t> buf_v;
732
+
733
+ for (uint32_t il = 0; il < n_layer; ++il) {
734
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
735
+ const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
736
+
737
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
738
+ const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
739
+
740
+ buf_k.resize(k_size);
741
+ buf_v.resize(v_size);
742
+
743
+ ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
744
+ ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
745
+
746
+ // batch move [i, i+nm) to [id, id+nm)
747
+ // note: cells can move only to a lower index
748
+ for (uint32_t i = 0; i < n_kv; ++i) {
749
+ const uint32_t id = ids[i];
750
+
751
+ if (i == id || id == n_kv) {
752
+ continue;
753
+ }
754
+
755
+ uint32_t nm = 1;
756
+
757
+ while (i + nm < n_kv && ids[i + nm] == id + nm) {
758
+ nm++;
759
+ }
760
+
761
+ // move keys
762
+ {
763
+ const int64_t os = i*k_size_row;
764
+ const int64_t od = id*k_size_row;
765
+
766
+ memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
767
+ }
768
+
769
+ // move values (note: they are transposed)
770
+ {
771
+ const int64_t os = i;
772
+ const int64_t od = id;
773
+
774
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
775
+ memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
776
+ }
777
+ }
778
+
779
+ i += nm - 1;
780
+ }
781
+
782
+ ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
783
+ ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
784
+ }
785
+ #else
786
+ for (uint32_t i = 0; i < ids.size(); ++i) {
787
+ const uint32_t id = ids[i];
788
+
789
+ if (i == id || id == ids.size()) {
790
+ continue;
791
+ }
792
+
793
+ uint32_t nm = 1;
794
+
795
+ while (i + nm < ids.size() && ids[i + nm] == id + nm) {
796
+ nm++;
797
+ }
798
+
799
+ for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
800
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
801
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
802
+
803
+ ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il],
804
+ n_embd_k_gqa, nm,
805
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa),
806
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
807
+
808
+ ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il],
809
+ n_embd_k_gqa, nm,
810
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa),
811
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
812
+
813
+ ggml_tensor * view_v_src;
814
+ ggml_tensor * view_v_dst;
815
+
816
+ if (cparams.flash_attn) {
817
+ // NOTE: the V cache is not transposed when using flash attention
818
+ view_v_src = ggml_view_2d(ctx, v_l[il],
819
+ n_embd_v_gqa, nm,
820
+ ggml_row_size(v_l[il]->type, n_embd_v_gqa),
821
+ ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
822
+
823
+ view_v_dst = ggml_view_2d(ctx, v_l[il],
824
+ n_embd_v_gqa, nm,
825
+ ggml_row_size(v_l[il]->type, n_embd_v_gqa),
826
+ ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
827
+ } else {
828
+ view_v_src = ggml_view_2d(ctx, v_l[il],
829
+ nm, n_embd_v_gqa,
830
+ ggml_row_size(v_l[il]->type, size),
831
+ ggml_row_size(v_l[il]->type, i));
832
+
833
+ view_v_dst = ggml_view_2d(ctx, v_l[il],
834
+ nm, n_embd_v_gqa,
835
+ ggml_row_size(v_l[il]->type, size),
836
+ ggml_row_size(v_l[il]->type, id));
837
+ }
838
+
839
+ ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
840
+ ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
841
+ }
842
+
843
+ i += nm - 1;
844
+ }
845
+
846
+ //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
847
+ #endif
848
+
849
+ return res;
850
+ }
851
+
852
+ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
853
+ const uint32_t n_layer = hparams.n_layer;
854
+
855
+ const uint32_t n_kv = cell_max();
856
+ const uint32_t n_used = used;
857
+
858
+ assert(n_used <= n_kv);
859
+
860
+ //const int64_t t_start = ggml_time_us();
861
+
862
+ // number of cells moved
863
+ uint32_t n_moves = 0;
864
+
865
+ // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
866
+ // - source view, destination view, copy operation
867
+ // - x2 for keys and values
868
+ //const uint32_t max_moves = max_nodes()/(6*n_layer);
869
+ // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
870
+ const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
871
+
872
+ // determine which KV cells to move where
873
+ //
874
+ // cell i moves to ids[i]
875
+ //
876
+ // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
877
+ //
878
+ auto & ids = defrag_info.ids;
879
+
880
+ ids.clear();
881
+ ids.resize(n_kv, n_kv);
882
+
883
+ for (uint32_t i0 = 0; i0 < n_used; ++i0) {
884
+ const auto & cell0 = cells[i0];
885
+
886
+ if (!cell0.is_empty()) {
887
+ ids[i0] = i0;
888
+
889
+ continue;
890
+ }
891
+
892
+ // found a hole - fill it with data from the end of the cache
893
+
894
+ uint32_t nh = 1;
895
+
896
+ // determine the size of the hole
897
+ while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
898
+ nh++;
899
+ }
900
+
901
+ uint32_t nf = 0;
902
+ uint32_t is = n_kv - 1;
903
+
904
+ // starting from the end, find nh non-empty cells
905
+ for (; is > i0; --is) {
906
+ const auto & cell1 = cells[is];
907
+
908
+ if (cell1.is_empty() || ids[is] != n_kv) {
909
+ continue;
910
+ }
911
+
912
+ // non-empty cell which is not yet moved
913
+ nf++;
914
 
915
  if (nf == nh) {
916
  break;
 
951
  cells[i0 + nf] = cell1;
952
 
953
  // clear the old cell and move the head there
954
+ cell1 = kv_cell();
955
  head = n_used;
956
 
957
  if (!cont) {
 
979
  return false;
980
  }
981
 
982
+ LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
983
 
984
+ LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
985
 
986
  return true;
987
  }
988
 
989
+ uint32_t llama_kv_cache_unified::cell_max() const {
990
+ for (uint32_t i = size; i > 0; --i) {
991
+ const kv_cell & cell = cells[i - 1];
992
+
993
+ if (cell.pos >= 0 && !cell.is_empty()) {
994
+ return i;
995
+ }
996
+ }
997
+
998
+ return 0;
999
+ }
1000
+
1001
  void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1002
  std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1003
  uint32_t cell_count = 0;
 
1206
  clear();
1207
 
1208
  for (uint32_t i = 0; i < cell_count; ++i) {
1209
+ kv_cell & cell = cells[i];
1210
 
1211
  llama_pos pos;
1212
  uint32_t n_seq_id;
 
1229
  }
1230
 
1231
  cell.seq_id.insert(seq_id);
 
 
 
 
 
 
 
 
 
1232
  }
1233
  }
1234
 
 
1236
  used = cell_count;
1237
  }
1238
 
 
 
 
 
 
 
 
 
1239
  return true;
1240
  }
1241
 
 
1253
  LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
1254
  return false;
1255
  }
1256
+ if (this->v_trans != (bool) v_trans) {
1257
+ LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1258
+ return false;
1259
+ }
1260
+
1261
+ // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1262
+ for (uint32_t il = 0; il < n_layer; ++il) {
1263
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1264
+
1265
+ // Read type of key
1266
+ int32_t k_type_i_ref;
1267
+ io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1268
+ const int32_t k_type_i = (int32_t) k_l[il]->type;
1269
+ if (k_type_i != k_type_i_ref) {
1270
+ LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1271
+ return false;
1272
+ }
1273
+
1274
+ // Read row size of key
1275
+ uint64_t k_size_row_ref;
1276
+ io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1277
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1278
+ if (k_size_row != k_size_row_ref) {
1279
+ LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1280
+ return false;
1281
+ }
1282
+
1283
+ if (cell_count) {
1284
+ // Read and set the keys for the whole cell range
1285
+ ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1286
+ }
1287
+ }
1288
+
1289
+ if (!this->v_trans) {
1290
+ for (uint32_t il = 0; il < n_layer; ++il) {
1291
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1292
+
1293
+ // Read type of value
1294
+ int32_t v_type_i_ref;
1295
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1296
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1297
+ if (v_type_i != v_type_i_ref) {
1298
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1299
+ return false;
1300
+ }
1301
+
1302
+ // Read row size of value
1303
+ uint64_t v_size_row_ref;
1304
+ io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1305
+ const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1306
+ if (v_size_row != v_size_row_ref) {
1307
+ LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1308
+ return false;
1309
+ }
1310
+
1311
+ if (cell_count) {
1312
+ // Read and set the values for the whole cell range
1313
+ ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1314
+ }
1315
+ }
1316
+ } else {
1317
+ // For each layer, read the values for each cell (transposed)
1318
+ for (uint32_t il = 0; il < n_layer; ++il) {
1319
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1320
+
1321
+ // Read type of value
1322
+ int32_t v_type_i_ref;
1323
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1324
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1325
+ if (v_type_i != v_type_i_ref) {
1326
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1327
+ return false;
1328
+ }
1329
+
1330
+ // Read element size of value
1331
+ uint32_t v_size_el_ref;
1332
+ io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1333
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
1334
+ if (v_size_el != v_size_el_ref) {
1335
+ LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1336
+ return false;
1337
+ }
1338
+
1339
+ // Read GQA embedding size
1340
+ uint32_t n_embd_v_gqa_ref;
1341
+ io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1342
+ if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1343
+ LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1344
+ return false;
1345
+ }
1346
+
1347
+ if (cell_count) {
1348
+ // For each row in the transposed matrix, read the values for the whole cell range
1349
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1350
+ const size_t dst_offset = (head + j * size) * v_size_el;
1351
+ ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1352
+ }
1353
+ }
1354
+ }
1355
+ }
1356
+
1357
+ return true;
1358
+ }
1359
+
1360
+ //
1361
+ // llama_kv_cache_recurrent
1362
+ //
1363
+
1364
+ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1365
+ const llama_model & model,
1366
+ ggml_type type_k,
1367
+ ggml_type type_v,
1368
+ bool offload,
1369
+ uint32_t kv_size) : hparams(model.hparams) {
1370
+ const int32_t n_layer = hparams.n_layer;
1371
+
1372
+ LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
1373
+ __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
1374
+
1375
+ head = 0;
1376
+ size = kv_size;
1377
+ used = 0;
1378
+
1379
+ this->type_k = type_k;
1380
+ this->type_v = type_v;
1381
+
1382
+ cells.clear();
1383
+ cells.resize(kv_size);
1384
+
1385
+ // create a context for each buffer type
1386
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
1387
+ auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
1388
+ auto it = ctx_map.find(buft);
1389
+ if (it == ctx_map.end()) {
1390
+ ggml_init_params params = {
1391
+ /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
1392
+ /*.mem_buffer =*/ NULL,
1393
+ /*.no_alloc =*/ true,
1394
+ };
1395
+
1396
+ ggml_context * ctx = ggml_init(params);
1397
+ if (!ctx) {
1398
+ return nullptr;
1399
+ }
1400
+
1401
+ ctx_map[buft] = ctx;
1402
+ ctxs.emplace_back(ctx);
1403
+
1404
+ return ctx;
1405
+ }
1406
+
1407
+ return it->second;
1408
+ };
1409
+
1410
+ k_l.reserve(n_layer);
1411
+ v_l.reserve(n_layer);
1412
+
1413
+ for (int i = 0; i < n_layer; i++) {
1414
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
1415
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
1416
+
1417
+ const char * dev_name = "CPU";
1418
+
1419
+ ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
1420
+
1421
+ if (offload) {
1422
+ auto * dev = model.dev_layer(i);
1423
+ buft = ggml_backend_dev_buffer_type(dev);
1424
+
1425
+ dev_name = ggml_backend_dev_name(dev);
1426
+ }
1427
+
1428
+ LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
1429
+
1430
+ ggml_context * ctx = ctx_for_buft(buft);
1431
+ if (!ctx) {
1432
+ throw std::runtime_error("failed to create ggml context for kv cache");
1433
+ }
1434
+
1435
+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
1436
+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
1437
+ ggml_format_name(k, "cache_k_l%d", i);
1438
+ ggml_format_name(v, "cache_v_l%d", i);
1439
+ k_l.push_back(k);
1440
+ v_l.push_back(v);
1441
+ }
1442
+
1443
+ // allocate tensors and initialize the buffers to avoid NaNs in the padding
1444
+ for (auto it : ctx_map) {
1445
+ auto * buft = it.first;
1446
+ auto * ctx = it.second;
1447
+
1448
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
1449
+ if (!buf) {
1450
+ throw std::runtime_error("failed to allocate buffer for kv cache");
1451
+ }
1452
+ ggml_backend_buffer_clear(buf, 0);
1453
+ LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
1454
+ bufs.emplace_back(buf);
1455
+ }
1456
+
1457
+ {
1458
+ const size_t memory_size_k = size_k_bytes();
1459
+ const size_t memory_size_v = size_v_bytes();
1460
+
1461
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
1462
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
1463
+ ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
1464
+ ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
1465
+ }
1466
+ }
1467
+
1468
+ void llama_kv_cache_recurrent::clear() {
1469
+ for (int32_t i = 0; i < (int32_t) size; ++i) {
1470
+ cells[i].pos = -1;
1471
+ cells[i].seq_id.clear();
1472
+ cells[i].src = -1;
1473
+ cells[i].tail = -1;
1474
+ }
1475
+ head = 0;
1476
+ used = 0;
1477
+
1478
+ for (auto & buf : bufs) {
1479
+ ggml_backend_buffer_clear(buf.get(), 0);
1480
+ }
1481
+ }
1482
+
1483
+ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1484
+ uint32_t new_head = size;
1485
+
1486
+ if (p0 < 0) {
1487
+ p0 = 0;
1488
+ }
1489
+
1490
+ if (p1 < 0) {
1491
+ p1 = std::numeric_limits<llama_pos>::max();
1492
+ }
1493
+
1494
+ // models like Mamba or RWKV can't have a state partially erased
1495
+ if (seq_id >= (int64_t) size) {
1496
+ // could be fatal
1497
+ return false;
1498
+ }
1499
+ if (0 <= seq_id) {
1500
+ int32_t & tail_id = cells[seq_id].tail;
1501
+ if (tail_id >= 0) {
1502
+ const kv_cell & cell = cells[tail_id];
1503
+ // partial intersection is invalid
1504
+ if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1505
+ return false;
1506
+ }
1507
+ // invalidate tails which will be cleared
1508
+ if (p0 <= cell.pos && cell.pos < p1) {
1509
+ tail_id = -1;
1510
+ }
1511
+ }
1512
+ } else {
1513
+ // seq_id is negative, then the range should include everything or nothing
1514
+ if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1515
+ return false;
1516
+ }
1517
+ }
1518
+
1519
+ for (uint32_t i = 0; i < size; ++i) {
1520
+ if (cells[i].pos >= p0 && cells[i].pos < p1) {
1521
+ if (seq_id < 0) {
1522
+ cells[i].seq_id.clear();
1523
+ } else if (cells[i].has_seq_id(seq_id)) {
1524
+ cells[i].seq_id.erase(seq_id);
1525
+ } else {
1526
+ continue;
1527
+ }
1528
+ if (cells[i].is_empty()) {
1529
+ // keep count of the number of used cells
1530
+ if (cells[i].pos >= 0) {
1531
+ used--;
1532
+ }
1533
+ cells[i].pos = -1;
1534
+ cells[i].src = -1;
1535
+ if (new_head == size) {
1536
+ new_head = i;
1537
+ }
1538
+ }
1539
+ }
1540
+ }
1541
+
1542
+ // If we freed up a slot, set head to it so searching can start there.
1543
+ if (new_head != size && new_head < head) {
1544
+ head = new_head;
1545
+ }
1546
+
1547
+ return true;
1548
+ }
1549
+
1550
+ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
1551
+ if (seq_id_src == seq_id_dst) {
1552
+ return;
1553
+ }
1554
+
1555
+ if (p0 < 0) {
1556
+ p0 = 0;
1557
+ }
1558
+
1559
+ if (p1 < 0) {
1560
+ p1 = std::numeric_limits<llama_pos>::max();
1561
+ }
1562
+
1563
+ if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
1564
+ kv_cell & tail_src = cells[seq_id_src];
1565
+ kv_cell & tail_dst = cells[seq_id_dst];
1566
+ if (tail_dst.tail >= 0) {
1567
+ // clear destination seq_id if it wasn't empty
1568
+ kv_cell & cell_dst = cells[tail_dst.tail];
1569
+
1570
+ cell_dst.seq_id.erase(seq_id_dst);
1571
+ tail_dst.tail = -1;
1572
+ if (cell_dst.seq_id.empty()) {
1573
+ cell_dst.pos = -1;
1574
+ cell_dst.src = -1;
1575
+ used -= 1;
1576
+ }
1577
+ }
1578
+ if (tail_src.tail >= 0) {
1579
+ kv_cell & cell_src = cells[tail_src.tail];
1580
+
1581
+ cell_src.seq_id.insert(seq_id_dst);
1582
+ tail_dst.tail = tail_src.tail;
1583
+ }
1584
+ }
1585
+ }
1586
+
1587
+ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
1588
+ uint32_t new_head = size;
1589
+
1590
+ for (uint32_t i = 0; i < size; ++i) {
1591
+ if ((llama_seq_id) i != seq_id) {
1592
+ cells[i].tail = -1;
1593
+ }
1594
+
1595
+ if (!cells[i].has_seq_id(seq_id)) {
1596
+ if (cells[i].pos >= 0) {
1597
+ used--;
1598
+ }
1599
+
1600
+ cells[i].pos = -1;
1601
+ cells[i].src = -1;
1602
+ cells[i].seq_id.clear();
1603
+
1604
+ if (new_head == size){
1605
+ new_head = i;
1606
+ }
1607
+ } else {
1608
+ cells[i].seq_id.clear();
1609
+ cells[i].seq_id.insert(seq_id);
1610
+ }
1611
+ }
1612
+
1613
+ // If we freed up a slot, set head to it so searching can start there.
1614
+ if (new_head != size && new_head < head) {
1615
+ head = new_head;
1616
+ }
1617
+ }
1618
+
1619
+ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
1620
+ if (delta == 0) {
1621
+ return;
1622
+ }
1623
+
1624
+ if (p0 < 0) {
1625
+ p0 = 0;
1626
+ }
1627
+
1628
+ if (p1 < 0) {
1629
+ p1 = std::numeric_limits<llama_pos>::max();
1630
+ }
1631
+
1632
+ // If there is no range then return early to avoid looping over the
1633
+ if (p0 == p1) {
1634
+ return;
1635
+ }
1636
+
1637
+ // for Mamba-like or RWKV models, only the pos needs to be shifted
1638
+ if (0 <= seq_id && seq_id < (int64_t) size) {
1639
+ const int32_t tail_id = cells[seq_id].tail;
1640
+ if (tail_id >= 0) {
1641
+ kv_cell & cell = cells[tail_id];
1642
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
1643
+ cell.pos += delta;
1644
+ }
1645
+ }
1646
+ }
1647
+ }
1648
+
1649
+ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
1650
+ if (d == 1) {
1651
+ return;
1652
+ }
1653
+
1654
+ if (p0 < 0) {
1655
+ p0 = 0;
1656
+ }
1657
+
1658
+ if (p1 < 0) {
1659
+ p1 = std::numeric_limits<llama_pos>::max();
1660
+ }
1661
+
1662
+ // If there is no range then return early to avoid looping over the cache.
1663
+ if (p0 == p1) {
1664
+ return;
1665
+ }
1666
+
1667
+ // for Mamba-like or RWKV models, only the pos needs to be changed
1668
+ if (0 <= seq_id && seq_id < (int64_t) size) {
1669
+ const int32_t tail_id = cells[seq_id].tail;
1670
+ if (tail_id >= 0) {
1671
+ kv_cell & cell = cells[tail_id];
1672
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
1673
+ cell.pos /= d;
1674
+ }
1675
+ }
1676
+ }
1677
+ }
1678
+
1679
+ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
1680
+ llama_pos result = 0;
1681
+
1682
+ for (uint32_t i = 0; i < size; ++i) {
1683
+ if (cells[i].has_seq_id(seq_id)) {
1684
+ result = std::max(result, cells[i].pos);
1685
+ }
1686
+ }
1687
+
1688
+ return result;
1689
+ }
1690
+
1691
+ void llama_kv_cache_recurrent::restore() {
1692
+ if (pending.ranges.empty()) {
1693
+ return;
1694
+ }
1695
+
1696
+ seq_rm(-1, -1, -1);
1697
+ }
1698
+
1699
+ void llama_kv_cache_recurrent::commit() {
1700
+ pending.ranges.clear();
1701
+ }
1702
+
1703
+ bool llama_kv_cache_recurrent::update(llama_context & lctx) {
1704
+ GGML_UNUSED(lctx);
1705
+ return false;
1706
+ }
1707
+
1708
+ void llama_kv_cache_recurrent::defrag_sched(float thold) {
1709
+ GGML_UNUSED(thold);
1710
+ // noop
1711
+ }
1712
+
1713
+ void llama_kv_cache_recurrent::set_full() {
1714
+ n = size;
1715
+ }
1716
+
1717
+ llama_sbatch llama_kv_cache_recurrent::sbatch_init(
1718
+ const llama_batch & batch,
1719
+ bool logits_all) {
1720
+ return llama_sbatch(batch, hparams.n_embd, false, logits_all);
1721
+ }
1722
+
1723
+ llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1724
+ if (embd_pooled) {
1725
+ // Pooled embeddings cannot be split across ubatches (yet)
1726
+ return sbatch.split_seq(n_ubatch);
1727
+ }
1728
+
1729
+ return sbatch.split_equal(n_ubatch);
1730
+ }
1731
+
1732
+ bool llama_kv_cache_recurrent::find_slot(
1733
+ const llama_ubatch & ubatch) {
1734
+ const uint32_t n_tokens = ubatch.n_tokens;
1735
+ const uint32_t n_seqs = ubatch.n_seqs;
1736
+
1737
+ const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
1738
+
1739
+ // if we have enough unused cells before the current head ->
1740
+ // better to start searching from the beginning of the cache, hoping to fill it
1741
+ if (head > used + 2*n_tokens) {
1742
+ head = 0;
1743
+ }
1744
+
1745
+ // For recurrent state architectures (like Mamba or RWKV),
1746
+ // each cache cell can store the state for a whole sequence.
1747
+ // A slot should be always be contiguous.
1748
+
1749
+ // can only process batches with an equal number of new tokens in each sequence
1750
+ GGML_ASSERT(ubatch.equal_seqs);
1751
+
1752
+ int32_t min = size - 1;
1753
+ int32_t max = 0;
1754
+
1755
+ // everything should fit if all seq_ids are smaller than the max
1756
+ for (uint32_t s = 0; s < n_seqs; ++s) {
1757
+ const uint32_t n_seq_id = ubatch.n_seq_id[s];
1758
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
1759
+ const llama_seq_id seq_id = ubatch.seq_id[s][j];
1760
+
1761
+ if (seq_id < 0 || (uint32_t) seq_id >= size) {
1762
+ // too big seq_id
1763
+ // TODO: would it be possible to resize the cache instead?
1764
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
1765
+ return false;
1766
+ }
1767
+ if (j > 0) {
1768
+ kv_cell & seq = cells[seq_id];
1769
+ if (seq.tail >= 0) {
1770
+ kv_cell & cell = cells[seq.tail];
1771
+ // clear cells from seq_ids that become shared
1772
+ // (should not normally happen, but let's handle it anyway)
1773
+ cell.seq_id.erase(seq_id);
1774
+ seq.tail = -1;
1775
+ if (cell.seq_id.empty()) {
1776
+ cell.pos = -1;
1777
+ cell.src = -1;
1778
+ used -= 1;
1779
+ }
1780
+ }
1781
+ }
1782
+ }
1783
+ }
1784
+
1785
+ #ifndef NDEBUG
1786
+ {
1787
+ std::vector<int32_t> tails_verif;
1788
+ tails_verif.assign(size, -1);
1789
+ for (uint32_t i = 0; i < size; ++i) {
1790
+ kv_cell & cell = cells[i];
1791
+ for (llama_seq_id seq_id : cell.seq_id) {
1792
+ if (tails_verif[seq_id] != -1) {
1793
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
1794
+ }
1795
+ tails_verif[seq_id] = i;
1796
+ }
1797
+ }
1798
+ for (uint32_t i = 0; i < size; ++i) {
1799
+ if (tails_verif[i] != cells[i].tail) {
1800
+ LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
1801
+ }
1802
+ }
1803
+ }
1804
+ #endif
1805
+
1806
+ // find next empty cell
1807
+ uint32_t next_empty_cell = head;
1808
+
1809
+ for (uint32_t i = 0; i < size; ++i) {
1810
+ if (next_empty_cell >= size) { next_empty_cell -= size; }
1811
+ kv_cell & cell = cells[next_empty_cell];
1812
+ if (cell.is_empty()) { break; }
1813
+ next_empty_cell += 1;
1814
+ }
1815
+
1816
+ // find usable cell range
1817
+ for (uint32_t s = 0; s < n_seqs; ++s) {
1818
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
1819
+ kv_cell & seq_meta = cells[seq_id];
1820
+ bool has_cell = false;
1821
+ if (seq_meta.tail >= 0) {
1822
+ kv_cell & cell = cells[seq_meta.tail];
1823
+ GGML_ASSERT(cell.has_seq_id(seq_id));
1824
+ // does this seq_id "own" the cell?
1825
+ if (cell.seq_id.size() == 1) { has_cell = true; }
1826
+ }
1827
+ if (!has_cell) {
1828
+ kv_cell & empty_cell = cells[next_empty_cell];
1829
+ GGML_ASSERT(empty_cell.is_empty());
1830
+ // copy old tail into the empty cell
1831
+ if (seq_meta.tail >= 0) {
1832
+ kv_cell & orig_cell = cells[seq_meta.tail];
1833
+ empty_cell.pos = orig_cell.pos;
1834
+ empty_cell.src = orig_cell.src;
1835
+ orig_cell.seq_id.erase(seq_id);
1836
+ empty_cell.seq_id.insert(seq_id); // will be overwritten
1837
+ }
1838
+ seq_meta.tail = next_empty_cell;
1839
+ // find next empty cell
1840
+ if (s + 1 < n_seqs) {
1841
+ next_empty_cell += 1;
1842
+ for (uint32_t i = 0; i < size; ++i) {
1843
+ if (next_empty_cell >= size) { next_empty_cell -= size; }
1844
+ kv_cell & cell = cells[next_empty_cell];
1845
+ if (cell.is_empty()) { break; }
1846
+ next_empty_cell += 1;
1847
+ }
1848
+ }
1849
+ }
1850
+ if (min > seq_meta.tail) { min = seq_meta.tail; }
1851
+ if (max < seq_meta.tail) { max = seq_meta.tail; }
1852
+ }
1853
+
1854
+ // gather and re-order
1855
+ for (uint32_t s = 0; s < n_seqs; ++s) {
1856
+ int32_t dst_id = s + min;
1857
+ int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
1858
+ if (dst_id != src_id) {
1859
+ kv_cell & dst_cell = cells[dst_id];
1860
+ kv_cell & src_cell = cells[src_id];
1861
+
1862
+ std::swap(dst_cell.pos, src_cell.pos);
1863
+ std::swap(dst_cell.src, src_cell.src);
1864
+ std::swap(dst_cell.seq_id, src_cell.seq_id);
1865
+
1866
+ // swap tails (assuming they NEVER overlap)
1867
+ for (const llama_seq_id seq_id : src_cell.seq_id) {
1868
+ cells[seq_id].tail = src_id;
1869
+ }
1870
+ for (const llama_seq_id seq_id : dst_cell.seq_id) {
1871
+ cells[seq_id].tail = dst_id;
1872
+ }
1873
+ }
1874
+ }
1875
+
1876
+ // update the pos of the used seqs
1877
+ for (uint32_t s = 0; s < n_seqs; ++s) {
1878
+ const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
1879
+ int32_t cell_id = s + min;
1880
+ kv_cell & cell = cells[cell_id];
1881
+
1882
+ if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
1883
+ // What should happen when the pos backtracks or skips a value?
1884
+ // Clearing the state mid-batch would require special-casing which isn't done.
1885
+ LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
1886
+ __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
1887
+ }
1888
+ cell.pos = last_pos;
1889
+ cell.seq_id.clear();
1890
+ for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
1891
+ const llama_seq_id seq_id = ubatch.seq_id[s][j];
1892
+ cell.seq_id.insert(seq_id);
1893
+ cells[seq_id].tail = cell_id;
1894
+ }
1895
+ }
1896
+
1897
+ // allow getting the range of used cells, from head to head + n
1898
+ head = min;
1899
+ n = max - min + 1;
1900
+ used = std::count_if(cells.begin(), cells.end(),
1901
+ [](const kv_cell & cell){ return !cell.is_empty(); });
1902
+
1903
+ // sanity check
1904
+ return n >= n_seqs;
1905
+ }
1906
+
1907
+ int32_t llama_kv_cache_recurrent::get_n_tokens() const {
1908
+ int32_t result = 0;
1909
+
1910
+ for (uint32_t i = 0; i < size; i++) {
1911
+ result += cells[i].seq_id.size();
1912
+ }
1913
+
1914
+ return result;
1915
+ }
1916
+
1917
+ int32_t llama_kv_cache_recurrent::get_used_cells() const {
1918
+ return used;
1919
+ }
1920
+
1921
+ llama_pos llama_kv_cache_recurrent::get_pos_max() const {
1922
+ llama_pos pos_max = -1;
1923
+ for (const auto & cell : cells) {
1924
+ pos_max = std::max(pos_max, cell.pos);
1925
+ }
1926
+
1927
+ return pos_max;
1928
+ }
1929
+
1930
+ bool llama_kv_cache_recurrent::get_can_shift() const {
1931
+ return false;
1932
+ }
1933
+
1934
+ int32_t llama_kv_cache_recurrent::s_copy(int i) const {
1935
+ const uint32_t cell_id = i + head;
1936
+
1937
+ //////////////////////////////////////////////
1938
+ // TODO: this should not mutate the KV cache !
1939
+ kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
1940
+
1941
+ // prevent out-of-bound sources
1942
+ if (cell.src < 0 || (uint32_t) cell.src >= size) {
1943
+ cell.src = cell_id;
1944
+ }
1945
+
1946
+ int32_t res = cell.src;
1947
+
1948
+ // TODO: do not mutate the KV cache
1949
+ // ensure copy only happens once
1950
+ if (cell.src != (int32_t) cell_id) {
1951
+ cell.src = cell_id;
1952
+ }
1953
+
1954
+ return res;
1955
+ }
1956
+
1957
+ float llama_kv_cache_recurrent::s_mask(int i) const {
1958
+ const uint32_t cell_id = i + head;
1959
+
1960
+ //////////////////////////////////////////////
1961
+ // TODO: this should not mutate the KV cache !
1962
+ kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
1963
+
1964
+ float res = (float) (cell.src >= 0);
1965
+
1966
+ // only clear once
1967
+ if (cell.src < 0) {
1968
+ cell.src = cell_id;
1969
+ }
1970
+
1971
+ return res;
1972
+ }
1973
+
1974
+ uint32_t llama_kv_cache_recurrent::cell_max() const {
1975
+ for (uint32_t i = size; i > 0; --i) {
1976
+ const kv_cell & cell = cells[i - 1];
1977
+
1978
+ if (cell.pos >= 0 && !cell.is_empty()) {
1979
+ return i;
1980
+ }
1981
+ }
1982
+
1983
+ return 0;
1984
+ }
1985
+
1986
+ size_t llama_kv_cache_recurrent::total_size() const {
1987
+ size_t size = 0;
1988
+ for (const auto & buf : bufs) {
1989
+ size += ggml_backend_buffer_get_size(buf.get());
1990
+ }
1991
+
1992
+ return size;
1993
+ }
1994
+
1995
+ size_t llama_kv_cache_recurrent::size_k_bytes() const {
1996
+ size_t size_k_bytes = 0;
1997
+
1998
+ for (const auto & k : k_l) {
1999
+ size_k_bytes += ggml_nbytes(k);
2000
+ }
2001
+
2002
+ return size_k_bytes;
2003
+ }
2004
+
2005
+ size_t llama_kv_cache_recurrent::size_v_bytes() const {
2006
+ size_t size_v_bytes = 0;
2007
+
2008
+ for (const auto & v : v_l) {
2009
+ size_v_bytes += ggml_nbytes(v);
2010
+ }
2011
+
2012
+ return size_v_bytes;
2013
+ }
2014
+
2015
+ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
2016
+ std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
2017
+ uint32_t cell_count = 0;
2018
+
2019
+ // Count the number of cells with the specified seq_id
2020
+ // Find all the ranges of cells with this seq id (or all, when -1)
2021
+ uint32_t cell_range_begin = size;
2022
+ for (uint32_t i = 0; i < size; ++i) {
2023
+ const auto & cell = cells[i];
2024
+ if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
2025
+ ++cell_count;
2026
+ if (cell_range_begin == size) {
2027
+ cell_range_begin = i;
2028
+ }
2029
+ } else {
2030
+ if (cell_range_begin != size) {
2031
+ cell_ranges.emplace_back(cell_range_begin, i);
2032
+ cell_range_begin = size;
2033
+ }
2034
+ }
2035
+ }
2036
+ if (cell_range_begin != size) {
2037
+ cell_ranges.emplace_back(cell_range_begin, size);
2038
+ }
2039
+
2040
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
2041
+ uint32_t cell_count_check = 0;
2042
+ for (const auto & range : cell_ranges) {
2043
+ cell_count_check += range.second - range.first;
2044
+ }
2045
+ GGML_ASSERT(cell_count == cell_count_check);
2046
+
2047
+ io.write(&cell_count, sizeof(cell_count));
2048
+
2049
+ state_write_meta(io, cell_ranges, seq_id);
2050
+ state_write_data(io, cell_ranges);
2051
+ }
2052
+
2053
+ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
2054
+ uint32_t cell_count;
2055
+ io.read_to(&cell_count, sizeof(cell_count));
2056
+
2057
+ bool res = true;
2058
+ res = res && state_read_meta(io, cell_count, seq_id);
2059
+ res = res && state_read_data(io, cell_count);
2060
+
2061
+ if (!res) {
2062
+ if (seq_id == -1) {
2063
+ clear();
2064
+ } else {
2065
+ seq_rm(seq_id, -1, -1);
2066
+ }
2067
+ throw std::runtime_error("failed to restore kv cache");
2068
+ }
2069
+ }
2070
+
2071
+ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
2072
+ for (const auto & range : cell_ranges) {
2073
+ for (uint32_t i = range.first; i < range.second; ++i) {
2074
+ const auto & cell = cells[i];
2075
+ const llama_pos pos = cell.pos;
2076
+ const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
2077
+
2078
+ io.write(&pos, sizeof(pos));
2079
+ io.write(&n_seq_id, sizeof(n_seq_id));
2080
+
2081
+ if (n_seq_id) {
2082
+ for (auto seq_id : cell.seq_id) {
2083
+ io.write(&seq_id, sizeof(seq_id));
2084
+ }
2085
+ }
2086
+ }
2087
+ }
2088
+ }
2089
+
2090
+ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
2091
+ const uint32_t v_trans = 0;
2092
+ const uint32_t n_layer = hparams.n_layer;
2093
+
2094
+ io.write(&v_trans, sizeof(v_trans));
2095
+ io.write(&n_layer, sizeof(n_layer));
2096
+
2097
+ std::vector<uint8_t> tmp_buf;
2098
+
2099
+ // Iterate and write all the keys first, each row is a cell
2100
+ // Get whole range at a time
2101
+ for (uint32_t il = 0; il < n_layer; ++il) {
2102
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2103
+
2104
+ // Write key type
2105
+ const int32_t k_type_i = (int32_t)k_l[il]->type;
2106
+ io.write(&k_type_i, sizeof(k_type_i));
2107
+
2108
+ // Write row size of key
2109
+ const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
2110
+ io.write(&k_size_row, sizeof(k_size_row));
2111
+
2112
+ // Read each range of cells of k_size length each into tmp_buf and write out
2113
+ for (const auto & range : cell_ranges) {
2114
+ const size_t range_size = range.second - range.first;
2115
+ const size_t buf_size = range_size * k_size_row;
2116
+ io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
2117
+ }
2118
+ }
2119
+
2120
+ if (!v_trans) {
2121
+ for (uint32_t il = 0; il < n_layer; ++il) {
2122
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2123
+
2124
+ // Write value type
2125
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
2126
+ io.write(&v_type_i, sizeof(v_type_i));
2127
+
2128
+ // Write row size of value
2129
+ const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
2130
+ io.write(&v_size_row, sizeof(v_size_row));
2131
+
2132
+ // Read each range of cells of v_size length each into tmp_buf and write out
2133
+ for (const auto & range : cell_ranges) {
2134
+ const size_t range_size = range.second - range.first;
2135
+ const size_t buf_size = range_size * v_size_row;
2136
+ io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
2137
+ }
2138
+ }
2139
+ } else {
2140
+ // When v is transposed, we also need the element size and get the element ranges from each row
2141
+ const uint32_t kv_size = size;
2142
+ for (uint32_t il = 0; il < n_layer; ++il) {
2143
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2144
+
2145
+ // Write value type
2146
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
2147
+ io.write(&v_type_i, sizeof(v_type_i));
2148
+
2149
+ // Write element size
2150
+ const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
2151
+ io.write(&v_size_el, sizeof(v_size_el));
2152
+
2153
+ // Write GQA embedding size
2154
+ io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
2155
+
2156
+ // For each row, we get the element values of each cell
2157
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2158
+ // Read each range of cells of v_size_el length each into tmp_buf and write out
2159
+ for (const auto & range : cell_ranges) {
2160
+ const size_t range_size = range.second - range.first;
2161
+ const size_t src_offset = (range.first + j * kv_size) * v_size_el;
2162
+ const size_t buf_size = range_size * v_size_el;
2163
+ io.write_tensor(v_l[il], src_offset, buf_size);
2164
+ }
2165
+ }
2166
+ }
2167
+ }
2168
+ }
2169
+
2170
+ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
2171
+ if (dest_seq_id != -1) {
2172
+ // single sequence
2173
+
2174
+ seq_rm(dest_seq_id, -1, -1);
2175
+
2176
+ llama_sbatch sbatch;
2177
+ llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
2178
+
2179
+ batch.n_tokens = cell_count;
2180
+ batch.n_seq_tokens = cell_count;
2181
+ batch.n_seqs = 1;
2182
+
2183
+ for (uint32_t i = 0; i < cell_count; ++i) {
2184
+ llama_pos pos;
2185
+ uint32_t n_seq_id;
2186
+
2187
+ io.read_to(&pos, sizeof(pos));
2188
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
2189
+
2190
+ if (n_seq_id != 0) {
2191
+ LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
2192
+ return false;
2193
+ }
2194
+
2195
+ batch.pos[i] = pos;
2196
+ }
2197
+ batch.n_seq_id[0] = 1;
2198
+ batch.seq_id[0] = &dest_seq_id;
2199
+ if (!find_slot(batch)) {
2200
+ LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
2201
+ return false;
2202
+ }
2203
+ commit();
2204
+
2205
+ // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
2206
+ // Assume that this is one contiguous block of cells
2207
+ GGML_ASSERT(head + cell_count <= size);
2208
+ GGML_ASSERT(cells[head].pos == batch.pos[0]);
2209
+ GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
2210
+ GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
2211
+ GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
2212
+ } else {
2213
+ // whole KV cache restore
2214
+
2215
+ if (cell_count > size) {
2216
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
2217
+ return false;
2218
+ }
2219
+
2220
+ clear();
2221
+
2222
+ for (uint32_t i = 0; i < cell_count; ++i) {
2223
+ kv_cell & cell = cells[i];
2224
+
2225
+ llama_pos pos;
2226
+ uint32_t n_seq_id;
2227
+
2228
+ io.read_to(&pos, sizeof(pos));
2229
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
2230
+
2231
+ cell.pos = pos;
2232
+
2233
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
2234
+ llama_seq_id seq_id;
2235
+ io.read_to(&seq_id, sizeof(seq_id));
2236
+
2237
+ // TODO: llama_kv_cache_recurrent should have a notion of max sequences
2238
+ //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
2239
+ if (seq_id < 0) {
2240
+ //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
2241
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
2242
+ return false;
2243
+ }
2244
+
2245
+ cell.seq_id.insert(seq_id);
2246
+
2247
+ int32_t & tail = cells[seq_id].tail;
2248
+ if (tail != -1) {
2249
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
2250
+ return false;
2251
+ }
2252
+ tail = i;
2253
+ }
2254
+ }
2255
+
2256
+ head = 0;
2257
+ used = cell_count;
2258
+ }
2259
+
2260
+ for (uint32_t i = 0; i < cell_count; ++i) {
2261
+ uint32_t cell_id = head + i;
2262
+ // make sure the recurrent states will keep their restored state
2263
+ cells[cell_id].src = cell_id;
2264
+ }
2265
+
2266
+ return true;
2267
+ }
2268
+
2269
+ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
2270
+ uint32_t v_trans;
2271
+ uint32_t n_layer;
2272
+ io.read_to(&v_trans, sizeof(v_trans));
2273
+ io.read_to(&n_layer, sizeof(n_layer));
2274
+
2275
+ if (n_layer != hparams.n_layer) {
2276
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
2277
+ return false;
2278
+ }
2279
+ if (cell_count > size) {
2280
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
2281
+ return false;
2282
+ }
2283
+ if (false != (bool) v_trans) {
2284
  LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
2285
  return false;
2286
  }
 
2432
  view->cells_sequences = (llama_seq_id *)p;
2433
  }
2434
 
2435
+ const std::vector<llama_kv_cache_unified::kv_cell> & kv_cells = kvu->cells;
2436
  llama_kv_cache_view_cell * c_curr = view->cells;
2437
  llama_seq_id * cs_curr = view->cells_sequences;
2438
  int32_t used_cells = 0;
examples/talk-llama/llama-kv-cache.h CHANGED
@@ -2,32 +2,72 @@
2
 
3
  #include "llama.h"
4
  #include "llama-io.h"
 
5
  #include "llama-memory.h"
6
 
7
  #include "ggml-cpp.h"
8
 
9
- #include <functional>
10
  #include <set>
11
  #include <vector>
12
 
13
  struct llama_cparams;
14
  struct llama_hparams;
15
  struct llama_ubatch;
 
 
 
16
 
17
  struct llama_kv_cache : public llama_memory_i {
18
- using llama_memory_i::llama_memory_i;
19
 
20
- virtual void restore() = 0; // call if batch processing fails - restores the cache state
21
- virtual void commit() = 0; // call after successful batch processing - clears any pending state
22
 
23
- virtual int32_t get_n_tokens() const = 0;
24
- virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
25
 
26
- virtual bool get_can_shift() const = 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  bool get_can_edit() const override { return get_can_shift(); }
 
 
 
 
 
 
 
29
  };
30
 
 
 
 
 
31
  struct llama_kv_cache_guard {
32
  llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
33
 
@@ -43,65 +83,50 @@ private:
43
  llama_kv_cache * kv;
44
  };
45
 
46
- struct llama_kv_cell {
47
- llama_pos pos = -1;
48
- llama_pos delta = 0;
49
- int32_t src = -1; // used by recurrent state models to copy states
50
- int32_t tail = -1;
51
 
52
- std::set<llama_seq_id> seq_id;
 
 
 
 
 
53
 
54
- bool has_seq_id(const llama_seq_id & id) const {
55
- return seq_id.find(id) != seq_id.end();
56
- }
57
 
58
- bool is_empty() const {
59
- return seq_id.empty();
60
- }
61
 
62
- bool is_same_seq(const llama_kv_cell & other) const {
63
- return seq_id == other.seq_id;
64
- }
65
- };
66
 
67
- // ring-buffer of cached KV data
68
- // TODO: pimpl
69
- // TODO: add notion of max sequences
70
- class llama_kv_cache_unified : public llama_kv_cache {
71
- public:
72
- // can be used to query data from the model if needed
73
- struct callbacks {
74
- std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
75
  };
76
 
77
- llama_kv_cache_unified(
78
- const llama_hparams & hparams,
79
- callbacks cbs);
80
-
81
- virtual ~llama_kv_cache_unified() = default;
82
 
83
- // TODO: become constructor
84
- bool init(
85
- const llama_model & model, // TODO: do not reference the model
86
- const llama_cparams & cparams,
87
  ggml_type type_k,
88
  ggml_type type_v,
 
 
89
  uint32_t kv_size,
90
- bool offload);
91
-
92
- int32_t get_n_tokens() const override;
93
- int32_t get_used_cells() const override;
94
 
95
- size_t total_size() const;
96
 
97
- // TODO: better data structures to reduce the cost of this operation
98
- llama_pos pos_max() const;
 
99
 
100
  void clear() override;
101
- void defrag() override;
102
-
103
- virtual void restore() override;
104
- virtual void commit() override;
105
 
106
  bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
107
  void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@@ -111,25 +136,76 @@ public:
111
 
112
  llama_pos seq_pos_max(llama_seq_id seq_id) const override;
113
 
114
- bool get_can_shift() const override;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- // find an empty slot of size "n_tokens" in the cache
117
  // updates the cache head
118
  // Note: On success, it's important that cache.head points
119
  // to the first cell of the slot.
120
- bool find_slot(const llama_ubatch & batch);
121
 
122
- // TODO: maybe not needed
123
- uint32_t get_padding(const llama_cparams & cparams) const;
124
 
125
- // find how many cells are currently in use
126
- uint32_t cell_max() const;
127
 
128
- size_t size_k_bytes() const;
129
- size_t size_v_bytes() const;
130
 
131
- // defrag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
 
133
  struct {
134
  std::vector<uint32_t> ids;
135
  } defrag_info;
@@ -138,7 +214,6 @@ public:
138
  bool defrag_prepare(int32_t n_max_nodes);
139
 
140
  // commit/restore cache
141
-
142
  struct slot_range {
143
  uint32_t c0 = 0; // note: these are cell indices, not sequence positions
144
  uint32_t c1 = 0;
@@ -149,25 +224,124 @@ public:
149
  std::vector<slot_range> ranges;
150
  } pending;
151
 
152
- // state write/load
 
153
 
154
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
155
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
156
 
157
- // members
 
158
 
159
- const llama_hparams & hparams;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- callbacks cbs;
 
162
 
163
- bool has_shift = false;
164
- bool do_defrag = false;
 
165
 
166
- // TODO: remove this and implement llama_kv_cache_recurrent instead
167
- bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
 
168
 
169
- bool v_trans = true; // the value tensor is transposed
170
- bool can_shift = false;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  // Note: The value of head isn't only used to optimize searching
173
  // for a free KV slot. llama_decode_impl also uses it, so it
@@ -179,18 +353,41 @@ public:
179
  // computed before each graph build
180
  uint32_t n = 0;
181
 
182
- std::vector<llama_kv_cell> cells;
183
 
184
  std::vector<ggml_tensor *> k_l; // per layer
185
  std::vector<ggml_tensor *> v_l;
186
 
187
  private:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  ggml_type type_k = GGML_TYPE_F16;
189
  ggml_type type_v = GGML_TYPE_F16;
190
 
191
  std::vector<ggml_context_ptr> ctxs;
192
  std::vector<ggml_backend_buffer_ptr> bufs;
193
 
 
 
 
 
 
 
 
 
194
  void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
195
  void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
196
 
@@ -198,11 +395,6 @@ private:
198
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
199
  };
200
 
201
- // TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
202
- //class llama_kv_cache_recurrent : public llama_kv_cache_unified {
203
- //public:
204
- // using llama_kv_cache_unified::llama_kv_cache_unified;
205
- //};
206
 
207
  //
208
  // kv cache view
 
2
 
3
  #include "llama.h"
4
  #include "llama-io.h"
5
+ #include "llama-graph.h"
6
  #include "llama-memory.h"
7
 
8
  #include "ggml-cpp.h"
9
 
 
10
  #include <set>
11
  #include <vector>
12
 
13
  struct llama_cparams;
14
  struct llama_hparams;
15
  struct llama_ubatch;
16
+ struct llama_sbatch;
17
+ struct llama_model;
18
+ struct llama_context;
19
 
20
  struct llama_kv_cache : public llama_memory_i {
21
+ virtual ~llama_kv_cache() = default;
22
 
23
+ // call if batch processing fails - restores the cache state
24
+ virtual void restore() = 0;
25
 
26
+ // call after successful batch processing - clears any pending state
27
+ virtual void commit() = 0;
28
 
29
+ // process any pending defrag/shift/etc. operations
30
+ // optionally call once before processing a new batch
31
+ virtual bool update(llama_context & lctx) = 0;
32
+
33
+ // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
34
+ virtual void defrag_sched(float thold) = 0;
35
+
36
+ // simulate full cache, used for allocating worst-case compute buffers
37
+ virtual void set_full() = 0;
38
+
39
+ //
40
+ // batch processing
41
+ //
42
+
43
+ virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
44
+
45
+ // different KV caches require different batch splitting strategies
46
+ virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
47
+
48
+ // find an empty slot of size "n_tokens" in the cache
49
+ virtual bool find_slot(const llama_ubatch & batch) = 0;
50
+
51
+ // getters
52
+ virtual int32_t get_n_tokens() const = 0;
53
+ virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
54
+ virtual llama_pos get_pos_max() const = 0;
55
+ virtual bool get_can_shift() const = 0;
56
 
57
  bool get_can_edit() const override { return get_can_shift(); }
58
+
59
+ //
60
+ // state write/read
61
+ //
62
+
63
+ virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
64
+ virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
65
  };
66
 
67
+ //
68
+ // llama_kv_cache_guard
69
+ //
70
+
71
  struct llama_kv_cache_guard {
72
  llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
73
 
 
83
  llama_kv_cache * kv;
84
  };
85
 
86
+ //
87
+ // llama_kv_cache_unified
88
+ //
 
 
89
 
90
+ // TODO: add notion of max sequences
91
+ class llama_kv_cache_unified : public llama_kv_cache {
92
+ public:
93
+ struct kv_cell {
94
+ llama_pos pos = -1;
95
+ llama_pos delta = 0;
96
 
97
+ std::set<llama_seq_id> seq_id;
 
 
98
 
99
+ bool has_seq_id(const llama_seq_id & id) const {
100
+ return seq_id.find(id) != seq_id.end();
101
+ }
102
 
103
+ bool is_empty() const {
104
+ return seq_id.empty();
105
+ }
 
106
 
107
+ bool is_same_seq(const kv_cell & other) const {
108
+ return seq_id == other.seq_id;
109
+ }
 
 
 
 
 
110
  };
111
 
112
+ static uint32_t get_padding(const llama_cparams & cparams);
 
 
 
 
113
 
114
+ llama_kv_cache_unified(
115
+ const llama_model & model,
 
 
116
  ggml_type type_k,
117
  ggml_type type_v,
118
+ bool v_trans,
119
+ bool offload,
120
  uint32_t kv_size,
121
+ uint32_t padding);
 
 
 
122
 
123
+ ~llama_kv_cache_unified() = default;
124
 
125
+ //
126
+ // llama_memory_i
127
+ //
128
 
129
  void clear() override;
 
 
 
 
130
 
131
  bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
132
  void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
 
136
 
137
  llama_pos seq_pos_max(llama_seq_id seq_id) const override;
138
 
139
+ //
140
+ // llama_kv_cache
141
+ //
142
+
143
+ void restore() override;
144
+ void commit() override;
145
+
146
+ bool update(llama_context & ctx) override;
147
+
148
+ void defrag_sched(float thold) override;
149
+
150
+ void set_full() override;
151
+
152
+ llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
153
+
154
+ llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
155
 
 
156
  // updates the cache head
157
  // Note: On success, it's important that cache.head points
158
  // to the first cell of the slot.
159
+ bool find_slot(const llama_ubatch & batch) override;
160
 
161
+ int32_t get_n_tokens() const override;
162
+ int32_t get_used_cells() const override;
163
 
164
+ // TODO: better data structures to reduce the cost of this operation
165
+ llama_pos get_pos_max() const override;
166
 
167
+ bool get_can_shift() const override;
 
168
 
169
+ // state write/load
170
+
171
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
172
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
173
+
174
+ // Note: The value of head isn't only used to optimize searching
175
+ // for a free KV slot. llama_decode_impl also uses it, so it
176
+ // cannot be freely changed after a slot has been allocated.
177
+ uint32_t head = 0;
178
+ uint32_t size = 0;
179
+ uint32_t used = 0; // used cells (i.e. at least one seq_id)
180
+
181
+ // computed before each graph build
182
+ uint32_t n = 0;
183
+
184
+ std::vector<kv_cell> cells;
185
+
186
+ std::vector<ggml_tensor *> k_l; // per layer
187
+ std::vector<ggml_tensor *> v_l;
188
+
189
+ private:
190
+ const llama_model & model;
191
+ const llama_hparams & hparams;
192
+
193
+ bool has_shift = false;
194
+ bool do_defrag = false;
195
+
196
+ bool v_trans = true; // the value tensor is transposed
197
+ bool can_shift = false;
198
+
199
+ // required padding
200
+ uint32_t padding = 1;
201
+
202
+ ggml_type type_k = GGML_TYPE_F16;
203
+ ggml_type type_v = GGML_TYPE_F16;
204
+
205
+ std::vector<ggml_context_ptr> ctxs;
206
+ std::vector<ggml_backend_buffer_ptr> bufs;
207
 
208
+ // defrag
209
  struct {
210
  std::vector<uint32_t> ids;
211
  } defrag_info;
 
214
  bool defrag_prepare(int32_t n_max_nodes);
215
 
216
  // commit/restore cache
 
217
  struct slot_range {
218
  uint32_t c0 = 0; // note: these are cell indices, not sequence positions
219
  uint32_t c1 = 0;
 
224
  std::vector<slot_range> ranges;
225
  } pending;
226
 
227
+ // find how many cells are currently in use
228
+ uint32_t cell_max() const;
229
 
230
+ size_t total_size() const;
 
231
 
232
+ size_t size_k_bytes() const;
233
+ size_t size_v_bytes() const;
234
 
235
+ ggml_tensor * build_rope_shift(
236
+ const llama_cparams & cparams,
237
+ ggml_context * ctx,
238
+ ggml_tensor * cur,
239
+ ggml_tensor * shift,
240
+ ggml_tensor * factors,
241
+ float freq_base,
242
+ float freq_scale) const;
243
+
244
+ llm_graph_result_ptr build_graph_shift(
245
+ const llama_cparams & cparams,
246
+ ggml_context * ctx,
247
+ ggml_cgraph * gf) const;
248
+
249
+ llm_graph_result_ptr build_graph_defrag(
250
+ const llama_cparams & cparams,
251
+ ggml_context * ctx,
252
+ ggml_cgraph * gf) const;
253
 
254
+ void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
255
+ void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
256
 
257
+ bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
258
+ bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
259
+ };
260
 
261
+ //
262
+ // llama_kv_cache_recurrent
263
+ //
264
 
265
+ class llama_kv_cache_recurrent : public llama_kv_cache {
266
+ public:
267
+ struct kv_cell {
268
+ llama_pos pos = -1;
269
+ int32_t src = -1; // used to copy states
270
+ int32_t tail = -1;
271
+
272
+ std::set<llama_seq_id> seq_id;
273
+
274
+ bool has_seq_id(const llama_seq_id & id) const {
275
+ return seq_id.find(id) != seq_id.end();
276
+ }
277
+
278
+ bool is_empty() const {
279
+ return seq_id.empty();
280
+ }
281
+
282
+ bool is_same_seq(const kv_cell & other) const {
283
+ return seq_id == other.seq_id;
284
+ }
285
+ };
286
+
287
+ llama_kv_cache_recurrent(
288
+ const llama_model & model,
289
+ ggml_type type_k,
290
+ ggml_type type_v,
291
+ bool offload,
292
+ uint32_t kv_size);
293
+
294
+ ~llama_kv_cache_recurrent() = default;
295
+
296
+ //
297
+ // llama_memory_i
298
+ //
299
+
300
+ void clear() override;
301
+
302
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
303
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
304
+ void seq_keep(llama_seq_id seq_id) override;
305
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
306
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
307
+
308
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
309
+
310
+ //
311
+ // llama_kv_cache
312
+ //
313
+
314
+ void restore() override;
315
+ void commit() override;
316
+
317
+ bool update(llama_context & lctx) override;
318
+
319
+ void defrag_sched(float thold) override;
320
+
321
+ void set_full() override;
322
+
323
+ llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
324
+
325
+ llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
326
+
327
+ bool find_slot(const llama_ubatch & batch) override;
328
+
329
+ int32_t get_n_tokens() const override;
330
+ int32_t get_used_cells() const override;
331
+
332
+ // TODO: better data structures to reduce the cost of this operation
333
+ llama_pos get_pos_max() const override;
334
+
335
+ bool get_can_shift() const override;
336
+
337
+ // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
338
+ int32_t s_copy(int i) const;
339
+ float s_mask(int i) const;
340
+
341
+ // state write/load
342
+
343
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
344
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
345
 
346
  // Note: The value of head isn't only used to optimize searching
347
  // for a free KV slot. llama_decode_impl also uses it, so it
 
353
  // computed before each graph build
354
  uint32_t n = 0;
355
 
356
+ std::vector<kv_cell> cells;
357
 
358
  std::vector<ggml_tensor *> k_l; // per layer
359
  std::vector<ggml_tensor *> v_l;
360
 
361
  private:
362
+ //const llama_model & model;
363
+ const llama_hparams & hparams;
364
+
365
+ // commit/restore cache
366
+ // TODO: rework for recurrent cache
367
+ struct slot_range {
368
+ uint32_t c0 = 0; // note: these are cell indices, not sequence positions
369
+ uint32_t c1 = 0;
370
+ };
371
+
372
+ // pending cell updates that are not yet committed
373
+ struct {
374
+ std::vector<slot_range> ranges;
375
+ } pending;
376
+
377
  ggml_type type_k = GGML_TYPE_F16;
378
  ggml_type type_v = GGML_TYPE_F16;
379
 
380
  std::vector<ggml_context_ptr> ctxs;
381
  std::vector<ggml_backend_buffer_ptr> bufs;
382
 
383
+ // find how many cells are currently in use
384
+ uint32_t cell_max() const;
385
+
386
+ size_t total_size() const;
387
+
388
+ size_t size_k_bytes() const;
389
+ size_t size_v_bytes() const;
390
+
391
  void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
392
  void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
393
 
 
395
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
396
  };
397
 
 
 
 
 
 
398
 
399
  //
400
  // kv cache view
examples/talk-llama/llama-memory.h CHANGED
@@ -2,12 +2,22 @@
2
 
3
  #include "llama.h"
4
 
 
 
 
 
 
 
 
 
 
5
  // general concept of LLM memory
6
  // the KV cache is a type of LLM memory, but there can be other types
7
  class llama_memory_i {
8
  public:
 
 
9
  virtual void clear() = 0;
10
- virtual void defrag() = 0;
11
 
12
  virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
13
  virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
 
2
 
3
  #include "llama.h"
4
 
5
+ struct llama_memory_params {
6
+ // kv cache
7
+ ggml_type type_k;
8
+ ggml_type type_v;
9
+
10
+ // parameters for other types of memory
11
+ // ...
12
+ };
13
+
14
  // general concept of LLM memory
15
  // the KV cache is a type of LLM memory, but there can be other types
16
  class llama_memory_i {
17
  public:
18
+ virtual ~llama_memory_i() = default;
19
+
20
  virtual void clear() = 0;
 
21
 
22
  virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
23
  virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
examples/talk-llama/llama-model-loader.cpp CHANGED
@@ -301,12 +301,12 @@ namespace GGUFMeta {
301
  GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
302
 
303
  switch (arr_info.gt) {
304
- case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
305
- case GGUF_TYPE_INT32: GGML_ASSERT(
306
- (std::is_same<T, int32_t>::value) ||
307
- (std::is_same<T, uint32_t>::value)); break;
308
  default:
309
- throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
310
  }
311
 
312
  result.resize(arr_info.length);
@@ -330,12 +330,12 @@ namespace GGUFMeta {
330
  GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
331
 
332
  switch (arr_info.gt) {
333
- case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
334
- case GGUF_TYPE_INT32: GGML_ASSERT(
335
- (std::is_same<T, int32_t>::value) ||
336
- (std::is_same<T, uint32_t>::value)); break;
337
  default:
338
- throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
339
  }
340
 
341
  if (arr_info.length > N_MAX) {
@@ -823,6 +823,10 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
823
  mmaps_used.reserve(files.size());
824
  for (const auto & file : files) {
825
  auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
 
 
 
 
826
  auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
827
  std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
828
  mmaps_used.emplace_back(mapping->size(), 0);
 
301
  GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
302
 
303
  switch (arr_info.gt) {
304
+ case GGUF_TYPE_UINT32:
305
+ case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
306
+ (std::is_same<T, uint32_t>::value)); break;
307
+ case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
308
  default:
309
+ throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
310
  }
311
 
312
  result.resize(arr_info.length);
 
330
  GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
331
 
332
  switch (arr_info.gt) {
333
+ case GGUF_TYPE_UINT32:
334
+ case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
335
+ (std::is_same<T, uint32_t>::value)); break;
336
+ case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
337
  default:
338
+ throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
339
  }
340
 
341
  if (arr_info.length > N_MAX) {
 
823
  mmaps_used.reserve(files.size());
824
  for (const auto & file : files) {
825
  auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
826
+ if (!reg) {
827
+ throw std::runtime_error(format("%s: no CPU backend found", __func__));
828
+ }
829
+
830
  auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
831
  std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
832
  mmaps_used.emplace_back(mapping->size(), 0);
examples/talk-llama/llama-model-saver.cpp ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-model-saver.h"
2
+
3
+ #include "gguf.h"
4
+
5
+ #include "llama.h"
6
+ #include "llama-hparams.h"
7
+ #include "llama-model.h"
8
+ #include "llama-vocab.h"
9
+
10
+ #include <string>
11
+
12
+ llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) {
13
+ gguf_ctx = gguf_init_empty();
14
+ }
15
+
16
+ llama_model_saver::~llama_model_saver() {
17
+ gguf_free(gguf_ctx);
18
+ }
19
+
20
+ void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) {
21
+ gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value);
22
+ }
23
+
24
+ void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) {
25
+ gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value);
26
+ }
27
+
28
+ void llama_model_saver::add_kv(const enum llm_kv key, const float value) {
29
+ gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value);
30
+ }
31
+
32
+ void llama_model_saver::add_kv(const enum llm_kv key, const bool value) {
33
+ gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value);
34
+ }
35
+
36
+ void llama_model_saver::add_kv(const enum llm_kv key, const char * value) {
37
+ gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value);
38
+ }
39
+
40
+ [[noreturn]]
41
+ void llama_model_saver::add_kv(const enum llm_kv key, const char value) {
42
+ GGML_UNUSED(key);
43
+ GGML_UNUSED(value);
44
+ GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile
45
+ }
46
+
47
+ template <typename Container>
48
+ void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) {
49
+ const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size();
50
+ GGML_ASSERT(n_values <= value.size());
51
+
52
+ if (n_values == 0) {
53
+ return;
54
+ }
55
+
56
+ if (per_layer) {
57
+ bool all_values_the_same = true;
58
+ for (size_t i = 1; i < n_values; ++i) {
59
+ if (value[i] != value[0]) {
60
+ all_values_the_same = false;
61
+ break;
62
+ }
63
+ }
64
+ if (all_values_the_same) {
65
+ add_kv(key, value[0]);
66
+ return;
67
+ }
68
+ }
69
+
70
+ if (std::is_same<typename Container::value_type, uint8_t>::value) {
71
+ gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values);
72
+ } else if (std::is_same<typename Container::value_type, int8_t>::value) {
73
+ gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values);
74
+ } else if (std::is_same<typename Container::value_type, uint32_t>::value) {
75
+ gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values);
76
+ } else if (std::is_same<typename Container::value_type, int32_t>::value) {
77
+ gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values);
78
+ } else if (std::is_same<typename Container::value_type, float>::value) {
79
+ gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values);
80
+ } else if (std::is_same<Container, std::string>::value) {
81
+ gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast<const char *>(value.data()));
82
+ } else {
83
+ GGML_ABORT("fatal error");
84
+ }
85
+ }
86
+
87
+ void llama_model_saver::add_kv(const enum llm_kv key, const std::vector<std::string> & value) {
88
+ std::vector<const char *> tmp(value.size());
89
+ for (size_t i = 0; i < value.size(); ++i) {
90
+ tmp[i] = value[i].c_str();
91
+ }
92
+ gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size());
93
+ }
94
+
95
+ void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) {
96
+ if (!tensor) {
97
+ return;
98
+ }
99
+ if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) {
100
+ GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME
101
+ return;
102
+ }
103
+ gguf_add_tensor(gguf_ctx, tensor);
104
+ }
105
+
106
+ void llama_model_saver::add_kv_from_model() {
107
+ const llama_hparams & hparams = model.hparams;
108
+ const llama_vocab & vocab = model.vocab;
109
+
110
+ const int32_t n_vocab = vocab.n_tokens();
111
+ std::vector<std::string> tokens(n_vocab);
112
+ std::vector<float> scores(n_vocab);
113
+ std::vector<int32_t> token_types(n_vocab);
114
+
115
+ for (int32_t id = 0; id < n_vocab; ++id) {
116
+ const llama_vocab::token_data & token_data = vocab.get_token_data(id);
117
+
118
+ tokens[id] = token_data.text;
119
+ scores[id] = token_data.score;
120
+
121
+ switch(token_data.attr) {
122
+ case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break;
123
+ case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break;
124
+ case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break;
125
+ case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break;
126
+ case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break;
127
+ case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break;
128
+ case LLAMA_TOKEN_ATTR_UNDEFINED:
129
+ default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break;
130
+ }
131
+ }
132
+
133
+ // add_kv(LLM_KV_GENERAL_TYPE, ???);
134
+ add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name());
135
+ // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???);
136
+ // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???);
137
+ add_kv(LLM_KV_GENERAL_NAME, model.name);
138
+ // add_kv(LLM_KV_GENERAL_AUTHOR, ???);
139
+ // add_kv(LLM_KV_GENERAL_VERSION, ???);
140
+ // add_kv(LLM_KV_GENERAL_URL, ???);
141
+ // add_kv(LLM_KV_GENERAL_DESCRIPTION, ???);
142
+ // add_kv(LLM_KV_GENERAL_LICENSE, ???);
143
+ // add_kv(LLM_KV_GENERAL_SOURCE_URL, ???);
144
+ // add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO, ???);
145
+
146
+ add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens());
147
+ add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
148
+ add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
149
+ add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer);
150
+ add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
151
+ add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true);
152
+ add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
153
+ add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
154
+ add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
155
+ // add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???);
156
+ add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert);
157
+ add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used);
158
+ add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
159
+ add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
160
+ add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type));
161
+ add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
162
+ add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id);
163
+ add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping);
164
+ add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping);
165
+ add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm);
166
+ add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers);
167
+ add_kv(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
168
+ add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
169
+ add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
170
+ add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
171
+
172
+ add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true);
173
+ add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true);
174
+ add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
175
+ add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
176
+ add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k);
177
+ add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v);
178
+ add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
179
+ add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
180
+ add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
181
+ add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
182
+ add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
183
+ add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
184
+ add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
185
+ add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
186
+
187
+ const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
188
+
189
+ add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
190
+ add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train);
191
+ // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name
192
+ add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train));
193
+ add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor);
194
+ add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor);
195
+ add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn);
196
+ add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned);
197
+ add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
198
+
199
+ // TODO: implement split file support
200
+ // add_kv(LLM_KV_SPLIT_NO, ???);
201
+ // add_kv(LLM_KV_SPLIT_COUNT, ???);
202
+ // add_kv(LLM_KV_SPLIT_TENSORS_COUNT, ???);
203
+
204
+ add_kv(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
205
+ add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
206
+ add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
207
+ add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
208
+ add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms);
209
+
210
+ add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
211
+
212
+ add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model());
213
+ add_kv(LLM_KV_TOKENIZER_PRE, vocab.get_tokenizer_pre());
214
+ add_kv(LLM_KV_TOKENIZER_LIST, tokens);
215
+ add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE, token_types);
216
+ add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, vocab.n_token_types());
217
+ add_kv(LLM_KV_TOKENIZER_SCORES, scores);
218
+ add_kv(LLM_KV_TOKENIZER_MERGES, vocab.get_bpe_merges());
219
+ // FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though
220
+ add_kv(LLM_KV_TOKENIZER_BOS_ID, uint32_t(vocab.token_bos()));
221
+ add_kv(LLM_KV_TOKENIZER_EOS_ID, uint32_t(vocab.token_eos()));
222
+ add_kv(LLM_KV_TOKENIZER_EOT_ID, uint32_t(vocab.token_eot()));
223
+ add_kv(LLM_KV_TOKENIZER_EOM_ID, uint32_t(vocab.token_eom()));
224
+ add_kv(LLM_KV_TOKENIZER_UNK_ID, uint32_t(vocab.token_unk()));
225
+ add_kv(LLM_KV_TOKENIZER_SEP_ID, uint32_t(vocab.token_sep()));
226
+ add_kv(LLM_KV_TOKENIZER_PAD_ID, uint32_t(vocab.token_pad()));
227
+ // add_kv(LLM_KV_TOKENIZER_CLS_ID, uint32_t(vocab.token_bos())); // deprecated
228
+ // add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
229
+ add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
230
+ add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
231
+ add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
232
+ add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
233
+ add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());
234
+ // add_kv(LLM_KV_TOKENIZER_HF_JSON, ???);
235
+ // add_kv(LLM_KV_TOKENIZER_RWKV, ???);
236
+ add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID, uint32_t(vocab.token_fim_pre()));
237
+ add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID, uint32_t(vocab.token_fim_suf()));
238
+ add_kv(LLM_KV_TOKENIZER_FIM_MID_ID, uint32_t(vocab.token_fim_mid()));
239
+ add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID, uint32_t(vocab.token_fim_pad()));
240
+ add_kv(LLM_KV_TOKENIZER_FIM_REP_ID, uint32_t(vocab.token_fim_rep()));
241
+ add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID, uint32_t(vocab.token_fim_sep()));
242
+
243
+ // TODO: implement LoRA support
244
+ // add_kv(LLM_KV_ADAPTER_TYPE, ???);
245
+ // add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???);
246
+
247
+ // deprecated
248
+ // add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???);
249
+ // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???);
250
+ // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???);
251
+ }
252
+
253
+ void llama_model_saver::add_tensors_from_model() {
254
+ if (std::string(model.output->name) != std::string(model.tok_embd->name)) {
255
+ add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output
256
+ }
257
+ add_tensor(model.type_embd);
258
+ add_tensor(model.pos_embd);
259
+ add_tensor(model.tok_norm);
260
+ add_tensor(model.tok_norm_b);
261
+ add_tensor(model.output_norm);
262
+ add_tensor(model.output_norm_b);
263
+ add_tensor(model.output);
264
+ add_tensor(model.output_b);
265
+ add_tensor(model.output_norm_enc);
266
+ add_tensor(model.cls);
267
+ add_tensor(model.cls_b);
268
+ add_tensor(model.cls_out);
269
+ add_tensor(model.cls_out_b);
270
+
271
+ for (const struct llama_layer & layer : model.layers) {
272
+ for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
273
+ add_tensor(reinterpret_cast<const struct ggml_tensor * const *>(&layer)[i]);
274
+ }
275
+ }
276
+ }
277
+
278
+ void llama_model_saver::save(const std::string & path_model) {
279
+ gguf_write_to_file(gguf_ctx, path_model.c_str(), false);
280
+ }
281
+
examples/talk-llama/llama-model-saver.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+ #include "llama-arch.h"
5
+
6
+ #include <vector>
7
+
8
+ struct llama_model_saver {
9
+ struct gguf_context * gguf_ctx = nullptr;
10
+ const struct llama_model & model;
11
+ const struct LLM_KV llm_kv;
12
+
13
+ llama_model_saver(const struct llama_model & model);
14
+ ~llama_model_saver();
15
+
16
+ void add_kv(enum llm_kv key, uint32_t value);
17
+ void add_kv(enum llm_kv key, int32_t value);
18
+ void add_kv(enum llm_kv key, float value);
19
+ void add_kv(enum llm_kv key, bool value);
20
+ void add_kv(enum llm_kv key, const char * value);
21
+
22
+ [[noreturn]]
23
+ void add_kv(enum llm_kv key, char value); // needed to make the template below compile
24
+
25
+ template <typename Container>
26
+ void add_kv(enum llm_kv key, const Container & value, bool per_layer = false);
27
+
28
+ void add_kv(enum llm_kv key, const std::vector<std::string> & value);
29
+
30
+ void add_tensor(const struct ggml_tensor * tensor);
31
+
32
+ void add_kv_from_model();
33
+
34
+ void add_tensors_from_model();
35
+
36
+ void save(const std::string & path_model);
37
+ };
examples/talk-llama/llama-model.cpp CHANGED
@@ -80,6 +80,7 @@ const char * llm_type_name(llm_type type) {
80
  case LLM_TYPE_236B: return "236B";
81
  case LLM_TYPE_290B: return "290B";
82
  case LLM_TYPE_314B: return "314B";
 
83
  case LLM_TYPE_671B: return "671B";
84
  case LLM_TYPE_SMALL: return "0.1B";
85
  case LLM_TYPE_MEDIUM: return "0.4B";
@@ -116,6 +117,10 @@ static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_
116
  { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" },
117
  };
118
 
 
 
 
 
119
  static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
120
  for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
121
  if (kv.second == name) {
@@ -298,6 +303,10 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
298
  // add extra buffer types, only if no GPU device is present
299
  // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094
300
  auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
 
 
 
 
301
  auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
302
  auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
303
  ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
@@ -582,6 +591,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
582
  switch (hparams.n_layer) {
583
  case 32: type = LLM_TYPE_7B; break;
584
  case 80: type = LLM_TYPE_70B; break;
 
585
  default: type = LLM_TYPE_UNKNOWN;
586
  }
587
  } break;
@@ -773,6 +783,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
773
  // fall through
774
  case LLM_ARCH_QWEN2:
775
  {
 
776
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
777
  switch (hparams.n_layer) {
778
  case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break;
@@ -1481,6 +1492,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
1481
  }
1482
 
1483
  ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
 
 
 
1484
  const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
1485
  const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
1486
  auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
@@ -1648,8 +1662,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
1648
  for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
1649
  std::regex pattern(overrides->pattern);
1650
  if (std::regex_search(tensor_name, pattern)) {
1651
- LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
1652
  buft = overrides->buft;
 
 
 
 
1653
  break;
1654
  }
1655
  }
@@ -1666,6 +1683,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
1666
  auto * buft_dev = ggml_backend_buft_get_device(buft);
1667
  if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
1668
  auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
 
 
 
1669
  buft = ggml_backend_dev_buffer_type(cpu_dev);
1670
  }
1671
 
@@ -1847,7 +1867,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
1847
  layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
1848
  layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
1849
 
1850
- layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
 
1851
 
1852
  if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
1853
  layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
@@ -1857,9 +1879,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
1857
  layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
1858
  }
1859
 
1860
- layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
1861
- layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
1862
- layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
 
 
1863
 
1864
  // optional MLP bias
1865
  layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
@@ -3503,7 +3527,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
3503
 
3504
  // output
3505
  output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3506
- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
 
 
 
 
3507
 
3508
  for (int i = 0; i < n_layer; ++i) {
3509
  auto & layer = layers[i];
@@ -4108,6 +4136,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
4108
  if (!dev) {
4109
  // FIXME: workaround for CPU backend buft having a NULL device
4110
  dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
 
 
 
4111
  }
4112
  ggml_backend_dev_props props;
4113
  ggml_backend_dev_get_props(dev, &props);
@@ -4237,7 +4268,7 @@ uint64_t llama_model::n_elements() const {
4237
  }
4238
 
4239
  void llama_model::print_info() const {
4240
- const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
4241
 
4242
  auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) {
4243
  bool is_var = false;
@@ -4298,7 +4329,7 @@ void llama_model::print_info() const {
4298
  LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
4299
  LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
4300
  LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
4301
- LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
4302
  LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
4303
  LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
4304
  LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
@@ -4445,6 +4476,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
4445
  return it->second;
4446
  }
4447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4448
  struct llm_build_llama : public llm_graph_context {
4449
  llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
4450
  const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -4485,7 +4529,7 @@ struct llm_build_llama : public llm_graph_context {
4485
  // self-attention
4486
  {
4487
  // rope freq factors for llama3; may return nullptr for llama2 and other models
4488
- ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
4489
 
4490
  // compute Q and K and RoPE them
4491
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4691,6 +4735,7 @@ struct llm_build_deci : public llm_graph_context {
4691
  ggml_tensor * inpSA = inpL;
4692
  const int64_t n_head_kv = hparams.n_head_kv(il);
4693
  const int64_t n_head = hparams.n_head(il);
 
4694
 
4695
  if (n_head == 0) {
4696
  // attention-free layer of Llama-3_1-Nemotron-51B
@@ -4710,7 +4755,7 @@ struct llm_build_deci : public llm_graph_context {
4710
  } else if (n_head > 0) {
4711
  // self-attention
4712
  // rope freq factors for llama3; may return nullptr for llama2 and other models
4713
- ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
4714
 
4715
  // compute Q and K and RoPE them
4716
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4766,6 +4811,11 @@ struct llm_build_deci : public llm_graph_context {
4766
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
4767
  }
4768
 
 
 
 
 
 
4769
  // For Granite architecture
4770
  if (hparams.f_residual_scale) {
4771
  cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
@@ -7192,7 +7242,7 @@ struct llm_build_phi3 : public llm_graph_context {
7192
  // self-attention
7193
  {
7194
  // rope freq factors for 128k context
7195
- ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
7196
 
7197
  ggml_tensor* attn_norm_output = build_norm(inpL,
7198
  model.layers[il].attn_norm,
@@ -7944,7 +7994,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
7944
  for (int il = 0; il < n_layer; ++il) {
7945
  ggml_tensor * inpSA = inpL;
7946
 
7947
- ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
7948
 
7949
  // norm
7950
  cur = build_norm(inpL,
@@ -8711,7 +8761,7 @@ struct llm_build_mamba : public llm_graph_context {
8711
  ggml_tensor * state_mask,
8712
  const llama_ubatch & ubatch,
8713
  int il) const {
8714
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
8715
 
8716
  const auto kv_head = kv_self->head;
8717
 
@@ -9012,7 +9062,7 @@ struct llm_build_cohere2 : public llm_graph_context {
9012
  // self-attention
9013
  {
9014
  // rope freq factors for 128k context
9015
- ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
9016
 
9017
  // compute Q and K and RoPE them
9018
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -9950,7 +10000,7 @@ struct llm_build_deepseek : public llm_graph_context {
9950
  // self-attention
9951
  {
9952
  // rope freq factors for llama3; may return nullptr for llama2 and other models
9953
- ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
9954
 
9955
  // compute Q and K and RoPE them
9956
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11314,7 +11364,7 @@ struct llm_build_exaone : public llm_graph_context {
11314
  // self-attention
11315
  {
11316
  // rope freq factors for llama3; may return nullptr for llama2 and other models
11317
- ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
11318
 
11319
  // compute Q and K and RoPE them
11320
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11459,7 +11509,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11459
  ggml_tensor * state_mask,
11460
  const llama_ubatch & ubatch,
11461
  int il) const {
11462
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
11463
 
11464
  const auto n_tokens = ubatch.n_tokens;
11465
  const auto n_seqs = ubatch.n_seqs;
@@ -11855,7 +11905,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
11855
  ggml_tensor *& first_layer_value,
11856
  const llama_ubatch & ubatch,
11857
  int il) const {
11858
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
11859
 
11860
  const auto n_tokens = ubatch.n_tokens;
11861
  const auto n_seqs = ubatch.n_seqs;
@@ -12695,7 +12745,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
12695
  // self-attention
12696
  {
12697
  // rope freq factors for llama3; may return nullptr for llama2 and other models
12698
- ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
12699
 
12700
  // compute Q and K and RoPE them
12701
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12815,36 +12865,46 @@ struct llm_build_bailingmoe : public llm_graph_context {
12815
  }
12816
  };
12817
 
12818
- llama_memory_i * llama_model::create_memory() const {
12819
  llama_memory_i * res;
12820
 
12821
  switch (arch) {
 
 
 
 
 
 
 
12822
  case LLM_ARCH_MAMBA:
12823
  case LLM_ARCH_RWKV6:
12824
  case LLM_ARCH_RWKV6QWEN2:
12825
  case LLM_ARCH_RWKV7:
12826
  case LLM_ARCH_ARWKV7:
12827
  {
12828
- res = new llama_kv_cache_unified(hparams, {
12829
- /*.get_rope_factors =*/ nullptr
12830
- });
 
 
 
12831
  } break;
12832
  default:
12833
  {
12834
- res = new llama_kv_cache_unified(hparams, {
12835
- /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
12836
- // choose long/short freq factors based on the context size
12837
- if (layers[il].rope_freqs != nullptr) {
12838
- return layers[il].rope_freqs;
12839
- }
12840
 
12841
- if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
12842
- return layers[il].rope_long;
12843
- }
12844
 
12845
- return layers[il].rope_short;
12846
- }
12847
- });
 
 
 
 
 
 
 
12848
  }
12849
  }
12850
 
@@ -13226,8 +13286,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13226
  case LLM_ARCH_DECI:
13227
  case LLM_ARCH_BAICHUAN:
13228
  case LLM_ARCH_STARCODER:
13229
- case LLM_ARCH_PLAMO:
13230
- case LLM_ARCH_ORION:
13231
  case LLM_ARCH_INTERNLM2:
13232
  case LLM_ARCH_MINICPM:
13233
  case LLM_ARCH_XVERSE:
@@ -13265,6 +13323,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13265
  case LLM_ARCH_PHI2:
13266
  case LLM_ARCH_PHI3:
13267
  case LLM_ARCH_PHIMOE:
 
13268
  case LLM_ARCH_GEMMA:
13269
  case LLM_ARCH_GEMMA2:
13270
  case LLM_ARCH_GEMMA3:
@@ -13272,6 +13331,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13272
  case LLM_ARCH_OPENELM:
13273
  case LLM_ARCH_GPTNEOX:
13274
  case LLM_ARCH_CODESHELL:
 
13275
  case LLM_ARCH_NEMOTRON:
13276
  case LLM_ARCH_EXAONE:
13277
  case LLM_ARCH_MINICPM3:
@@ -13344,6 +13404,14 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
13344
  : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
13345
  const auto & it = model->gguf_kv.find(key);
13346
  if (it == model->gguf_kv.end()) {
 
 
 
 
 
 
 
 
13347
  return nullptr;
13348
  }
13349
 
 
80
  case LLM_TYPE_236B: return "236B";
81
  case LLM_TYPE_290B: return "290B";
82
  case LLM_TYPE_314B: return "314B";
83
+ case LLM_TYPE_405B: return "405B";
84
  case LLM_TYPE_671B: return "671B";
85
  case LLM_TYPE_SMALL: return "0.1B";
86
  case LLM_TYPE_MEDIUM: return "0.4B";
 
117
  { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" },
118
  };
119
 
120
+ std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) {
121
+ return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type);
122
+ }
123
+
124
  static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
125
  for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
126
  if (kv.second == name) {
 
303
  // add extra buffer types, only if no GPU device is present
304
  // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094
305
  auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
306
+ if (cpu_dev == nullptr) {
307
+ throw std::runtime_error(format("%s: no CPU backend found", __func__));
308
+ }
309
+
310
  auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
311
  auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
312
  ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
 
591
  switch (hparams.n_layer) {
592
  case 32: type = LLM_TYPE_7B; break;
593
  case 80: type = LLM_TYPE_70B; break;
594
+ case 162: type = LLM_TYPE_405B; break;
595
  default: type = LLM_TYPE_UNKNOWN;
596
  }
597
  } break;
 
783
  // fall through
784
  case LLM_ARCH_QWEN2:
785
  {
786
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
787
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
788
  switch (hparams.n_layer) {
789
  case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break;
 
1492
  }
1493
 
1494
  ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
1495
+ if (cpu_dev == nullptr) {
1496
+ throw std::runtime_error(format("%s: no CPU backend found", __func__));
1497
+ }
1498
  const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
1499
  const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
1500
  auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
 
1662
  for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
1663
  std::regex pattern(overrides->pattern);
1664
  if (std::regex_search(tensor_name, pattern)) {
 
1665
  buft = overrides->buft;
1666
+ LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n",
1667
+ tensor_name.c_str(),
1668
+ ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type),
1669
+ ggml_backend_buft_name(buft));
1670
  break;
1671
  }
1672
  }
 
1683
  auto * buft_dev = ggml_backend_buft_get_device(buft);
1684
  if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
1685
  auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
1686
+ if (!cpu_dev) {
1687
+ throw std::runtime_error("no CPU backend found");
1688
+ }
1689
  buft = ggml_backend_dev_buffer_type(cpu_dev);
1690
  }
1691
 
 
1867
  layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
1868
  layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
1869
 
1870
+ if (n_ff > 0) {
1871
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
1872
+ }
1873
 
1874
  if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
1875
  layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
 
1879
  layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
1880
  }
1881
 
1882
+ if (n_ff > 0) {
1883
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
1884
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
1885
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
1886
+ }
1887
 
1888
  // optional MLP bias
1889
  layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
 
3527
 
3528
  // output
3529
  output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3530
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
3531
+ // if output is NULL, init from the input tok embed
3532
+ if (output == NULL) {
3533
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
3534
+ }
3535
 
3536
  for (int i = 0; i < n_layer; ++i) {
3537
  auto & layer = layers[i];
 
4136
  if (!dev) {
4137
  // FIXME: workaround for CPU backend buft having a NULL device
4138
  dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
4139
+ if (!dev) {
4140
+ throw std::runtime_error(format("%s: no CPU backend found", __func__));
4141
+ }
4142
  }
4143
  ggml_backend_dev_props props;
4144
  ggml_backend_dev_get_props(dev, &props);
 
4268
  }
4269
 
4270
  void llama_model::print_info() const {
4271
+ const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train);
4272
 
4273
  auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) {
4274
  bool is_var = false;
 
4329
  LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
4330
  LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
4331
  LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
4332
+ LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
4333
  LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
4334
  LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
4335
  LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
 
4476
  return it->second;
4477
  }
4478
 
4479
+ ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
4480
+ // choose long/short freq factors based on the context size
4481
+ if (layers[il].rope_freqs != nullptr) {
4482
+ return layers[il].rope_freqs;
4483
+ }
4484
+
4485
+ if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
4486
+ return layers[il].rope_long;
4487
+ }
4488
+
4489
+ return layers[il].rope_short;
4490
+ }
4491
+
4492
  struct llm_build_llama : public llm_graph_context {
4493
  llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
4494
  const int64_t n_embd_head = hparams.n_embd_head_v;
 
4529
  // self-attention
4530
  {
4531
  // rope freq factors for llama3; may return nullptr for llama2 and other models
4532
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
4533
 
4534
  // compute Q and K and RoPE them
4535
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
4735
  ggml_tensor * inpSA = inpL;
4736
  const int64_t n_head_kv = hparams.n_head_kv(il);
4737
  const int64_t n_head = hparams.n_head(il);
4738
+ const int64_t n_ff = hparams.n_ff(il);
4739
 
4740
  if (n_head == 0) {
4741
  // attention-free layer of Llama-3_1-Nemotron-51B
 
4755
  } else if (n_head > 0) {
4756
  // self-attention
4757
  // rope freq factors for llama3; may return nullptr for llama2 and other models
4758
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
4759
 
4760
  // compute Q and K and RoPE them
4761
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
4811
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
4812
  }
4813
 
4814
+ // FFN-free layer of Llama-3_1-Nemotron-Ultra-253B
4815
+ if (n_ff == 0) {
4816
+ continue;
4817
+ }
4818
+
4819
  // For Granite architecture
4820
  if (hparams.f_residual_scale) {
4821
  cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
 
7242
  // self-attention
7243
  {
7244
  // rope freq factors for 128k context
7245
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
7246
 
7247
  ggml_tensor* attn_norm_output = build_norm(inpL,
7248
  model.layers[il].attn_norm,
 
7994
  for (int il = 0; il < n_layer; ++il) {
7995
  ggml_tensor * inpSA = inpL;
7996
 
7997
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
7998
 
7999
  // norm
8000
  cur = build_norm(inpL,
 
8761
  ggml_tensor * state_mask,
8762
  const llama_ubatch & ubatch,
8763
  int il) const {
8764
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
8765
 
8766
  const auto kv_head = kv_self->head;
8767
 
 
9062
  // self-attention
9063
  {
9064
  // rope freq factors for 128k context
9065
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
9066
 
9067
  // compute Q and K and RoPE them
9068
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
10000
  // self-attention
10001
  {
10002
  // rope freq factors for llama3; may return nullptr for llama2 and other models
10003
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
10004
 
10005
  // compute Q and K and RoPE them
10006
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
11364
  // self-attention
11365
  {
11366
  // rope freq factors for llama3; may return nullptr for llama2 and other models
11367
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
11368
 
11369
  // compute Q and K and RoPE them
11370
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
11509
  ggml_tensor * state_mask,
11510
  const llama_ubatch & ubatch,
11511
  int il) const {
11512
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
11513
 
11514
  const auto n_tokens = ubatch.n_tokens;
11515
  const auto n_seqs = ubatch.n_seqs;
 
11905
  ggml_tensor *& first_layer_value,
11906
  const llama_ubatch & ubatch,
11907
  int il) const {
11908
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
11909
 
11910
  const auto n_tokens = ubatch.n_tokens;
11911
  const auto n_seqs = ubatch.n_seqs;
 
12745
  // self-attention
12746
  {
12747
  // rope freq factors for llama3; may return nullptr for llama2 and other models
12748
+ ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
12749
 
12750
  // compute Q and K and RoPE them
12751
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
12865
  }
12866
  };
12867
 
12868
+ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
12869
  llama_memory_i * res;
12870
 
12871
  switch (arch) {
12872
+ case LLM_ARCH_BERT:
12873
+ case LLM_ARCH_JINA_BERT_V2:
12874
+ case LLM_ARCH_NOMIC_BERT:
12875
+ case LLM_ARCH_NOMIC_BERT_MOE:
12876
+ {
12877
+ res = nullptr;
12878
+ } break;
12879
  case LLM_ARCH_MAMBA:
12880
  case LLM_ARCH_RWKV6:
12881
  case LLM_ARCH_RWKV6QWEN2:
12882
  case LLM_ARCH_RWKV7:
12883
  case LLM_ARCH_ARWKV7:
12884
  {
12885
+ res = new llama_kv_cache_recurrent(
12886
+ *this,
12887
+ GGML_TYPE_F32,
12888
+ GGML_TYPE_F32,
12889
+ cparams.offload_kqv,
12890
+ std::max((uint32_t) 1, cparams.n_seq_max));
12891
  } break;
12892
  default:
12893
  {
12894
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
 
 
 
 
 
12895
 
12896
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
 
 
12897
 
12898
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
12899
+
12900
+ res = new llama_kv_cache_unified(
12901
+ *this,
12902
+ params.type_k,
12903
+ params.type_v,
12904
+ !cparams.flash_attn,
12905
+ cparams.offload_kqv,
12906
+ cparams.n_ctx,
12907
+ padding);
12908
  }
12909
  }
12910
 
 
13286
  case LLM_ARCH_DECI:
13287
  case LLM_ARCH_BAICHUAN:
13288
  case LLM_ARCH_STARCODER:
 
 
13289
  case LLM_ARCH_INTERNLM2:
13290
  case LLM_ARCH_MINICPM:
13291
  case LLM_ARCH_XVERSE:
 
13323
  case LLM_ARCH_PHI2:
13324
  case LLM_ARCH_PHI3:
13325
  case LLM_ARCH_PHIMOE:
13326
+ case LLM_ARCH_PLAMO:
13327
  case LLM_ARCH_GEMMA:
13328
  case LLM_ARCH_GEMMA2:
13329
  case LLM_ARCH_GEMMA3:
 
13331
  case LLM_ARCH_OPENELM:
13332
  case LLM_ARCH_GPTNEOX:
13333
  case LLM_ARCH_CODESHELL:
13334
+ case LLM_ARCH_ORION:
13335
  case LLM_ARCH_NEMOTRON:
13336
  case LLM_ARCH_EXAONE:
13337
  case LLM_ARCH_MINICPM3:
 
13404
  : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
13405
  const auto & it = model->gguf_kv.find(key);
13406
  if (it == model->gguf_kv.end()) {
13407
+ // one-off fix for very popular models (so we are not flooded with issues)
13408
+ // do not extend this list unless absolutely necessary
13409
+ // Mistral-Small-2503 does not have built-in chat template
13410
+ llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
13411
+ if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
13412
+ return "mistral-v7-tekken";
13413
+ }
13414
+
13415
  return nullptr;
13416
  }
13417
 
examples/talk-llama/llama-model.h CHANGED
@@ -76,6 +76,7 @@ enum llm_type {
76
  LLM_TYPE_236B,
77
  LLM_TYPE_290B,
78
  LLM_TYPE_314B,
 
79
  LLM_TYPE_671B,
80
  LLM_TYPE_SMALL,
81
  LLM_TYPE_MEDIUM,
@@ -95,6 +96,8 @@ enum llm_type {
95
  LLM_TYPE_235B_A22B,
96
  };
97
 
 
 
98
  struct llama_layer_posnet {
99
  // resnet
100
  struct ggml_tensor * norm1 = nullptr;
@@ -395,8 +398,11 @@ struct llama_model {
395
 
396
  const struct ggml_tensor * get_tensor(const char * name) const;
397
 
 
 
 
398
  // TODO: move this to new llm_arch_model_i interface
399
- llama_memory_i * create_memory() const; // TODO: params
400
 
401
  // TODO: move this to new llm_arch_model_i interface
402
  llm_graph_result_ptr build_graph(
 
76
  LLM_TYPE_236B,
77
  LLM_TYPE_290B,
78
  LLM_TYPE_314B,
79
+ LLM_TYPE_405B,
80
  LLM_TYPE_671B,
81
  LLM_TYPE_SMALL,
82
  LLM_TYPE_MEDIUM,
 
96
  LLM_TYPE_235B_A22B,
97
  };
98
 
99
+ std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
100
+
101
  struct llama_layer_posnet {
102
  // resnet
103
  struct ggml_tensor * norm1 = nullptr;
 
398
 
399
  const struct ggml_tensor * get_tensor(const char * name) const;
400
 
401
+ ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
402
+
403
+ // note: can mutate `cparams`
404
  // TODO: move this to new llm_arch_model_i interface
405
+ llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
406
 
407
  // TODO: move this to new llm_arch_model_i interface
408
  llm_graph_result_ptr build_graph(
examples/talk-llama/llama-quant.cpp CHANGED
@@ -519,7 +519,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
519
  nthread = std::thread::hardware_concurrency();
520
  }
521
 
522
- // mmap consistently increases speed Linux, and also increases speed on Windows with
523
  // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
524
  #if defined(__linux__) || defined(_WIN32)
525
  constexpr bool use_mmap = true;
@@ -529,7 +529,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
529
 
530
  llama_model_kv_override * kv_overrides = nullptr;
531
  if (params->kv_overrides) {
532
- auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
533
  kv_overrides = v->data();
534
  }
535
 
 
519
  nthread = std::thread::hardware_concurrency();
520
  }
521
 
522
+ // mmap consistently increases speed on Linux, and also increases speed on Windows with
523
  // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
524
  #if defined(__linux__) || defined(_WIN32)
525
  constexpr bool use_mmap = true;
 
529
 
530
  llama_model_kv_override * kv_overrides = nullptr;
531
  if (params->kv_overrides) {
532
+ auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
533
  kv_overrides = v->data();
534
  }
535
 
examples/talk-llama/llama-sampling.cpp CHANGED
@@ -1750,23 +1750,35 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
1750
  static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1751
  const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1752
 
 
 
 
 
1753
  // find max logit and calculate mean
1754
  float max = cur_p->data[0].logit;
1755
  float logits_sum = 0;
 
1756
  for (size_t i = 0; i < cur_p->size; ++i) {
1757
- if (cur_p->data[i].logit > max) {
1758
- max = cur_p->data[i].logit;
 
 
 
 
 
1759
  }
1760
- logits_sum += cur_p->data[i].logit;
1761
  }
1762
- float mean = logits_sum/cur_p->size;
1763
 
1764
  // calculate standard deviation
1765
  float acc = 0;
1766
  for (size_t i = 0; i < cur_p->size; ++i) {
1767
- acc += pow(cur_p->data[i].logit - mean, 2);
 
 
 
1768
  }
1769
- float std = sqrt(acc/cur_p->size);
1770
 
1771
  //apply mask
1772
  for (size_t i = 0; i < cur_p->size; ++i) {
 
1750
  static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1751
  const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1752
 
1753
+ if (ctx->n <= 0.0f || cur_p->size <= 1) {
1754
+ return;
1755
+ }
1756
+
1757
  // find max logit and calculate mean
1758
  float max = cur_p->data[0].logit;
1759
  float logits_sum = 0;
1760
+ size_t valid_count = 0;
1761
  for (size_t i = 0; i < cur_p->size; ++i) {
1762
+ // Only count non-negative infinity values
1763
+ if (cur_p->data[i].logit != -INFINITY) {
1764
+ if (cur_p->data[i].logit > max) {
1765
+ max = cur_p->data[i].logit;
1766
+ }
1767
+ logits_sum += cur_p->data[i].logit;
1768
+ valid_count++;
1769
  }
 
1770
  }
1771
+ float mean = valid_count > 0 ? logits_sum/valid_count : 0;
1772
 
1773
  // calculate standard deviation
1774
  float acc = 0;
1775
  for (size_t i = 0; i < cur_p->size; ++i) {
1776
+ // Skip -infinity in std calculation
1777
+ if (cur_p->data[i].logit != -INFINITY) {
1778
+ acc += pow(cur_p->data[i].logit - mean, 2);
1779
+ }
1780
  }
1781
+ float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
1782
 
1783
  //apply mask
1784
  for (size_t i = 0; i < cur_p->size; ++i) {
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -1,5 +1,7 @@
1
  #include "llama-vocab.h"
2
 
 
 
3
  #include "llama-impl.h"
4
  #include "llama-model-loader.h"
5
 
@@ -415,6 +417,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
415
  "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
416
  };
417
  break;
 
 
 
 
 
 
 
418
  default:
419
  // default regex for BPE tokenization pre-processing
420
  regex_exprs = {
@@ -1227,6 +1236,9 @@ struct fragment_buffer_variant {
1227
  struct llama_vocab::impl {
1228
  uint32_t n_token_types = 0; // for BERT-style token types
1229
 
 
 
 
1230
  enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
1231
  enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
1232
 
@@ -1362,9 +1374,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1362
 
1363
  // determine vocab type
1364
  {
1365
- std::string tokenizer_model;
1366
- std::string tokenizer_pre;
1367
-
1368
  ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
1369
  ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
1370
 
@@ -1459,7 +1468,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1459
 
1460
  const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
1461
  if (precompiled_charsmap_keyidx != -1) {
1462
- size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
 
 
 
1463
  const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
1464
  precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
1465
  #ifdef IS_BIG_ENDIAN
@@ -1634,6 +1646,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1634
  tokenizer_pre == "bailingmoe") {
1635
  pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
1636
  clean_spaces = false;
 
 
 
 
1637
  } else {
1638
  throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
1639
  }
@@ -2778,6 +2794,14 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
2778
  pimpl->load(ml, kv);
2779
  }
2780
 
 
 
 
 
 
 
 
 
2781
  enum llama_vocab_type llama_vocab::get_type() const {
2782
  return pimpl->type;
2783
  }
@@ -3000,6 +3024,20 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string
3000
  return it->second;
3001
  }
3002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3003
  int32_t llama_vocab::tokenize(
3004
  const char * text,
3005
  int32_t text_len,
 
1
  #include "llama-vocab.h"
2
 
3
+ #include "ggml.h"
4
+ #include "gguf.h"
5
  #include "llama-impl.h"
6
  #include "llama-model-loader.h"
7
 
 
417
  "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
418
  };
419
  break;
420
+ case LLAMA_VOCAB_PRE_TYPE_SEED_CODER:
421
+ regex_exprs = {
422
+ // original regex from tokenizer.json
423
+ // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
424
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
425
+ };
426
+ break;
427
  default:
428
  // default regex for BPE tokenization pre-processing
429
  regex_exprs = {
 
1236
  struct llama_vocab::impl {
1237
  uint32_t n_token_types = 0; // for BERT-style token types
1238
 
1239
+ std::string tokenizer_model;
1240
+ std::string tokenizer_pre;
1241
+
1242
  enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
1243
  enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
1244
 
 
1374
 
1375
  // determine vocab type
1376
  {
 
 
 
1377
  ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
1378
  ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
1379
 
 
1468
 
1469
  const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
1470
  if (precompiled_charsmap_keyidx != -1) {
1471
+ const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx);
1472
+ GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8);
1473
+
1474
+ const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
1475
  const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
1476
  precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
1477
  #ifdef IS_BIG_ENDIAN
 
1646
  tokenizer_pre == "bailingmoe") {
1647
  pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
1648
  clean_spaces = false;
1649
+ } else if (
1650
+ tokenizer_pre == "seed-coder") {
1651
+ pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
1652
+ clean_spaces = false;
1653
  } else {
1654
  throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
1655
  }
 
2794
  pimpl->load(ml, kv);
2795
  }
2796
 
2797
+ std::string llama_vocab::get_tokenizer_model() const {
2798
+ return pimpl->tokenizer_model;
2799
+ }
2800
+
2801
+ std::string llama_vocab::get_tokenizer_pre() const {
2802
+ return pimpl->tokenizer_pre;
2803
+ }
2804
+
2805
  enum llama_vocab_type llama_vocab::get_type() const {
2806
  return pimpl->type;
2807
  }
 
3024
  return it->second;
3025
  }
3026
 
3027
+ std::vector<std::string> llama_vocab::get_bpe_merges() const {
3028
+ std::vector<std::string> result(pimpl->bpe_ranks.size());
3029
+
3030
+ for (const auto & pair : pimpl->bpe_ranks) {
3031
+ result[pair.second] = pair.first.first + " " + pair.first.second;
3032
+ }
3033
+
3034
+ return result;
3035
+ }
3036
+
3037
+ std::vector<char> llama_vocab::get_precompiled_charsmap() const {
3038
+ return pimpl->precompiled_charsmap;
3039
+ }
3040
+
3041
  int32_t llama_vocab::tokenize(
3042
  const char * text,
3043
  int32_t text_len,
examples/talk-llama/llama-vocab.h CHANGED
@@ -21,6 +21,9 @@ struct llama_vocab {
21
 
22
  void load(llama_model_loader & ml, const LLM_KV & kv);
23
 
 
 
 
24
  enum llama_vocab_type get_type() const;
25
  enum llama_vocab_pre_type get_pre_type() const;
26
 
@@ -80,6 +83,9 @@ struct llama_vocab {
80
  int max_token_len() const;
81
 
82
  int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
 
 
 
83
 
84
  int32_t tokenize(
85
  const char * text,
 
21
 
22
  void load(llama_model_loader & ml, const LLM_KV & kv);
23
 
24
+ std::string get_tokenizer_model() const;
25
+ std::string get_tokenizer_pre() const;
26
+
27
  enum llama_vocab_type get_type() const;
28
  enum llama_vocab_pre_type get_pre_type() const;
29
 
 
83
  int max_token_len() const;
84
 
85
  int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
86
+ std::vector<std::string> get_bpe_merges() const;
87
+
88
+ std::vector<char> get_precompiled_charsmap() const;
89
 
90
  int32_t tokenize(
91
  const char * text,
examples/talk-llama/llama.cpp CHANGED
@@ -4,6 +4,7 @@
4
  #include "llama-mmap.h"
5
  #include "llama-vocab.h"
6
  #include "llama-model-loader.h"
 
7
  #include "llama-model.h"
8
 
9
  #include "ggml.h"
@@ -16,6 +17,10 @@
16
  #include <cstring>
17
  #include <ctime>
18
 
 
 
 
 
19
  //
20
  // interface implementation
21
  //
@@ -249,6 +254,13 @@ struct llama_model * llama_model_load_from_splits(
249
  return llama_model_load_from_file_impl(splits.front(), splits, params);
250
  }
251
 
 
 
 
 
 
 
 
252
  //
253
  // chat templates
254
  //
@@ -334,3 +346,4 @@ const char * llama_print_system_info(void) {
334
 
335
  return s.c_str();
336
  }
 
 
4
  #include "llama-mmap.h"
5
  #include "llama-vocab.h"
6
  #include "llama-model-loader.h"
7
+ #include "llama-model-saver.h"
8
  #include "llama-model.h"
9
 
10
  #include "ggml.h"
 
17
  #include <cstring>
18
  #include <ctime>
19
 
20
+ #if defined(_MSC_VER)
21
+ #pragma warning(disable: 4244 4267) // possible loss of data
22
+ #endif
23
+
24
  //
25
  // interface implementation
26
  //
 
254
  return llama_model_load_from_file_impl(splits.front(), splits, params);
255
  }
256
 
257
+ void llama_model_save_to_file(const struct llama_model * model, const char * path_model) {
258
+ llama_model_saver ms(*model);
259
+ ms.add_kv_from_model();
260
+ ms.add_tensors_from_model();
261
+ ms.save(path_model);
262
+ }
263
+
264
  //
265
  // chat templates
266
  //
 
346
 
347
  return s.c_str();
348
  }
349
+
examples/talk-llama/llama.h CHANGED
@@ -4,6 +4,7 @@
4
  #include "ggml.h"
5
  #include "ggml-cpu.h"
6
  #include "ggml-backend.h"
 
7
 
8
  #include <stddef.h>
9
  #include <stdint.h>
@@ -112,6 +113,7 @@ extern "C" {
112
  LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
113
  LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
114
  LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
 
115
  };
116
 
117
  enum llama_rope_type {
@@ -343,7 +345,7 @@ extern "C" {
343
  float yarn_beta_fast; // YaRN low correction dim
344
  float yarn_beta_slow; // YaRN high correction dim
345
  uint32_t yarn_orig_ctx; // YaRN original context size
346
- float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
347
 
348
  ggml_backend_sched_eval_callback cb_eval;
349
  void * cb_eval_user_data;
@@ -351,19 +353,18 @@ extern "C" {
351
  enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
352
  enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
353
 
354
- // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
355
- // TODO: move at the end of the struct
356
- bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
357
- bool embeddings; // if true, extract embeddings (together with logits)
358
- bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
359
- bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
360
- bool no_perf; // whether to measure performance timings
361
-
362
  // Abort callback
363
  // if it returns true, execution of llama_decode() will be aborted
364
  // currently works only with CPU execution
365
  ggml_abort_callback abort_callback;
366
  void * abort_callback_data;
 
 
 
 
 
 
 
367
  };
368
 
369
  // model quantization parameters
@@ -445,6 +446,10 @@ extern "C" {
445
  size_t n_paths,
446
  struct llama_model_params params);
447
 
 
 
 
 
448
  DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
449
  "use llama_model_free instead");
450
 
@@ -924,14 +929,19 @@ extern "C" {
924
  // Frees a batch of tokens allocated with llama_batch_init()
925
  LLAMA_API void llama_batch_free(struct llama_batch batch);
926
 
927
- // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
928
- // Stores the encoder output internally for later use by the decoder cross-attention layers.
 
 
929
  // 0 - success
930
  // < 0 - error. the KV cache state is restored to the state before this call
931
  LLAMA_API int32_t llama_encode(
932
  struct llama_context * ctx,
933
  struct llama_batch batch);
934
 
 
 
 
935
  // Positive return values does not mean a fatal error, but rather a warning.
936
  // 0 - success
937
  // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
@@ -1428,6 +1438,37 @@ extern "C" {
1428
  LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
1429
  LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
1430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1431
  #ifdef __cplusplus
1432
  }
1433
  #endif
 
4
  #include "ggml.h"
5
  #include "ggml-cpu.h"
6
  #include "ggml-backend.h"
7
+ #include "ggml-opt.h"
8
 
9
  #include <stddef.h>
10
  #include <stdint.h>
 
113
  LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
114
  LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
115
  LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
116
+ LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
117
  };
118
 
119
  enum llama_rope_type {
 
345
  float yarn_beta_fast; // YaRN low correction dim
346
  float yarn_beta_slow; // YaRN high correction dim
347
  uint32_t yarn_orig_ctx; // YaRN original context size
348
+ float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
349
 
350
  ggml_backend_sched_eval_callback cb_eval;
351
  void * cb_eval_user_data;
 
353
  enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
354
  enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
355
 
 
 
 
 
 
 
 
 
356
  // Abort callback
357
  // if it returns true, execution of llama_decode() will be aborted
358
  // currently works only with CPU execution
359
  ggml_abort_callback abort_callback;
360
  void * abort_callback_data;
361
+
362
+ // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
363
+ bool embeddings; // if true, extract embeddings (together with logits)
364
+ bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
365
+ bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
366
+ bool no_perf; // whether to measure performance timings
367
+ bool op_offload; // whether to offload host tensor operations to device
368
  };
369
 
370
  // model quantization parameters
 
446
  size_t n_paths,
447
  struct llama_model_params params);
448
 
449
+ LLAMA_API void llama_model_save_to_file(
450
+ const struct llama_model * model,
451
+ const char * path_model);
452
+
453
  DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
454
  "use llama_model_free instead");
455
 
 
929
  // Frees a batch of tokens allocated with llama_batch_init()
930
  LLAMA_API void llama_batch_free(struct llama_batch batch);
931
 
932
+ // Process a batch of tokens.
933
+ // In contrast to llama_decode() - this call does not use KV cache.
934
+ // For encode-decoder contexts, processes the batch using the encoder.
935
+ // Can store the encoder output internally for later use by the decoder's cross-attention layers.
936
  // 0 - success
937
  // < 0 - error. the KV cache state is restored to the state before this call
938
  LLAMA_API int32_t llama_encode(
939
  struct llama_context * ctx,
940
  struct llama_batch batch);
941
 
942
+ // Process a batch of tokens.
943
+ // Requires KV cache.
944
+ // For encode-decoder contexts, processes the batch using the decoder.
945
  // Positive return values does not mean a fatal error, but rather a warning.
946
  // 0 - success
947
  // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
 
1438
  LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
1439
  LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
1440
 
1441
+ //
1442
+ // training
1443
+ //
1444
+
1445
+ // function that returns whether or not a given tensor contains trainable parameters
1446
+ typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
1447
+
1448
+ // always returns true
1449
+ LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
1450
+
1451
+ struct llama_opt_params {
1452
+ uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
1453
+
1454
+ llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
1455
+ void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
1456
+
1457
+ ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
1458
+ void * get_opt_pars_ud; // userdata for calculating optimizer parameters
1459
+ };
1460
+
1461
+ LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
1462
+
1463
+ LLAMA_API void llama_opt_epoch(
1464
+ struct llama_context * lctx,
1465
+ ggml_opt_dataset_t dataset,
1466
+ ggml_opt_result_t result_train,
1467
+ ggml_opt_result_t result_eval,
1468
+ int64_t idata_split,
1469
+ ggml_opt_epoch_callback callback_train,
1470
+ ggml_opt_epoch_callback callback_eval);
1471
+
1472
  #ifdef __cplusplus
1473
  }
1474
  #endif