ggerganov commited on
Commit
0ec1374
·
1 Parent(s): e37767f

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/llama-arch.cpp CHANGED
@@ -42,6 +42,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
42
  { LLM_ARCH_GEMMA, "gemma" },
43
  { LLM_ARCH_GEMMA2, "gemma2" },
44
  { LLM_ARCH_GEMMA3, "gemma3" },
 
45
  { LLM_ARCH_STARCODER2, "starcoder2" },
46
  { LLM_ARCH_MAMBA, "mamba" },
47
  { LLM_ARCH_XVERSE, "xverse" },
@@ -75,6 +76,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
75
  { LLM_ARCH_BAILINGMOE, "bailingmoe" },
76
  { LLM_ARCH_DOTS1, "dots1" },
77
  { LLM_ARCH_ARCEE, "arcee" },
 
78
  { LLM_ARCH_UNKNOWN, "(unknown)" },
79
  };
80
 
@@ -932,6 +934,42 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
932
  { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
933
  },
934
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935
  {
936
  LLM_ARCH_STARCODER2,
937
  {
@@ -1621,6 +1659,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1621
  { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1622
  }
1623
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1624
  {
1625
  LLM_ARCH_UNKNOWN,
1626
  {
@@ -1749,6 +1804,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
1749
  {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
1750
  {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
1751
  {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1752
  // this tensor is loaded for T5, but never used
1753
  {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
1754
  {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
 
42
  { LLM_ARCH_GEMMA, "gemma" },
43
  { LLM_ARCH_GEMMA2, "gemma2" },
44
  { LLM_ARCH_GEMMA3, "gemma3" },
45
+ { LLM_ARCH_GEMMA3N, "gemma3n" },
46
  { LLM_ARCH_STARCODER2, "starcoder2" },
47
  { LLM_ARCH_MAMBA, "mamba" },
48
  { LLM_ARCH_XVERSE, "xverse" },
 
76
  { LLM_ARCH_BAILINGMOE, "bailingmoe" },
77
  { LLM_ARCH_DOTS1, "dots1" },
78
  { LLM_ARCH_ARCEE, "arcee" },
79
+ { LLM_ARCH_ERNIE4_5, "ernie4_5" },
80
  { LLM_ARCH_UNKNOWN, "(unknown)" },
81
  };
82
 
 
934
  { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
935
  },
936
  },
937
+ {
938
+ LLM_ARCH_GEMMA3N,
939
+ {
940
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
941
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
942
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
943
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
944
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
945
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
946
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
947
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
948
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
949
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
950
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
951
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
952
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
953
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
954
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
955
+ { LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
956
+ { LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
957
+ { LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
958
+ { LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" },
959
+ { LLM_TENSOR_ALTUP_PROJ, "altup_proj" },
960
+ { LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" },
961
+ { LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" },
962
+ { LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
963
+ { LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" },
964
+ { LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" },
965
+ { LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" },
966
+ { LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" },
967
+ { LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" },
968
+ { LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" },
969
+ { LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" },
970
+ { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
971
+ },
972
+ },
973
  {
974
  LLM_ARCH_STARCODER2,
975
  {
 
1659
  { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1660
  }
1661
  },
1662
+ {
1663
+ LLM_ARCH_ERNIE4_5,
1664
+ {
1665
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1666
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1667
+ { LLM_TENSOR_OUTPUT, "output" },
1668
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1669
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1670
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1671
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1672
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1673
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1674
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1675
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1676
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1677
+ },
1678
+ },
1679
  {
1680
  LLM_ARCH_UNKNOWN,
1681
  {
 
1804
  {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
1805
  {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
1806
  {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1807
+ // altup / laurel (gemma 3n)
1808
+ {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
1809
+ {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
1810
+ {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
1811
+ {LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
1812
+ {LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
1813
+ {LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1814
+ {LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1815
+ {LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1816
+ {LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1817
+ {LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1818
+ {LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1819
+ {LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1820
+ {LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1821
+ {LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1822
+ {LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1823
+ {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1824
  // this tensor is loaded for T5, but never used
1825
  {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
1826
  {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
examples/talk-llama/llama-arch.h CHANGED
@@ -46,6 +46,7 @@ enum llm_arch {
46
  LLM_ARCH_GEMMA,
47
  LLM_ARCH_GEMMA2,
48
  LLM_ARCH_GEMMA3,
 
49
  LLM_ARCH_STARCODER2,
50
  LLM_ARCH_MAMBA,
51
  LLM_ARCH_XVERSE,
@@ -79,6 +80,7 @@ enum llm_arch {
79
  LLM_ARCH_BAILINGMOE,
80
  LLM_ARCH_DOTS1,
81
  LLM_ARCH_ARCEE,
 
82
  LLM_ARCH_UNKNOWN,
83
  };
84
 
@@ -269,6 +271,22 @@ enum llm_tensor {
269
  LLM_TENSOR_LAYER_OUT_NORM,
270
  LLM_TENSOR_POST_ATTN_NORM,
271
  LLM_TENSOR_POST_MLP_NORM,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  LLM_TENSOR_SSM_IN,
273
  LLM_TENSOR_SSM_CONV1D,
274
  LLM_TENSOR_SSM_X,
 
46
  LLM_ARCH_GEMMA,
47
  LLM_ARCH_GEMMA2,
48
  LLM_ARCH_GEMMA3,
49
+ LLM_ARCH_GEMMA3N,
50
  LLM_ARCH_STARCODER2,
51
  LLM_ARCH_MAMBA,
52
  LLM_ARCH_XVERSE,
 
80
  LLM_ARCH_BAILINGMOE,
81
  LLM_ARCH_DOTS1,
82
  LLM_ARCH_ARCEE,
83
+ LLM_ARCH_ERNIE4_5,
84
  LLM_ARCH_UNKNOWN,
85
  };
86
 
 
271
  LLM_TENSOR_LAYER_OUT_NORM,
272
  LLM_TENSOR_POST_ATTN_NORM,
273
  LLM_TENSOR_POST_MLP_NORM,
274
+ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
275
+ LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
276
+ LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n
277
+ LLM_TENSOR_PER_LAYER_PROJ, // gemma3n
278
+ LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n
279
+ LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n
280
+ LLM_TENSOR_ALTUP_PROJ, // gemma3n
281
+ LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n
282
+ LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n
283
+ LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n
284
+ LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n
285
+ LLM_TENSOR_ALTUP_ROUTER, // gemma3n
286
+ LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n
287
+ LLM_TENSOR_LAUREL_L, // gemma3n
288
+ LLM_TENSOR_LAUREL_R, // gemma3n
289
+ LLM_TENSOR_LAUREL_POST_NORM, // gemma3n
290
  LLM_TENSOR_SSM_IN,
291
  LLM_TENSOR_SSM_CONV1D,
292
  LLM_TENSOR_SSM_X,
examples/talk-llama/llama-batch.cpp CHANGED
@@ -244,22 +244,35 @@ bool llama_batch_allocr::init(
244
  continue;
245
  }
246
 
247
- if (memory) {
 
 
 
 
248
  if (batch.token) {
249
- if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
250
- LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
251
- return false;
252
  }
253
  } else {
254
  assert(batch.embd);
255
 
256
  // for embeddings (typically used as vision input), we allow them to have repeating positions
257
  // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
258
- if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
259
- LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
260
- return false;
261
  }
262
  }
 
 
 
 
 
 
 
 
 
 
 
263
  }
264
 
265
  if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
 
244
  continue;
245
  }
246
 
247
+ const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
248
+
249
+ if (p0 >= 0) {
250
+ bool ok = true;
251
+
252
  if (batch.token) {
253
+ if (seq_pos_min(s) != p0 + 1) {
254
+ ok = false;
 
255
  }
256
  } else {
257
  assert(batch.embd);
258
 
259
  // for embeddings (typically used as vision input), we allow them to have repeating positions
260
  // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
261
+ if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
262
+ ok = false;
 
263
  }
264
  }
265
+
266
+ if (!ok) {
267
+ LLAMA_LOG_ERROR(
268
+ "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
269
+ " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
270
+ " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
271
+ " it is required that the sequence positions remain consecutive: Y = X + 1\n",
272
+ __func__, s, s, p0, s, seq_pos_min(s));
273
+
274
+ return false;
275
+ }
276
  }
277
 
278
  if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
examples/talk-llama/llama-chat.cpp CHANGED
@@ -528,12 +528,17 @@ int32_t llm_chat_apply_template(
528
  }
529
  } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
530
  // this template requires the model to have "\n\n" as EOT token
531
- for (auto message : chat) {
532
- std::string role(message->role);
533
- if (role == "user") {
534
- ss << "User: " << message->content << "\n\nAssistant:";
535
- } else {
536
- ss << message->content << "\n\n";
 
 
 
 
 
537
  }
538
  }
539
  } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
 
528
  }
529
  } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
530
  // this template requires the model to have "\n\n" as EOT token
531
+ for (size_t i = 0; i < chat.size(); i++) {
532
+ std::string role(chat[i]->role);
533
+ if (role == "system") {
534
+ ss << "System: " << trim(chat[i]->content) << "\n\n";
535
+ } else if (role == "user") {
536
+ ss << "User: " << trim(chat[i]->content) << "\n\n";
537
+ if (i == chat.size() - 1) {
538
+ ss << "Assistant:";
539
+ }
540
+ } else if (role == "assistant") {
541
+ ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
542
  }
543
  }
544
  } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
examples/talk-llama/llama-context.cpp CHANGED
@@ -280,8 +280,8 @@ llama_context::llama_context(
280
 
281
  // simulate full KV cache
282
 
283
- const auto mstate = memory->init_full();
284
- if (!mstate) {
285
  throw std::runtime_error("failed to initialize KV cache");
286
  }
287
 
@@ -289,7 +289,7 @@ llama_context::llama_context(
289
 
290
  // reserve pp graph first so that buffers are only allocated once
291
  {
292
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
293
  if (!gf) {
294
  throw std::runtime_error("failed to allocate compute pp buffers");
295
  }
@@ -300,7 +300,7 @@ llama_context::llama_context(
300
 
301
  // reserve with tg graph to get the number of splits and nodes
302
  {
303
- auto * gf = graph_reserve(1, 1, 1, mstate.get());
304
  if (!gf) {
305
  throw std::runtime_error("failed to allocate compute tg buffers");
306
  }
@@ -311,7 +311,7 @@ llama_context::llama_context(
311
 
312
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
313
  {
314
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
315
  if (!gf) {
316
  throw std::runtime_error("failed to allocate compute pp buffers");
317
  }
@@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
444
  optimize |= memory_force_optimize;
445
  memory_force_optimize = false;
446
 
447
- const auto mstate = memory->init_update(this, optimize);
448
- switch (mstate->get_status()) {
449
  case LLAMA_MEMORY_STATUS_SUCCESS:
450
  {
451
  // noop
@@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
463
  }
464
  }
465
 
466
- if (!mstate->apply()) {
467
  LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
  }
469
  }
470
 
471
  // if the memory module did any computation, we have to reserve a new worst-case graph
472
  {
473
- const auto mstate = memory->init_full();
474
- if (!mstate) {
475
- throw std::runtime_error("failed to initialize memory state");
476
  }
477
 
478
  const uint32_t n_seqs = cparams.n_seq_max;
479
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
480
 
481
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
482
  if (!gf) {
483
  LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
484
  }
@@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
678
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
679
  }
680
 
681
- llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
682
- if (mstate && !mstate->apply()) {
683
- LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
684
  ret = GGML_STATUS_FAILED;
685
  return nullptr;
686
  }
@@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
692
  return nullptr;
693
  }
694
 
695
- auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
696
  if (!res) {
697
  LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
  ret = GGML_STATUS_FAILED;
@@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
933
  // handle any pending defrags/shifts
934
  kv_self_update(false);
935
 
936
- llama_memory_state_ptr mstate;
937
 
938
  while (true) {
939
- mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
940
- if (!mstate) {
941
  return -2;
942
  }
943
 
944
- switch (mstate->get_status()) {
945
  case LLAMA_MEMORY_STATUS_SUCCESS:
946
  {
947
  } break;
948
  case LLAMA_MEMORY_STATUS_NO_UPDATE:
949
  {
950
- LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
951
 
952
  return -2;
953
  }
@@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
987
  int64_t n_outputs_prev = 0;
988
 
989
  do {
990
- const auto & ubatch = mstate->get_ubatch();
991
 
992
  // count the outputs in this ubatch
993
  {
@@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1009
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1010
 
1011
  ggml_status status;
1012
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
1013
 
1014
  if (!res) {
1015
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1018
  pos_min[s] = std::numeric_limits<llama_pos>::max();
1019
  }
1020
 
1021
- // TODO: fix sequence indexing
1022
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1023
  const auto & seq_id = ubatch.seq_id[i][0];
1024
 
@@ -1126,7 +1125,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1126
  }
1127
 
1128
  n_outputs_prev += n_outputs;
1129
- } while (mstate->next());
1130
 
1131
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1132
  n_outputs = n_outputs_all;
@@ -1292,7 +1291,7 @@ ggml_cgraph * llama_context::graph_init() {
1292
  return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1293
  }
1294
 
1295
- ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
1296
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1297
 
1298
  if (n_tokens % n_seqs != 0) {
@@ -1312,7 +1311,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1312
  llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1313
 
1314
  auto * gf = graph_init();
1315
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
1316
 
1317
  this->n_outputs = save_n_outputs;
1318
 
@@ -1333,11 +1332,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1333
  }
1334
 
1335
  llm_graph_result_ptr llama_context::graph_build(
1336
- ggml_context * ctx,
1337
- ggml_cgraph * gf,
1338
- const llama_ubatch & ubatch,
1339
- llm_graph_type gtype,
1340
- const llama_memory_state_i * mstate) {
1341
  return model.build_graph(
1342
  {
1343
  /*.ctx =*/ ctx,
@@ -1349,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
1349
  /*.backend_cpu =*/ backend_cpu,
1350
  /*.cvec =*/ &cvec,
1351
  /*.loras =*/ &loras,
1352
- /*.mstate =*/ mstate,
1353
  /*.cross =*/ &cross,
1354
  /*.n_outputs =*/ n_outputs,
1355
  /*.cb =*/ graph_get_cb(),
@@ -2042,8 +2041,8 @@ void llama_context::opt_epoch_iter(
2042
 
2043
  uint32_t n_outputs_all = n_tokens_all;
2044
 
2045
- auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
2046
- if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2047
  LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2048
  break;
2049
  }
@@ -2056,17 +2055,17 @@ void llama_context::opt_epoch_iter(
2056
 
2057
  uint32_t pos_batch = 0;
2058
  do {
2059
- const auto & ubatch = mstate->get_ubatch();
2060
 
2061
  n_outputs = ubatch.n_tokens;
2062
 
2063
- if (!mstate->apply()) {
2064
- LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
2065
  break;
2066
  }
2067
 
2068
  auto * gf = graph_init();
2069
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
2070
 
2071
  struct ggml_context * ctx_compute_opt;
2072
  {
@@ -2101,7 +2100,7 @@ void llama_context::opt_epoch_iter(
2101
  ggml_free(ctx_compute_opt);
2102
 
2103
  pos_batch += ubatch.n_tokens;
2104
- } while (mstate->next());
2105
  }
2106
  }
2107
 
 
280
 
281
  // simulate full KV cache
282
 
283
+ const auto mctx = memory->init_full();
284
+ if (!mctx) {
285
  throw std::runtime_error("failed to initialize KV cache");
286
  }
287
 
 
289
 
290
  // reserve pp graph first so that buffers are only allocated once
291
  {
292
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
293
  if (!gf) {
294
  throw std::runtime_error("failed to allocate compute pp buffers");
295
  }
 
300
 
301
  // reserve with tg graph to get the number of splits and nodes
302
  {
303
+ auto * gf = graph_reserve(1, 1, 1, mctx.get());
304
  if (!gf) {
305
  throw std::runtime_error("failed to allocate compute tg buffers");
306
  }
 
311
 
312
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
313
  {
314
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
315
  if (!gf) {
316
  throw std::runtime_error("failed to allocate compute pp buffers");
317
  }
 
444
  optimize |= memory_force_optimize;
445
  memory_force_optimize = false;
446
 
447
+ const auto mctx = memory->init_update(this, optimize);
448
+ switch (mctx->get_status()) {
449
  case LLAMA_MEMORY_STATUS_SUCCESS:
450
  {
451
  // noop
 
463
  }
464
  }
465
 
466
+ if (!mctx->apply()) {
467
  LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
  }
469
  }
470
 
471
  // if the memory module did any computation, we have to reserve a new worst-case graph
472
  {
473
+ const auto mctx = memory->init_full();
474
+ if (!mctx) {
475
+ throw std::runtime_error("failed to initialize memory context");
476
  }
477
 
478
  const uint32_t n_seqs = cparams.n_seq_max;
479
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
480
 
481
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
482
  if (!gf) {
483
  LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
484
  }
 
678
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
679
  }
680
 
681
+ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682
+ if (mctx && !mctx->apply()) {
683
+ LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684
  ret = GGML_STATUS_FAILED;
685
  return nullptr;
686
  }
 
692
  return nullptr;
693
  }
694
 
695
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696
  if (!res) {
697
  LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
  ret = GGML_STATUS_FAILED;
 
933
  // handle any pending defrags/shifts
934
  kv_self_update(false);
935
 
936
+ llama_memory_context_ptr mctx;
937
 
938
  while (true) {
939
+ mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
940
+ if (!mctx) {
941
  return -2;
942
  }
943
 
944
+ switch (mctx->get_status()) {
945
  case LLAMA_MEMORY_STATUS_SUCCESS:
946
  {
947
  } break;
948
  case LLAMA_MEMORY_STATUS_NO_UPDATE:
949
  {
950
+ LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
951
 
952
  return -2;
953
  }
 
987
  int64_t n_outputs_prev = 0;
988
 
989
  do {
990
+ const auto & ubatch = mctx->get_ubatch();
991
 
992
  // count the outputs in this ubatch
993
  {
 
1009
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1010
 
1011
  ggml_status status;
1012
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1013
 
1014
  if (!res) {
1015
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
 
1018
  pos_min[s] = std::numeric_limits<llama_pos>::max();
1019
  }
1020
 
 
1021
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1022
  const auto & seq_id = ubatch.seq_id[i][0];
1023
 
 
1125
  }
1126
 
1127
  n_outputs_prev += n_outputs;
1128
+ } while (mctx->next());
1129
 
1130
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1131
  n_outputs = n_outputs_all;
 
1291
  return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1292
  }
1293
 
1294
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1295
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1296
 
1297
  if (n_tokens % n_seqs != 0) {
 
1311
  llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1312
 
1313
  auto * gf = graph_init();
1314
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1315
 
1316
  this->n_outputs = save_n_outputs;
1317
 
 
1332
  }
1333
 
1334
  llm_graph_result_ptr llama_context::graph_build(
1335
+ ggml_context * ctx,
1336
+ ggml_cgraph * gf,
1337
+ const llama_ubatch & ubatch,
1338
+ llm_graph_type gtype,
1339
+ const llama_memory_context_i * mctx) {
1340
  return model.build_graph(
1341
  {
1342
  /*.ctx =*/ ctx,
 
1348
  /*.backend_cpu =*/ backend_cpu,
1349
  /*.cvec =*/ &cvec,
1350
  /*.loras =*/ &loras,
1351
+ /*.mctx =*/ mctx,
1352
  /*.cross =*/ &cross,
1353
  /*.n_outputs =*/ n_outputs,
1354
  /*.cb =*/ graph_get_cb(),
 
2041
 
2042
  uint32_t n_outputs_all = n_tokens_all;
2043
 
2044
+ auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
2045
+ if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2046
  LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2047
  break;
2048
  }
 
2055
 
2056
  uint32_t pos_batch = 0;
2057
  do {
2058
+ const auto & ubatch = mctx->get_ubatch();
2059
 
2060
  n_outputs = ubatch.n_tokens;
2061
 
2062
+ if (!mctx->apply()) {
2063
+ LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
2064
  break;
2065
  }
2066
 
2067
  auto * gf = graph_init();
2068
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
2069
 
2070
  struct ggml_context * ctx_compute_opt;
2071
  {
 
2100
  ggml_free(ctx_compute_opt);
2101
 
2102
  pos_batch += ubatch.n_tokens;
2103
+ } while (mctx->next());
2104
  }
2105
  }
2106
 
examples/talk-llama/llama-context.h CHANGED
@@ -18,7 +18,7 @@ class llama_io_read_i;
18
  class llama_io_write_i;
19
 
20
  struct llama_memory_i;
21
- struct llama_memory_state_i;
22
 
23
  struct llama_context {
24
  // init scheduler and compute buffers, reserve worst-case graphs
@@ -93,14 +93,14 @@ struct llama_context {
93
  int32_t il_end);
94
 
95
  // process a single ubatch with a specific graph type
96
- // if memory_state is provided, it will be applied first to the context's memory
97
  // ret contains the status of the graph computation
98
  // returns nullptr only if ret != GGML_STATUS_SUCCESS
99
  llm_graph_result_ptr process_ubatch(
100
- const llama_ubatch & ubatch,
101
- llm_graph_type gtype,
102
- llama_memory_state_i * mstate,
103
- ggml_status & ret);
104
 
105
  int encode(const llama_batch & batch_inp);
106
  int decode(const llama_batch & batch_inp);
@@ -197,15 +197,15 @@ public:
197
  ggml_status graph_compute(ggml_cgraph * gf, bool batched);
198
 
199
  // reserve a graph with a dummy ubatch of the specified size
200
- ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
201
 
202
  private:
203
  llm_graph_result_ptr graph_build(
204
- ggml_context * ctx,
205
- ggml_cgraph * gf,
206
- const llama_ubatch & ubatch,
207
- llm_graph_type gtype,
208
- const llama_memory_state_i * mstate);
209
 
210
  llm_graph_cb graph_get_cb() const;
211
 
 
18
  class llama_io_write_i;
19
 
20
  struct llama_memory_i;
21
+ struct llama_memory_context_i;
22
 
23
  struct llama_context {
24
  // init scheduler and compute buffers, reserve worst-case graphs
 
93
  int32_t il_end);
94
 
95
  // process a single ubatch with a specific graph type
96
+ // if memory_context is provided, it will be applied first to the context's memory
97
  // ret contains the status of the graph computation
98
  // returns nullptr only if ret != GGML_STATUS_SUCCESS
99
  llm_graph_result_ptr process_ubatch(
100
+ const llama_ubatch & ubatch,
101
+ llm_graph_type gtype,
102
+ llama_memory_context_i * mctx,
103
+ ggml_status & ret);
104
 
105
  int encode(const llama_batch & batch_inp);
106
  int decode(const llama_batch & batch_inp);
 
197
  ggml_status graph_compute(ggml_cgraph * gf, bool batched);
198
 
199
  // reserve a graph with a dummy ubatch of the specified size
200
+ ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
201
 
202
  private:
203
  llm_graph_result_ptr graph_build(
204
+ ggml_context * ctx,
205
+ ggml_cgraph * gf,
206
+ const llama_ubatch & ubatch,
207
+ llm_graph_type gtype,
208
+ const llama_memory_context_i * mctx);
209
 
210
  llm_graph_cb graph_get_cb() const;
211
 
examples/talk-llama/llama-graph.cpp CHANGED
@@ -87,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
87
 
88
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
89
  if (pos_bucket) {
90
- kv_state->set_input_pos_bucket(pos_bucket, ubatch);
91
  }
92
  }
93
 
@@ -221,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
221
  void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
222
  GGML_UNUSED(ubatch);
223
 
224
- const int64_t n_rs = mem_state->get_n_rs();
225
 
226
  if (s_copy) {
227
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -229,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
229
 
230
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
231
  for (uint32_t i = 0; i < n_rs; ++i) {
232
- data[i] = mem_state->s_copy(i);
233
  }
234
  }
235
  }
@@ -282,17 +282,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
282
 
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284
  if (self_kq_mask) {
285
- kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
286
  }
287
  }
288
 
289
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
290
  if (self_kq_mask) {
291
- kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
292
  }
293
 
294
  if (self_kq_mask_swa) {
295
- kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
296
  }
297
  }
298
 
@@ -334,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
334
 
335
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
  if (self_kq_mask) {
337
- mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
  }
339
 
340
- const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
341
 
342
  if (s_copy) {
343
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -345,11 +345,17 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
345
 
346
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
  for (uint32_t i = 0; i < n_rs; ++i) {
348
- data[i] = mem_state->get_state_recr()->s_copy(i);
349
  }
350
  }
351
  }
352
 
 
 
 
 
 
 
353
  //
354
  // llm_graph_context
355
  //
@@ -389,7 +395,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
389
  backend_cpu (params.backend_cpu),
390
  cvec (params.cvec),
391
  loras (params.loras),
392
- mstate (params.mstate),
393
  cross (params.cross),
394
  cb_func (params.cb),
395
  res (std::make_unique<llm_graph_result>()) {
@@ -554,12 +560,20 @@ ggml_tensor * llm_graph_context::build_ffn(
554
 
555
  switch (type_op) {
556
  case LLM_FFN_SILU:
557
- {
 
 
 
 
558
  cur = ggml_silu(ctx0, cur);
559
  cb(cur, "ffn_silu", il);
560
  } break;
561
  case LLM_FFN_GELU:
562
- {
 
 
 
 
563
  cur = ggml_gelu(ctx0, cur);
564
  cb(cur, "ffn_gelu", il);
565
  if (act_scales != NULL) {
@@ -568,7 +582,11 @@ ggml_tensor * llm_graph_context::build_ffn(
568
  }
569
  } break;
570
  case LLM_FFN_RELU:
571
- {
 
 
 
 
572
  cur = ggml_relu(ctx0, cur);
573
  cb(cur, "ffn_relu", il);
574
  } break;
@@ -582,32 +600,19 @@ ggml_tensor * llm_graph_context::build_ffn(
582
  } break;
583
  case LLM_FFN_SWIGLU:
584
  {
585
- // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
586
- int64_t split_point = cur->ne[0] / 2;
587
- // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
588
- ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
589
- ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
590
-
591
- x0 = ggml_silu(ctx0, x0);
592
- cb(cur, "ffn_silu", il);
593
-
594
- cur = ggml_mul(ctx0, x0, x1);
595
- cb(cur, "ffn_mul", il);
596
  } break;
597
  case LLM_FFN_GEGLU:
598
  {
599
- // Split into two equal parts
600
- int64_t split_point = cur->ne[0] / 2;
601
- // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
602
- ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
603
- ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
604
-
605
- x0 = ggml_gelu(ctx0, x0);
606
- cb(x0, "ffn_gelu", il);
607
-
608
- cur = ggml_mul(ctx0, x0, x1);
609
  cb(cur, "ffn_geglu", il);
610
  } break;
 
 
 
 
 
611
  }
612
 
613
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -737,12 +742,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
737
 
738
  switch (type_op) {
739
  case LLM_FFN_SILU:
740
- {
 
 
 
741
  cur = ggml_silu(ctx0, cur);
742
  cb(cur, "ffn_moe_silu", il);
743
  } break;
744
  case LLM_FFN_GELU:
745
- {
 
 
 
746
  cur = ggml_gelu(ctx0, cur);
747
  cb(cur, "ffn_moe_gelu", il);
748
  } break;
@@ -750,11 +761,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
750
  GGML_ABORT("fatal error");
751
  }
752
 
753
- if (gate_exps) {
754
- cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
755
- cb(cur, "ffn_moe_gate_par", il);
756
- }
757
-
758
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
759
  cb(experts, "ffn_moe_down", il);
760
 
@@ -950,11 +956,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
950
  }
951
 
952
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
953
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
954
 
955
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
956
 
957
- const auto n_kv = kv_state->get_n_kv();
958
 
959
  auto & cur = inp->pos_bucket;
960
 
@@ -982,14 +988,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
982
  }
983
 
984
  llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
985
- const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
986
 
987
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
988
 
989
  {
990
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
991
 
992
- const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
993
 
994
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
995
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -999,7 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
999
  }
1000
 
1001
  {
1002
- const auto n_rs = mem_state->get_state_recr()->get_n_rs();
1003
 
1004
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1005
  ggml_set_input(inp->s_copy);
@@ -1183,14 +1189,14 @@ ggml_tensor * llm_graph_context::build_attn(
1183
  }
1184
 
1185
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1186
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1187
 
1188
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
1189
 
1190
  {
1191
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1192
 
1193
- const auto n_kv = kv_state->get_n_kv();
1194
 
1195
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1196
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1220,19 +1226,19 @@ ggml_tensor * llm_graph_context::build_attn(
1220
  ggml_build_forward_expand(gf, k_cur);
1221
  ggml_build_forward_expand(gf, v_cur);
1222
 
1223
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1224
 
1225
  // store to KV cache
1226
  {
1227
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1228
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1229
  }
1230
 
1231
  const auto & kq_mask = inp->get_kq_mask();
1232
 
1233
  ggml_tensor * q = q_cur;
1234
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1235
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1236
 
1237
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1238
  cb(cur, "kqv_out", il);
@@ -1267,26 +1273,35 @@ ggml_tensor * llm_graph_context::build_attn(
1267
  // these nodes are added to the graph together so that they are not reordered
1268
  // by doing so, the number of splits in the graph is reduced
1269
  ggml_build_forward_expand(gf, q_cur);
1270
- ggml_build_forward_expand(gf, k_cur);
1271
- ggml_build_forward_expand(gf, v_cur);
1272
 
1273
- const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
 
 
 
 
 
 
 
 
1274
 
1275
  const bool is_swa = hparams.is_swa(il);
1276
 
1277
- const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
1278
 
1279
- // store to KV cache
1280
- {
1281
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1282
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
 
 
 
1283
  }
1284
 
1285
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1286
 
1287
  ggml_tensor * q = q_cur;
1288
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1289
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1290
 
1291
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1292
  cb(cur, "kqv_out", il);
@@ -1379,19 +1394,19 @@ ggml_tensor * llm_graph_context::build_attn(
1379
  ggml_build_forward_expand(gf, k_cur);
1380
  ggml_build_forward_expand(gf, v_cur);
1381
 
1382
- const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
1383
 
1384
  // store to KV cache
1385
  {
1386
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1387
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1388
  }
1389
 
1390
  const auto & kq_mask = inp->get_kq_mask();
1391
 
1392
  ggml_tensor * q = q_cur;
1393
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1394
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1395
 
1396
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1397
  cb(cur, "kqv_out", il);
@@ -1412,12 +1427,12 @@ ggml_tensor * llm_graph_context::build_attn(
1412
  }
1413
 
1414
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1415
- const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1416
 
1417
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1418
 
1419
  {
1420
- const auto n_kv = kv_state->get_base()->get_n_kv();
1421
 
1422
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1423
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1429,7 +1444,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1429
  {
1430
  GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1431
 
1432
- const auto n_kv = kv_state->get_swa()->get_n_kv();
1433
 
1434
  inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1435
  //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1485,11 +1500,11 @@ ggml_tensor * llm_graph_context::build_rs(
1485
  }
1486
 
1487
  llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1488
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1489
 
1490
- auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1491
 
1492
- const auto n_rs = kv_state->get_n_rs();
1493
 
1494
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1495
  ggml_set_input(inp->s_copy);
@@ -1504,9 +1519,9 @@ ggml_tensor * llm_graph_context::build_rs(
1504
  int32_t state_size,
1505
  int32_t n_seqs,
1506
  bool avoid_copies) const {
1507
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1508
 
1509
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1510
  }
1511
 
1512
  ggml_tensor * llm_graph_context::build_rs(
@@ -1516,9 +1531,9 @@ ggml_tensor * llm_graph_context::build_rs(
1516
  int32_t state_size,
1517
  int32_t n_seqs,
1518
  bool avoid_copies) const {
1519
- const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
1520
 
1521
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1522
  }
1523
 
1524
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1526,13 +1541,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1526
  ggml_cgraph * gf,
1527
  const llama_ubatch & ubatch,
1528
  int il) const {
1529
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1530
 
1531
  const auto token_shift_count = hparams.token_shift_count;
1532
 
1533
  const int64_t n_seqs = ubatch.n_seqs;
1534
 
1535
- ggml_tensor * token_shift_all = kv_state->get_r_l(il);
1536
 
1537
  ggml_tensor * token_shift = build_rs(
1538
  inp, gf, token_shift_all,
@@ -1547,19 +1562,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1547
  ggml_tensor * token_shift,
1548
  const llama_ubatch & ubatch,
1549
  int il) const {
1550
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1551
 
1552
  const auto token_shift_count = hparams.token_shift_count;
1553
  const auto n_embd = hparams.n_embd;
1554
 
1555
  const int64_t n_seqs = ubatch.n_seqs;
1556
 
1557
- const auto kv_head = kv_state->get_head();
1558
 
1559
  return ggml_cpy(
1560
  ctx0,
1561
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1562
- ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
1563
  );
1564
  }
1565
 
 
87
 
88
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
89
  if (pos_bucket) {
90
+ mctx->set_input_pos_bucket(pos_bucket, ubatch);
91
  }
92
  }
93
 
 
221
  void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
222
  GGML_UNUSED(ubatch);
223
 
224
+ const int64_t n_rs = mctx->get_n_rs();
225
 
226
  if (s_copy) {
227
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
 
229
 
230
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
231
  for (uint32_t i = 0; i < n_rs; ++i) {
232
+ data[i] = mctx->s_copy(i);
233
  }
234
  }
235
  }
 
282
 
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284
  if (self_kq_mask) {
285
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
286
  }
287
  }
288
 
289
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
290
  if (self_kq_mask) {
291
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
292
  }
293
 
294
  if (self_kq_mask_swa) {
295
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
296
  }
297
  }
298
 
 
334
 
335
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
  if (self_kq_mask) {
337
+ mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
  }
339
 
340
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
341
 
342
  if (s_copy) {
343
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
 
345
 
346
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
  for (uint32_t i = 0; i < n_rs; ++i) {
348
+ data[i] = mctx->get_recr()->s_copy(i);
349
  }
350
  }
351
  }
352
 
353
+ void llm_graph_input_one::set_input(const llama_ubatch *) {
354
+ GGML_ASSERT(one && ggml_nelements(one) == 1);
355
+ float f_one = 1.0f;
356
+ ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
357
+ }
358
+
359
  //
360
  // llm_graph_context
361
  //
 
395
  backend_cpu (params.backend_cpu),
396
  cvec (params.cvec),
397
  loras (params.loras),
398
+ mctx (params.mctx),
399
  cross (params.cross),
400
  cb_func (params.cb),
401
  res (std::make_unique<llm_graph_result>()) {
 
560
 
561
  switch (type_op) {
562
  case LLM_FFN_SILU:
563
+ if (gate && type_gate == LLM_FFN_PAR) {
564
+ cur = ggml_swiglu_split(ctx0, cur, tmp);
565
+ cb(cur, "ffn_swiglu", il);
566
+ type_gate = LLM_FFN_SEQ;
567
+ } else {
568
  cur = ggml_silu(ctx0, cur);
569
  cb(cur, "ffn_silu", il);
570
  } break;
571
  case LLM_FFN_GELU:
572
+ if (gate && type_gate == LLM_FFN_PAR) {
573
+ cur = ggml_geglu_split(ctx0, cur, tmp);
574
+ cb(cur, "ffn_geglu", il);
575
+ type_gate = LLM_FFN_SEQ;
576
+ } else {
577
  cur = ggml_gelu(ctx0, cur);
578
  cb(cur, "ffn_gelu", il);
579
  if (act_scales != NULL) {
 
582
  }
583
  } break;
584
  case LLM_FFN_RELU:
585
+ if (gate && type_gate == LLM_FFN_PAR) {
586
+ cur = ggml_reglu_split(ctx0, cur, tmp);
587
+ cb(cur, "ffn_reglu", il);
588
+ type_gate = LLM_FFN_SEQ;
589
+ } else {
590
  cur = ggml_relu(ctx0, cur);
591
  cb(cur, "ffn_relu", il);
592
  } break;
 
600
  } break;
601
  case LLM_FFN_SWIGLU:
602
  {
603
+ cur = ggml_swiglu(ctx0, cur);
604
+ cb(cur, "ffn_swiglu", il);
 
 
 
 
 
 
 
 
 
605
  } break;
606
  case LLM_FFN_GEGLU:
607
  {
608
+ cur = ggml_geglu(ctx0, cur);
 
 
 
 
 
 
 
 
 
609
  cb(cur, "ffn_geglu", il);
610
  } break;
611
+ case LLM_FFN_REGLU:
612
+ {
613
+ cur = ggml_reglu(ctx0, cur);
614
+ cb(cur, "ffn_reglu", il);
615
+ } break;
616
  }
617
 
618
  if (gate && type_gate == LLM_FFN_PAR) {
 
742
 
743
  switch (type_op) {
744
  case LLM_FFN_SILU:
745
+ if (gate_exps) {
746
+ cur = ggml_swiglu_split(ctx0, cur, up);
747
+ cb(cur, "ffn_moe_swiglu", il);
748
+ } else {
749
  cur = ggml_silu(ctx0, cur);
750
  cb(cur, "ffn_moe_silu", il);
751
  } break;
752
  case LLM_FFN_GELU:
753
+ if (gate_exps) {
754
+ cur = ggml_geglu_split(ctx0, cur, up);
755
+ cb(cur, "ffn_moe_geglu", il);
756
+ } else {
757
  cur = ggml_gelu(ctx0, cur);
758
  cb(cur, "ffn_moe_gelu", il);
759
  } break;
 
761
  GGML_ABORT("fatal error");
762
  }
763
 
 
 
 
 
 
764
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
765
  cb(experts, "ffn_moe_down", il);
766
 
 
956
  }
957
 
958
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
959
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
960
 
961
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
962
 
963
+ const auto n_kv = mctx_cur->get_n_kv();
964
 
965
  auto & cur = inp->pos_bucket;
966
 
 
988
  }
989
 
990
  llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
991
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
992
 
993
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
994
 
995
  {
996
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
997
 
998
+ const auto n_kv = inp->mctx->get_attn()->get_n_kv();
999
 
1000
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1001
  //cb(inp->self_kq_mask, "KQ_mask", -1);
 
1005
  }
1006
 
1007
  {
1008
+ const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1009
 
1010
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1011
  ggml_set_input(inp->s_copy);
 
1189
  }
1190
 
1191
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1192
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1193
 
1194
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1195
 
1196
  {
1197
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1198
 
1199
+ const auto n_kv = mctx_cur->get_n_kv();
1200
 
1201
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1202
  //cb(inp->self_kq_mask, "KQ_mask", -1);
 
1226
  ggml_build_forward_expand(gf, k_cur);
1227
  ggml_build_forward_expand(gf, v_cur);
1228
 
1229
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1230
 
1231
  // store to KV cache
1232
  {
1233
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1235
  }
1236
 
1237
  const auto & kq_mask = inp->get_kq_mask();
1238
 
1239
  ggml_tensor * q = q_cur;
1240
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1241
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1242
 
1243
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1244
  cb(cur, "kqv_out", il);
 
1273
  // these nodes are added to the graph together so that they are not reordered
1274
  // by doing so, the number of splits in the graph is reduced
1275
  ggml_build_forward_expand(gf, q_cur);
 
 
1276
 
1277
+ if (k_cur) {
1278
+ ggml_build_forward_expand(gf, k_cur);
1279
+ }
1280
+
1281
+ if (v_cur) {
1282
+ ggml_build_forward_expand(gf, v_cur);
1283
+ }
1284
+
1285
+ const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1286
 
1287
  const bool is_swa = hparams.is_swa(il);
1288
 
1289
+ const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
1290
 
1291
+ // optionally store to KV cache
1292
+ if (k_cur) {
1293
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1294
+ }
1295
+
1296
+ if (v_cur) {
1297
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1298
  }
1299
 
1300
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1301
 
1302
  ggml_tensor * q = q_cur;
1303
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1304
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1305
 
1306
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1307
  cb(cur, "kqv_out", il);
 
1394
  ggml_build_forward_expand(gf, k_cur);
1395
  ggml_build_forward_expand(gf, v_cur);
1396
 
1397
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1398
 
1399
  // store to KV cache
1400
  {
1401
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1403
  }
1404
 
1405
  const auto & kq_mask = inp->get_kq_mask();
1406
 
1407
  ggml_tensor * q = q_cur;
1408
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1409
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1410
 
1411
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1412
  cb(cur, "kqv_out", il);
 
1427
  }
1428
 
1429
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1430
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1431
 
1432
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1433
 
1434
  {
1435
+ const auto n_kv = mctx_cur->get_base()->get_n_kv();
1436
 
1437
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1438
  //cb(inp->self_kq_mask, "KQ_mask", -1);
 
1444
  {
1445
  GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1446
 
1447
+ const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1448
 
1449
  inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1450
  //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
 
1500
  }
1501
 
1502
  llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1503
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1504
 
1505
+ auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1506
 
1507
+ const auto n_rs = mctx_cur->get_n_rs();
1508
 
1509
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1510
  ggml_set_input(inp->s_copy);
 
1519
  int32_t state_size,
1520
  int32_t n_seqs,
1521
  bool avoid_copies) const {
1522
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1523
 
1524
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1525
  }
1526
 
1527
  ggml_tensor * llm_graph_context::build_rs(
 
1531
  int32_t state_size,
1532
  int32_t n_seqs,
1533
  bool avoid_copies) const {
1534
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1535
 
1536
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1537
  }
1538
 
1539
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
 
1541
  ggml_cgraph * gf,
1542
  const llama_ubatch & ubatch,
1543
  int il) const {
1544
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1545
 
1546
  const auto token_shift_count = hparams.token_shift_count;
1547
 
1548
  const int64_t n_seqs = ubatch.n_seqs;
1549
 
1550
+ ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1551
 
1552
  ggml_tensor * token_shift = build_rs(
1553
  inp, gf, token_shift_all,
 
1562
  ggml_tensor * token_shift,
1563
  const llama_ubatch & ubatch,
1564
  int il) const {
1565
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1566
 
1567
  const auto token_shift_count = hparams.token_shift_count;
1568
  const auto n_embd = hparams.n_embd;
1569
 
1570
  const int64_t n_seqs = ubatch.n_seqs;
1571
 
1572
+ const auto kv_head = mctx_cur->get_head();
1573
 
1574
  return ggml_cpy(
1575
  ctx0,
1576
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1577
+ ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
1578
  );
1579
  }
1580
 
examples/talk-llama/llama-graph.h CHANGED
@@ -17,12 +17,12 @@ struct ggml_tensor;
17
  struct llama_ubatch;
18
  struct llama_cparams;
19
 
20
- struct llama_memory_state_i;
21
 
22
- class llama_kv_cache_unified_state;
23
- class llama_kv_cache_unified_iswa_state;
24
- class llama_memory_recurrent_state;
25
- class llama_memory_hybrid_state;
26
 
27
  // certain models (typically multi-modal) can produce different types of graphs
28
  enum llm_graph_type {
@@ -38,6 +38,7 @@ enum llm_ffn_op_type {
38
  LLM_FFN_RELU_SQR,
39
  LLM_FFN_SWIGLU,
40
  LLM_FFN_GEGLU,
 
41
  };
42
 
43
  enum llm_ffn_gate_type {
@@ -136,7 +137,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
136
  public:
137
  llm_graph_input_pos_bucket_kv(
138
  const llama_hparams & hparams,
139
- const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
140
  virtual ~llm_graph_input_pos_bucket_kv() = default;
141
 
142
  void set_input(const llama_ubatch * ubatch) override;
@@ -144,7 +145,8 @@ public:
144
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
145
 
146
  const llama_hparams & hparams;
147
- const llama_kv_cache_unified_state * kv_state;
 
148
  };
149
 
150
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -191,14 +193,14 @@ public:
191
 
192
  class llm_graph_input_rs : public llm_graph_input_i {
193
  public:
194
- llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
195
  virtual ~llm_graph_input_rs() = default;
196
 
197
  void set_input(const llama_ubatch * ubatch) override;
198
 
199
  ggml_tensor * s_copy; // I32 [kv_size]
200
 
201
- const llama_memory_recurrent_state * mem_state;
202
  };
203
 
204
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -238,10 +240,10 @@ public:
238
  llm_graph_input_attn_kv_unified(
239
  const llama_hparams & hparams,
240
  const llama_cparams & cparams,
241
- const llama_kv_cache_unified_state * kv_state) :
242
  hparams(hparams),
243
  cparams(cparams),
244
- kv_state(kv_state) {
245
  }
246
  ~llm_graph_input_attn_kv_unified() = default;
247
 
@@ -255,7 +257,7 @@ public:
255
  const llama_hparams & hparams;
256
  const llama_cparams & cparams;
257
 
258
- const llama_kv_cache_unified_state * kv_state;
259
  };
260
 
261
  class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -263,10 +265,10 @@ public:
263
  llm_graph_input_attn_kv_unified_iswa(
264
  const llama_hparams & hparams,
265
  const llama_cparams & cparams,
266
- const llama_kv_cache_unified_iswa_state * kv_state) :
267
  hparams(hparams),
268
  cparams(cparams),
269
- kv_state(kv_state) {
270
  }
271
  ~llm_graph_input_attn_kv_unified_iswa() = default;
272
 
@@ -283,7 +285,7 @@ public:
283
  const llama_hparams & hparams;
284
  const llama_cparams & cparams;
285
 
286
- const llama_kv_cache_unified_iswa_state * kv_state;
287
  };
288
 
289
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -306,10 +308,10 @@ public:
306
  llm_graph_input_mem_hybrid(
307
  const llama_hparams & hparams,
308
  const llama_cparams & cparams,
309
- const llama_memory_hybrid_state * mem_state) :
310
  hparams(hparams),
311
  cparams(cparams),
312
- mem_state(mem_state) {
313
  }
314
  virtual ~llm_graph_input_mem_hybrid() = default;
315
 
@@ -325,7 +327,18 @@ public:
325
  const llama_hparams & hparams;
326
  const llama_cparams & cparams;
327
 
328
- const llama_memory_hybrid_state * mem_state;
 
 
 
 
 
 
 
 
 
 
 
329
  };
330
 
331
  //
@@ -401,10 +414,10 @@ struct llm_graph_params {
401
  ggml_backend_sched_t sched;
402
  ggml_backend_t backend_cpu;
403
 
404
- const llama_adapter_cvec * cvec;
405
- const llama_adapter_loras * loras;
406
- const llama_memory_state_i * mstate;
407
- const llama_cross * cross;
408
 
409
  uint32_t n_outputs;
410
 
@@ -453,16 +466,17 @@ struct llm_graph_context {
453
 
454
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
455
 
456
- const llama_adapter_cvec * cvec;
457
- const llama_adapter_loras * loras;
458
- const llama_memory_state_i * mstate;
459
- const llama_cross * cross;
460
 
461
  const llm_graph_cb & cb_func;
462
 
463
  std::unique_ptr<llm_graph_result> res;
464
 
465
  llm_graph_context(const llm_graph_params & params);
 
466
 
467
  void cb(ggml_tensor * cur, const char * name, int il) const;
468
 
@@ -588,14 +602,15 @@ struct llm_graph_context {
588
 
589
  llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
590
 
 
591
  ggml_tensor * build_attn(
592
  llm_graph_input_attn_kv_unified_iswa * inp,
593
  ggml_cgraph * gf,
594
  ggml_tensor * wo,
595
  ggml_tensor * wo_b,
596
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
597
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
598
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
599
  ggml_tensor * kq_b,
600
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
601
  float kq_scale,
 
17
  struct llama_ubatch;
18
  struct llama_cparams;
19
 
20
+ struct llama_memory_context_i;
21
 
22
+ class llama_kv_cache_unified_context;
23
+ class llama_kv_cache_unified_iswa_context;
24
+ class llama_memory_recurrent_context;
25
+ class llama_memory_hybrid_context;
26
 
27
  // certain models (typically multi-modal) can produce different types of graphs
28
  enum llm_graph_type {
 
38
  LLM_FFN_RELU_SQR,
39
  LLM_FFN_SWIGLU,
40
  LLM_FFN_GEGLU,
41
+ LLM_FFN_REGLU,
42
  };
43
 
44
  enum llm_ffn_gate_type {
 
137
  public:
138
  llm_graph_input_pos_bucket_kv(
139
  const llama_hparams & hparams,
140
+ const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
141
  virtual ~llm_graph_input_pos_bucket_kv() = default;
142
 
143
  void set_input(const llama_ubatch * ubatch) override;
 
145
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
146
 
147
  const llama_hparams & hparams;
148
+
149
+ const llama_kv_cache_unified_context * mctx;
150
  };
151
 
152
  class llm_graph_input_out_ids : public llm_graph_input_i {
 
193
 
194
  class llm_graph_input_rs : public llm_graph_input_i {
195
  public:
196
+ llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
197
  virtual ~llm_graph_input_rs() = default;
198
 
199
  void set_input(const llama_ubatch * ubatch) override;
200
 
201
  ggml_tensor * s_copy; // I32 [kv_size]
202
 
203
+ const llama_memory_recurrent_context * mctx;
204
  };
205
 
206
  class llm_graph_input_cross_embd : public llm_graph_input_i {
 
240
  llm_graph_input_attn_kv_unified(
241
  const llama_hparams & hparams,
242
  const llama_cparams & cparams,
243
+ const llama_kv_cache_unified_context * mctx) :
244
  hparams(hparams),
245
  cparams(cparams),
246
+ mctx(mctx) {
247
  }
248
  ~llm_graph_input_attn_kv_unified() = default;
249
 
 
257
  const llama_hparams & hparams;
258
  const llama_cparams & cparams;
259
 
260
+ const llama_kv_cache_unified_context * mctx;
261
  };
262
 
263
  class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
 
265
  llm_graph_input_attn_kv_unified_iswa(
266
  const llama_hparams & hparams,
267
  const llama_cparams & cparams,
268
+ const llama_kv_cache_unified_iswa_context * mctx) :
269
  hparams(hparams),
270
  cparams(cparams),
271
+ mctx(mctx) {
272
  }
273
  ~llm_graph_input_attn_kv_unified_iswa() = default;
274
 
 
285
  const llama_hparams & hparams;
286
  const llama_cparams & cparams;
287
 
288
+ const llama_kv_cache_unified_iswa_context * mctx;
289
  };
290
 
291
  class llm_graph_input_attn_cross : public llm_graph_input_i {
 
308
  llm_graph_input_mem_hybrid(
309
  const llama_hparams & hparams,
310
  const llama_cparams & cparams,
311
+ const llama_memory_hybrid_context * mctx) :
312
  hparams(hparams),
313
  cparams(cparams),
314
+ mctx(mctx) {
315
  }
316
  virtual ~llm_graph_input_mem_hybrid() = default;
317
 
 
327
  const llama_hparams & hparams;
328
  const llama_cparams & cparams;
329
 
330
+ const llama_memory_hybrid_context * mctx;
331
+ };
332
+
333
+ // TODO: remove this when ggml_scale_add is implemented
334
+ class llm_graph_input_one : public llm_graph_input_i {
335
+ public:
336
+ llm_graph_input_one() {}
337
+ virtual ~llm_graph_input_one() = default;
338
+
339
+ void set_input(const llama_ubatch *) override;
340
+
341
+ ggml_tensor * one = nullptr; // F32
342
  };
343
 
344
  //
 
414
  ggml_backend_sched_t sched;
415
  ggml_backend_t backend_cpu;
416
 
417
+ const llama_adapter_cvec * cvec;
418
+ const llama_adapter_loras * loras;
419
+ const llama_memory_context_i * mctx;
420
+ const llama_cross * cross;
421
 
422
  uint32_t n_outputs;
423
 
 
466
 
467
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
468
 
469
+ const llama_adapter_cvec * cvec;
470
+ const llama_adapter_loras * loras;
471
+ const llama_memory_context_i * mctx;
472
+ const llama_cross * cross;
473
 
474
  const llm_graph_cb & cb_func;
475
 
476
  std::unique_ptr<llm_graph_result> res;
477
 
478
  llm_graph_context(const llm_graph_params & params);
479
+ virtual ~llm_graph_context() = default;
480
 
481
  void cb(ggml_tensor * cur, const char * name, int il) const;
482
 
 
602
 
603
  llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
604
 
605
+ // note: if k_cur or v_cur are not provided, they will not be stored in the memory
606
  ggml_tensor * build_attn(
607
  llm_graph_input_attn_kv_unified_iswa * inp,
608
  ggml_cgraph * gf,
609
  ggml_tensor * wo,
610
  ggml_tensor * wo_b,
611
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
612
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
613
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
614
  ggml_tensor * kq_b,
615
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
616
  float kq_scale,
examples/talk-llama/llama-hparams.h CHANGED
@@ -143,6 +143,12 @@ struct llama_hparams {
143
  uint32_t n_attn_temp_floor_scale = 8192;
144
  float f_attn_temp_scale = 0.1;
145
 
 
 
 
 
 
 
146
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
147
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
148
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
 
143
  uint32_t n_attn_temp_floor_scale = 8192;
144
  float f_attn_temp_scale = 0.1;
145
 
146
+ // gemma3n altup
147
+ uint32_t n_altup = 4; // altup_num_inputs
148
+ uint32_t i_altup_act = 0; // altup_active_idx
149
+ uint32_t laurel_rank = 64;
150
+ uint32_t n_embd_altup = 256;
151
+
152
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
153
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
154
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
examples/talk-llama/llama-kv-cache-unified-iswa.cpp CHANGED
@@ -95,7 +95,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
  return kv_swa->seq_pos_max(seq_id);
96
  }
97
 
98
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
99
  GGML_UNUSED(embd_all);
100
 
101
  // first try simple split
@@ -125,7 +125,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
125
 
126
  assert(heads_base.size() == heads_swa.size());
127
 
128
- return std::make_unique<llama_kv_cache_unified_iswa_state>(
129
  this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
130
  } while (false);
131
 
@@ -156,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
156
 
157
  assert(heads_base.size() == heads_swa.size());
158
 
159
- return std::make_unique<llama_kv_cache_unified_iswa_state>(
160
  this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
161
  } while (false);
162
 
163
  // TODO: if we fail again, we should attempt different splitting strategies
164
  // but to do that properly, we first have to refactor the batches to be more flexible
165
 
166
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
167
  }
168
 
169
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
170
- return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
171
  }
172
 
173
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
174
- return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
175
  }
176
 
177
  bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -197,46 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
197
  }
198
 
199
  //
200
- // llama_kv_cache_unified_iswa_state
201
  //
202
 
203
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
204
 
205
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
206
  llama_kv_cache_unified_iswa * kv) :
207
- state_base(kv->get_base()->init_full()),
208
- state_swa (kv->get_swa ()->init_full()),
209
- status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
210
  }
211
 
212
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
213
  llama_kv_cache_unified_iswa * kv,
214
  llama_context * lctx,
215
  bool optimize) :
216
- state_base(kv->get_base()->init_update(lctx, optimize)),
217
- state_swa (kv->get_swa ()->init_update(lctx, optimize)),
218
- status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
219
  }
220
 
221
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
222
  llama_kv_cache_unified_iswa * kv,
223
  std::vector<uint32_t> heads_base,
224
  std::vector<uint32_t> heads_swa,
225
  std::vector<llama_ubatch> ubatches) :
226
  ubatches(std::move(ubatches)),
227
  // note: here we copy the ubatches. not sure if this is ideal
228
- state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)),
229
- state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa), this->ubatches)),
230
- status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
231
  }
232
 
233
- llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
234
 
235
- bool llama_kv_cache_unified_iswa_state::next() {
236
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
237
 
238
- state_base->next();
239
- state_swa ->next();
240
 
241
  if (++i_next >= ubatches.size()) {
242
  return false;
@@ -245,35 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
245
  return true;
246
  }
247
 
248
- bool llama_kv_cache_unified_iswa_state::apply() {
249
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
250
 
251
  bool res = true;
252
 
253
- res = res & state_base->apply();
254
- res = res & state_swa ->apply();
255
 
256
  return res;
257
  }
258
 
259
- llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
260
  return status;
261
  }
262
 
263
- const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
264
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
265
 
266
  return ubatches[i_next];
267
  }
268
 
269
- const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
270
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
271
 
272
- return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
273
  }
274
 
275
- const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
276
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
277
 
278
- return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
279
  }
 
95
  return kv_swa->seq_pos_max(seq_id);
96
  }
97
 
98
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
99
  GGML_UNUSED(embd_all);
100
 
101
  // first try simple split
 
125
 
126
  assert(heads_base.size() == heads_swa.size());
127
 
128
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(
129
  this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
130
  } while (false);
131
 
 
156
 
157
  assert(heads_base.size() == heads_swa.size());
158
 
159
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(
160
  this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
161
  } while (false);
162
 
163
  // TODO: if we fail again, we should attempt different splitting strategies
164
  // but to do that properly, we first have to refactor the batches to be more flexible
165
 
166
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
167
  }
168
 
169
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
170
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
171
  }
172
 
173
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
174
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
175
  }
176
 
177
  bool llama_kv_cache_unified_iswa::get_can_shift() const {
 
197
  }
198
 
199
  //
200
+ // llama_kv_cache_unified_iswa_context
201
  //
202
 
203
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
204
 
205
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
206
  llama_kv_cache_unified_iswa * kv) :
207
+ ctx_base(kv->get_base()->init_full()),
208
+ ctx_swa (kv->get_swa ()->init_full()),
209
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
210
  }
211
 
212
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
213
  llama_kv_cache_unified_iswa * kv,
214
  llama_context * lctx,
215
  bool optimize) :
216
+ ctx_base(kv->get_base()->init_update(lctx, optimize)),
217
+ ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
218
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
219
  }
220
 
221
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
222
  llama_kv_cache_unified_iswa * kv,
223
  std::vector<uint32_t> heads_base,
224
  std::vector<uint32_t> heads_swa,
225
  std::vector<llama_ubatch> ubatches) :
226
  ubatches(std::move(ubatches)),
227
  // note: here we copy the ubatches. not sure if this is ideal
228
+ ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229
+ ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
230
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
231
  }
232
 
233
+ llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
234
 
235
+ bool llama_kv_cache_unified_iswa_context::next() {
236
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
237
 
238
+ ctx_base->next();
239
+ ctx_swa ->next();
240
 
241
  if (++i_next >= ubatches.size()) {
242
  return false;
 
245
  return true;
246
  }
247
 
248
+ bool llama_kv_cache_unified_iswa_context::apply() {
249
+ assert(!llama_memory_status_is_fail(status));
250
 
251
  bool res = true;
252
 
253
+ res = res & ctx_base->apply();
254
+ res = res & ctx_swa ->apply();
255
 
256
  return res;
257
  }
258
 
259
+ llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
260
  return status;
261
  }
262
 
263
+ const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
264
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
265
 
266
  return ubatches[i_next];
267
  }
268
 
269
+ const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
270
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
271
 
272
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
273
  }
274
 
275
+ const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
276
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
277
 
278
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
279
  }
examples/talk-llama/llama-kv-cache-unified-iswa.h CHANGED
@@ -31,14 +31,14 @@ public:
31
  // llama_memory_i
32
  //
33
 
34
- llama_memory_state_ptr init_batch(
35
  llama_batch_allocr & balloc,
36
  uint32_t n_ubatch,
37
  bool embd_all) override;
38
 
39
- llama_memory_state_ptr init_full() override;
40
 
41
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
42
 
43
  bool get_can_shift() const override;
44
 
@@ -72,32 +72,32 @@ private:
72
  std::unique_ptr<llama_kv_cache_unified> kv_swa;
73
  };
74
 
75
- class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
76
  public:
77
  // used for errors
78
- llama_kv_cache_unified_iswa_state(llama_memory_status status);
79
 
80
- // used to create a full-cache state
81
- llama_kv_cache_unified_iswa_state(
82
  llama_kv_cache_unified_iswa * kv);
83
 
84
- // used to create an update state
85
- llama_kv_cache_unified_iswa_state(
86
  llama_kv_cache_unified_iswa * kv,
87
  llama_context * lctx,
88
  bool optimize);
89
 
90
- // used to create a state from a batch
91
- llama_kv_cache_unified_iswa_state(
92
  llama_kv_cache_unified_iswa * kv,
93
  std::vector<uint32_t> heads_base,
94
  std::vector<uint32_t> heads_swa,
95
  std::vector<llama_ubatch> ubatches);
96
 
97
- virtual ~llama_kv_cache_unified_iswa_state();
98
 
99
  //
100
- // llama_memory_state_i
101
  //
102
 
103
  bool next() override;
@@ -107,11 +107,11 @@ public:
107
  const llama_ubatch & get_ubatch() const override;
108
 
109
  //
110
- // llama_kv_cache_unified_iswa_state specific API
111
  //
112
 
113
- const llama_kv_cache_unified_state * get_base() const;
114
- const llama_kv_cache_unified_state * get_swa() const;
115
 
116
  private:
117
  //llama_kv_cache_unified_iswa * kv;
@@ -121,8 +121,8 @@ private:
121
 
122
  std::vector<llama_ubatch> ubatches;
123
 
124
- const llama_memory_state_ptr state_base;
125
- const llama_memory_state_ptr state_swa;
126
 
127
  const llama_memory_status status;
128
  };
 
31
  // llama_memory_i
32
  //
33
 
34
+ llama_memory_context_ptr init_batch(
35
  llama_batch_allocr & balloc,
36
  uint32_t n_ubatch,
37
  bool embd_all) override;
38
 
39
+ llama_memory_context_ptr init_full() override;
40
 
41
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
42
 
43
  bool get_can_shift() const override;
44
 
 
72
  std::unique_ptr<llama_kv_cache_unified> kv_swa;
73
  };
74
 
75
+ class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
76
  public:
77
  // used for errors
78
+ llama_kv_cache_unified_iswa_context(llama_memory_status status);
79
 
80
+ // used to create a full-cache context
81
+ llama_kv_cache_unified_iswa_context(
82
  llama_kv_cache_unified_iswa * kv);
83
 
84
+ // used to create an update context
85
+ llama_kv_cache_unified_iswa_context(
86
  llama_kv_cache_unified_iswa * kv,
87
  llama_context * lctx,
88
  bool optimize);
89
 
90
+ // used to create a batch processing context from a batch
91
+ llama_kv_cache_unified_iswa_context(
92
  llama_kv_cache_unified_iswa * kv,
93
  std::vector<uint32_t> heads_base,
94
  std::vector<uint32_t> heads_swa,
95
  std::vector<llama_ubatch> ubatches);
96
 
97
+ virtual ~llama_kv_cache_unified_iswa_context();
98
 
99
  //
100
+ // llama_memory_context_i
101
  //
102
 
103
  bool next() override;
 
107
  const llama_ubatch & get_ubatch() const override;
108
 
109
  //
110
+ // llama_kv_cache_unified_iswa_context specific API
111
  //
112
 
113
+ const llama_kv_cache_unified_context * get_base() const;
114
+ const llama_kv_cache_unified_context * get_swa() const;
115
 
116
  private:
117
  //llama_kv_cache_unified_iswa * kv;
 
121
 
122
  std::vector<llama_ubatch> ubatches;
123
 
124
+ const llama_memory_context_ptr ctx_base;
125
+ const llama_memory_context_ptr ctx_swa;
126
 
127
  const llama_memory_status status;
128
  };
examples/talk-llama/llama-kv-cache-unified.cpp CHANGED
@@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
33
 
34
  GGML_ASSERT(kv_size % n_pad == 0);
35
 
 
 
 
 
 
 
36
  // create a context for each buffer type
37
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
38
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
39
  auto it = ctx_map.find(buft);
40
  if (it == ctx_map.end()) {
41
  ggml_init_params params = {
42
- /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
43
  /*.mem_buffer =*/ NULL,
44
  /*.no_alloc =*/ true,
45
  };
@@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
62
 
63
  cells.resize(kv_size);
64
 
65
- for (uint32_t il = 0; il < hparams.n_layer; il++) {
66
  if (filter && !filter(il)) {
67
  LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
68
  continue;
@@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
102
  layers.push_back({ il, k, v });
103
  }
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
106
  for (auto it : ctx_map) {
107
  auto * buft = it.first;
@@ -307,7 +333,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
307
  return cells.seq_pos_max(seq_id);
308
  }
309
 
310
- llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311
  llama_batch_allocr & balloc,
312
  uint32_t n_ubatch,
313
  bool embd_all) {
@@ -332,18 +358,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
332
  break;
333
  }
334
 
335
- return std::make_unique<llama_kv_cache_unified_state>(
336
  this, std::move(heads), std::move(ubatches));
337
  } while (false);
338
 
339
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
340
  }
341
 
342
- llama_memory_state_ptr llama_kv_cache_unified::init_full() {
343
- return std::make_unique<llama_kv_cache_unified_state>(this);
344
  }
345
 
346
- llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
347
  bool do_shift = get_has_shift();
348
 
349
  defrag_info dinfo;
@@ -373,7 +399,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
373
  }
374
  }
375
 
376
- return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
377
  }
378
 
379
  llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -1710,18 +1736,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1710
  }
1711
 
1712
  //
1713
- // llama_kv_cache_unified_state
1714
  //
1715
 
1716
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
1717
 
1718
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1719
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1720
  n_kv = kv->get_size();
1721
  head = 0;
1722
  }
1723
 
1724
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1725
  llama_kv_cache_unified * kv,
1726
  llama_context * lctx,
1727
  bool do_shift,
@@ -1731,15 +1757,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1731
  }
1732
  }
1733
 
1734
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1735
  llama_kv_cache_unified * kv,
1736
  llama_kv_cache_unified::ubatch_heads heads,
1737
  std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1738
  }
1739
 
1740
- llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
1741
 
1742
- bool llama_kv_cache_unified_state::next() {
1743
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1744
 
1745
  if (++i_next >= ubatches.size()) {
@@ -1749,8 +1775,8 @@ bool llama_kv_cache_unified_state::next() {
1749
  return true;
1750
  }
1751
 
1752
- bool llama_kv_cache_unified_state::apply() {
1753
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1754
 
1755
  // no ubatches -> this is a KV cache update
1756
  if (ubatches.empty()) {
@@ -1767,45 +1793,45 @@ bool llama_kv_cache_unified_state::apply() {
1767
  return true;
1768
  }
1769
 
1770
- llama_memory_status llama_kv_cache_unified_state::get_status() const {
1771
  return status;
1772
  }
1773
 
1774
- const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
1775
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1776
 
1777
  return ubatches[i_next];
1778
  }
1779
 
1780
- uint32_t llama_kv_cache_unified_state::get_n_kv() const {
1781
  return n_kv;
1782
  }
1783
 
1784
- ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
1785
  return kv->get_k(ctx, il, n_kv);
1786
  }
1787
 
1788
- ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
1789
  return kv->get_v(ctx, il, n_kv);
1790
  }
1791
 
1792
- ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1793
  return kv->cpy_k(ctx, k_cur, il, head);
1794
  }
1795
 
1796
- ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1797
  return kv->cpy_v(ctx, v_cur, il, head);
1798
  }
1799
 
1800
- void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
1801
  kv->set_input_k_shift(dst);
1802
  }
1803
 
1804
- void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1805
  kv->set_input_kq_mask(dst, ubatch, causal_attn);
1806
  }
1807
 
1808
- void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1809
  kv->set_input_pos_bucket(dst, ubatch);
1810
  }
1811
 
 
33
 
34
  GGML_ASSERT(kv_size % n_pad == 0);
35
 
36
+ // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
37
+ auto n_layer_cache = hparams.n_layer;
38
+ if (model.arch == LLM_ARCH_GEMMA3N) {
39
+ n_layer_cache = 20;
40
+ }
41
+
42
  // create a context for each buffer type
43
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
44
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
45
  auto it = ctx_map.find(buft);
46
  if (it == ctx_map.end()) {
47
  ggml_init_params params = {
48
+ /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
49
  /*.mem_buffer =*/ NULL,
50
  /*.no_alloc =*/ true,
51
  };
 
68
 
69
  cells.resize(kv_size);
70
 
71
+ for (uint32_t il = 0; il < n_layer_cache; il++) {
72
  if (filter && !filter(il)) {
73
  LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
74
  continue;
 
108
  layers.push_back({ il, k, v });
109
  }
110
 
111
+ // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
112
+ if (model.arch == LLM_ARCH_GEMMA3N) {
113
+ LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
114
+
115
+ for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
116
+ if (filter && !filter(il)) {
117
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
118
+ continue;
119
+ }
120
+
121
+ const bool is_swa = hparams.is_swa(il);
122
+ const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
123
+
124
+ GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
125
+ map_layer_ids[il] = map_layer_ids[il_reuse];
126
+
127
+ LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
128
+ }
129
+ }
130
+
131
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
132
  for (auto it : ctx_map) {
133
  auto * buft = it.first;
 
333
  return cells.seq_pos_max(seq_id);
334
  }
335
 
336
+ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
337
  llama_batch_allocr & balloc,
338
  uint32_t n_ubatch,
339
  bool embd_all) {
 
358
  break;
359
  }
360
 
361
+ return std::make_unique<llama_kv_cache_unified_context>(
362
  this, std::move(heads), std::move(ubatches));
363
  } while (false);
364
 
365
+ return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
366
  }
367
 
368
+ llama_memory_context_ptr llama_kv_cache_unified::init_full() {
369
+ return std::make_unique<llama_kv_cache_unified_context>(this);
370
  }
371
 
372
+ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
373
  bool do_shift = get_has_shift();
374
 
375
  defrag_info dinfo;
 
399
  }
400
  }
401
 
402
+ return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
403
  }
404
 
405
  llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
 
1736
  }
1737
 
1738
  //
1739
+ // llama_kv_cache_unified_context
1740
  //
1741
 
1742
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
1743
 
1744
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1745
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1746
  n_kv = kv->get_size();
1747
  head = 0;
1748
  }
1749
 
1750
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1751
  llama_kv_cache_unified * kv,
1752
  llama_context * lctx,
1753
  bool do_shift,
 
1757
  }
1758
  }
1759
 
1760
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1761
  llama_kv_cache_unified * kv,
1762
  llama_kv_cache_unified::ubatch_heads heads,
1763
  std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1764
  }
1765
 
1766
+ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
1767
 
1768
+ bool llama_kv_cache_unified_context::next() {
1769
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1770
 
1771
  if (++i_next >= ubatches.size()) {
 
1775
  return true;
1776
  }
1777
 
1778
+ bool llama_kv_cache_unified_context::apply() {
1779
+ assert(!llama_memory_status_is_fail(status));
1780
 
1781
  // no ubatches -> this is a KV cache update
1782
  if (ubatches.empty()) {
 
1793
  return true;
1794
  }
1795
 
1796
+ llama_memory_status llama_kv_cache_unified_context::get_status() const {
1797
  return status;
1798
  }
1799
 
1800
+ const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
1801
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1802
 
1803
  return ubatches[i_next];
1804
  }
1805
 
1806
+ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
1807
  return n_kv;
1808
  }
1809
 
1810
+ ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
1811
  return kv->get_k(ctx, il, n_kv);
1812
  }
1813
 
1814
+ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
1815
  return kv->get_v(ctx, il, n_kv);
1816
  }
1817
 
1818
+ ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1819
  return kv->cpy_k(ctx, k_cur, il, head);
1820
  }
1821
 
1822
+ ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1823
  return kv->cpy_v(ctx, v_cur, il, head);
1824
  }
1825
 
1826
+ void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
1827
  kv->set_input_k_shift(dst);
1828
  }
1829
 
1830
+ void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1831
  kv->set_input_kq_mask(dst, ubatch, causal_attn);
1832
  }
1833
 
1834
+ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1835
  kv->set_input_pos_bucket(dst, ubatch);
1836
  }
1837
 
examples/talk-llama/llama-kv-cache-unified.h CHANGED
@@ -56,14 +56,14 @@ public:
56
  // llama_memory_i
57
  //
58
 
59
- llama_memory_state_ptr init_batch(
60
  llama_batch_allocr & balloc,
61
  uint32_t n_ubatch,
62
  bool embd_all) override;
63
 
64
- llama_memory_state_ptr init_full() override;
65
 
66
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
67
 
68
  bool get_can_shift() const override;
69
 
@@ -208,36 +208,36 @@ private:
208
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
209
  };
210
 
211
- class llama_kv_cache_unified_state : public llama_memory_state_i {
212
  public:
213
  // some shorthands
214
  using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
215
  using defrag_info = llama_kv_cache_unified::defrag_info;
216
 
217
  // used for errors
218
- llama_kv_cache_unified_state(llama_memory_status status);
219
 
220
- // used to create a full-cache state
221
- llama_kv_cache_unified_state(
222
  llama_kv_cache_unified * kv);
223
 
224
- // used to create an update state
225
- llama_kv_cache_unified_state(
226
  llama_kv_cache_unified * kv,
227
  llama_context * lctx,
228
  bool do_shift,
229
  defrag_info dinfo);
230
 
231
- // used to create a decode state from a batch
232
- llama_kv_cache_unified_state(
233
  llama_kv_cache_unified * kv,
234
  ubatch_heads heads,
235
  std::vector<llama_ubatch> ubatches);
236
 
237
- virtual ~llama_kv_cache_unified_state();
238
 
239
  //
240
- // llama_memory_state_i
241
  //
242
 
243
  bool next() override;
@@ -247,7 +247,7 @@ public:
247
  const llama_ubatch & get_ubatch() const override;
248
 
249
  //
250
- // llama_kv_cache_unified_state specific API
251
  //
252
 
253
  uint32_t get_n_kv() const;
@@ -272,7 +272,7 @@ private:
272
  llama_context * lctx;
273
 
274
  //
275
- // update state
276
  //
277
 
278
  bool do_shift = false;
@@ -280,7 +280,7 @@ private:
280
  defrag_info dinfo;
281
 
282
  //
283
- // batch processing state
284
  //
285
 
286
  // the index of the next ubatch to process
 
56
  // llama_memory_i
57
  //
58
 
59
+ llama_memory_context_ptr init_batch(
60
  llama_batch_allocr & balloc,
61
  uint32_t n_ubatch,
62
  bool embd_all) override;
63
 
64
+ llama_memory_context_ptr init_full() override;
65
 
66
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
67
 
68
  bool get_can_shift() const override;
69
 
 
208
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
209
  };
210
 
211
+ class llama_kv_cache_unified_context : public llama_memory_context_i {
212
  public:
213
  // some shorthands
214
  using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
215
  using defrag_info = llama_kv_cache_unified::defrag_info;
216
 
217
  // used for errors
218
+ llama_kv_cache_unified_context(llama_memory_status status);
219
 
220
+ // used to create a full-cache context
221
+ llama_kv_cache_unified_context(
222
  llama_kv_cache_unified * kv);
223
 
224
+ // used to create an update context
225
+ llama_kv_cache_unified_context(
226
  llama_kv_cache_unified * kv,
227
  llama_context * lctx,
228
  bool do_shift,
229
  defrag_info dinfo);
230
 
231
+ // used to create a batch procesing context from a batch
232
+ llama_kv_cache_unified_context(
233
  llama_kv_cache_unified * kv,
234
  ubatch_heads heads,
235
  std::vector<llama_ubatch> ubatches);
236
 
237
+ virtual ~llama_kv_cache_unified_context();
238
 
239
  //
240
+ // llama_memory_context_i
241
  //
242
 
243
  bool next() override;
 
247
  const llama_ubatch & get_ubatch() const override;
248
 
249
  //
250
+ // llama_kv_cache_unified_context specific API
251
  //
252
 
253
  uint32_t get_n_kv() const;
 
272
  llama_context * lctx;
273
 
274
  //
275
+ // update context
276
  //
277
 
278
  bool do_shift = false;
 
280
  defrag_info dinfo;
281
 
282
  //
283
+ // batch processing context
284
  //
285
 
286
  // the index of the next ubatch to process
examples/talk-llama/llama-kv-cells.h CHANGED
@@ -7,6 +7,7 @@
7
  #include <cassert>
8
  #include <vector>
9
  #include <set>
 
10
 
11
  // meta information about KV cells that can be part of multiple sequences at the same time
12
  // TODO: add unit tests
@@ -164,7 +165,7 @@ public:
164
  assert(seq_id >= 0);
165
 
166
  seq[i].reset(seq_id);
167
- seq_pos[seq_id].erase(pos[i]);
168
 
169
  if (seq[i].none()) {
170
  pos[i] = -1;
@@ -187,7 +188,7 @@ public:
187
  seq[i].reset();
188
 
189
  seq[i].set(seq_id);
190
- seq_pos[seq_id].insert(pos[i]);
191
 
192
  return false;
193
  }
@@ -232,7 +233,7 @@ public:
232
  assert(!seq[i].test(seq_id));
233
 
234
  seq[i].set(seq_id);
235
- seq_pos[seq_id].insert(pos[i]);
236
  }
237
 
238
  // return the sequence id of this cell
@@ -259,7 +260,9 @@ public:
259
  return -1;
260
  }
261
 
262
- return *seq_pos[seq_id].begin();
 
 
263
  }
264
 
265
  // the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ public:
272
  return -1;
273
  }
274
 
275
- return *seq_pos[seq_id].rbegin();
 
 
276
  }
277
 
278
  // note: call only if the cell is not empty
@@ -389,17 +394,36 @@ private:
389
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390
  std::vector<seq_set_t> seq;
391
 
392
- // the set seq_pos[s] tells us which positions are currently present for sequence s
 
393
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394
- std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
 
 
 
 
 
395
 
396
  // helper functions for updating `seq_pos`, once cell at a time:
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  // remove cell i
399
  void seq_pos_rm(uint32_t i) {
400
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
401
  if (seq[i].test(s)) {
402
- seq_pos[s].erase(pos[i]);
403
  }
404
  }
405
  }
@@ -408,7 +432,7 @@ private:
408
  void seq_pos_add(uint32_t i) {
409
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
410
  if (seq[i].test(s)) {
411
- seq_pos[s].insert(pos[i]);
412
  }
413
  }
414
  }
 
7
  #include <cassert>
8
  #include <vector>
9
  #include <set>
10
+ #include <map>
11
 
12
  // meta information about KV cells that can be part of multiple sequences at the same time
13
  // TODO: add unit tests
 
165
  assert(seq_id >= 0);
166
 
167
  seq[i].reset(seq_id);
168
+ seq_pos_dec(seq_id, pos[i]);
169
 
170
  if (seq[i].none()) {
171
  pos[i] = -1;
 
188
  seq[i].reset();
189
 
190
  seq[i].set(seq_id);
191
+ seq_pos_inc(seq_id, pos[i]);
192
 
193
  return false;
194
  }
 
233
  assert(!seq[i].test(seq_id));
234
 
235
  seq[i].set(seq_id);
236
+ seq_pos_inc(seq_id, pos[i]);
237
  }
238
 
239
  // return the sequence id of this cell
 
260
  return -1;
261
  }
262
 
263
+ assert(seq_pos[seq_id].begin()->second > 0);
264
+
265
+ return seq_pos[seq_id].begin()->first;
266
  }
267
 
268
  // the maximum position of sequence seq_id currently present in any of the cells
 
275
  return -1;
276
  }
277
 
278
+ assert(seq_pos[seq_id].rbegin()->second > 0);
279
+
280
+ return seq_pos[seq_id].rbegin()->first;
281
  }
282
 
283
  // note: call only if the cell is not empty
 
394
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
395
  std::vector<seq_set_t> seq;
396
 
397
+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398
+ // if the position p is not present, seq_pos[s][p] is not set
399
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
400
+ //
401
+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402
+ // - during performing a cache reuse via (rm + add)
403
+ // - some vision models have input embeddings with repeating positions
404
+ //
405
+ std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
406
 
407
  // helper functions for updating `seq_pos`, once cell at a time:
408
 
409
+ void seq_pos_dec(llama_seq_id s, llama_pos p) {
410
+ auto it = seq_pos[s].find(p);
411
+ assert(it != seq_pos[s].end());
412
+
413
+ if (--it->second == 0) {
414
+ seq_pos[s].erase(it);
415
+ }
416
+ }
417
+
418
+ void seq_pos_inc(llama_seq_id s, llama_pos p) {
419
+ seq_pos[s][p]++;
420
+ }
421
+
422
  // remove cell i
423
  void seq_pos_rm(uint32_t i) {
424
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
425
  if (seq[i].test(s)) {
426
+ seq_pos_dec(s, pos[i]);
427
  }
428
  }
429
  }
 
432
  void seq_pos_add(uint32_t i) {
433
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
434
  if (seq[i].test(s)) {
435
+ seq_pos_inc(s, pos[i]);
436
  }
437
  }
438
  }
examples/talk-llama/llama-memory-hybrid.cpp CHANGED
@@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid(
56
  n_seq_max
57
  )) {}
58
 
59
- llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
60
  do {
61
  balloc.split_reset();
62
 
@@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball
82
 
83
  // prepare the recurrent batches first
84
  if (!mem_recr->prepare(ubatches)) {
85
- // TODO: will the recurrent cache be in an undefined state at this point?
86
  LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
87
- return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
88
  }
89
 
90
  // prepare the attention cache
91
  auto heads_attn = mem_attn->prepare(ubatches);
92
  if (heads_attn.empty()) {
93
  LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
94
- return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
95
  }
96
 
97
- return std::make_unique<llama_memory_hybrid_state>(
98
  this, std::move(heads_attn), std::move(ubatches));
99
  } while(false);
100
 
101
- return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
102
  }
103
 
104
- llama_memory_state_ptr llama_memory_hybrid::init_full() {
105
- return std::make_unique<llama_memory_hybrid_state>(this);
106
  }
107
 
108
- llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
109
- return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
110
  }
111
 
112
  bool llama_memory_hybrid::get_can_shift() const {
@@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
176
  return mem_recr.get();
177
  }
178
 
179
- llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
180
 
181
- llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
182
- state_attn(mem->get_mem_attn()->init_full()),
183
- state_recr(mem->get_mem_recr()->init_full()),
184
- status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
185
  }
186
 
187
- llama_memory_hybrid_state::llama_memory_hybrid_state(
188
  llama_memory_hybrid * mem,
189
  llama_context * lctx,
190
  bool optimize) :
191
- state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
192
- state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
193
- status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
194
  }
195
 
196
- llama_memory_hybrid_state::llama_memory_hybrid_state(
197
  llama_memory_hybrid * mem,
198
  std::vector<uint32_t> heads_attn,
199
  std::vector<llama_ubatch> ubatches) :
200
  ubatches(std::move(ubatches)),
201
  // note: here we copy the ubatches. not sure if this is ideal
202
- state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
203
- state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), this->ubatches)),
204
- status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
205
  }
206
 
207
- bool llama_memory_hybrid_state::next() {
208
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
209
 
210
- state_attn->next();
211
- state_recr->next();
212
 
213
  if (++i_next >= ubatches.size()) {
214
  return false;
@@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() {
217
  return true;
218
  }
219
 
220
- bool llama_memory_hybrid_state::apply() {
221
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
222
 
223
  bool res = true;
224
 
225
- res = res & state_attn->apply();
226
- res = res & state_recr->apply();
227
 
228
  return res;
229
  }
230
 
231
- llama_memory_status llama_memory_hybrid_state::get_status() const {
232
  return status;
233
  }
234
 
235
- const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
236
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
237
  return ubatches[i_next];
238
  }
239
 
240
- const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
241
- return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
242
  }
243
 
244
- const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
245
- return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
246
  }
 
56
  n_seq_max
57
  )) {}
58
 
59
+ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
60
  do {
61
  balloc.split_reset();
62
 
 
82
 
83
  // prepare the recurrent batches first
84
  if (!mem_recr->prepare(ubatches)) {
85
+ // TODO: will the recurrent cache be in an undefined context at this point?
86
  LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
87
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
88
  }
89
 
90
  // prepare the attention cache
91
  auto heads_attn = mem_attn->prepare(ubatches);
92
  if (heads_attn.empty()) {
93
  LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
94
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
95
  }
96
 
97
+ return std::make_unique<llama_memory_hybrid_context>(
98
  this, std::move(heads_attn), std::move(ubatches));
99
  } while(false);
100
 
101
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
102
  }
103
 
104
+ llama_memory_context_ptr llama_memory_hybrid::init_full() {
105
+ return std::make_unique<llama_memory_hybrid_context>(this);
106
  }
107
 
108
+ llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
109
+ return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
110
  }
111
 
112
  bool llama_memory_hybrid::get_can_shift() const {
 
176
  return mem_recr.get();
177
  }
178
 
179
+ llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
180
 
181
+ llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
182
+ ctx_attn(mem->get_mem_attn()->init_full()),
183
+ ctx_recr(mem->get_mem_recr()->init_full()),
184
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
185
  }
186
 
187
+ llama_memory_hybrid_context::llama_memory_hybrid_context(
188
  llama_memory_hybrid * mem,
189
  llama_context * lctx,
190
  bool optimize) :
191
+ ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
192
+ ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
193
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
194
  }
195
 
196
+ llama_memory_hybrid_context::llama_memory_hybrid_context(
197
  llama_memory_hybrid * mem,
198
  std::vector<uint32_t> heads_attn,
199
  std::vector<llama_ubatch> ubatches) :
200
  ubatches(std::move(ubatches)),
201
  // note: here we copy the ubatches. not sure if this is ideal
202
+ ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
203
+ ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
204
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
205
  }
206
 
207
+ bool llama_memory_hybrid_context::next() {
208
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
209
 
210
+ ctx_attn->next();
211
+ ctx_recr->next();
212
 
213
  if (++i_next >= ubatches.size()) {
214
  return false;
 
217
  return true;
218
  }
219
 
220
+ bool llama_memory_hybrid_context::apply() {
221
+ assert(!llama_memory_status_is_fail(status));
222
 
223
  bool res = true;
224
 
225
+ res = res & ctx_attn->apply();
226
+ res = res & ctx_recr->apply();
227
 
228
  return res;
229
  }
230
 
231
+ llama_memory_status llama_memory_hybrid_context::get_status() const {
232
  return status;
233
  }
234
 
235
+ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
236
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
237
  return ubatches[i_next];
238
  }
239
 
240
+ const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
241
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
242
  }
243
 
244
+ const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
245
+ return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
246
  }
examples/talk-llama/llama-memory-hybrid.h CHANGED
@@ -49,14 +49,14 @@ public:
49
  // llama_memory_i
50
  //
51
 
52
- llama_memory_state_ptr init_batch(
53
  llama_batch_allocr & balloc,
54
  uint32_t n_ubatch,
55
  bool embd_all) override;
56
 
57
- llama_memory_state_ptr init_full() override;
58
 
59
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
60
 
61
  bool get_can_shift() const override;
62
 
@@ -90,27 +90,27 @@ private:
90
  const std::unique_ptr<llama_memory_recurrent> mem_recr;
91
  };
92
 
93
- class llama_memory_hybrid_state : public llama_memory_state_i {
94
  public:
95
  // init failure
96
- explicit llama_memory_hybrid_state(llama_memory_status status);
97
 
98
  // init full
99
- explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
100
 
101
  // init update
102
- explicit llama_memory_hybrid_state(
103
  llama_memory_hybrid * mem,
104
  llama_context * lctx,
105
  bool optimize);
106
 
107
  // init success
108
- llama_memory_hybrid_state(
109
  llama_memory_hybrid * mem,
110
  std::vector<uint32_t> heads_attn,
111
  std::vector<llama_ubatch> ubatches);
112
 
113
- ~llama_memory_hybrid_state() = default;
114
 
115
  bool next() override;
116
  bool apply() override;
@@ -119,11 +119,11 @@ public:
119
  const llama_ubatch & get_ubatch() const override;
120
 
121
  //
122
- // llama_memory_hybrid_state
123
  //
124
 
125
- const llama_kv_cache_unified_state * get_state_attn() const;
126
- const llama_memory_recurrent_state * get_state_recr() const;
127
 
128
  private:
129
  // the index of the next ubatch to process
@@ -131,8 +131,8 @@ private:
131
 
132
  std::vector<llama_ubatch> ubatches;
133
 
134
- const llama_memory_state_ptr state_attn;
135
- const llama_memory_state_ptr state_recr;
136
 
137
  const llama_memory_status status;
138
  };
 
49
  // llama_memory_i
50
  //
51
 
52
+ llama_memory_context_ptr init_batch(
53
  llama_batch_allocr & balloc,
54
  uint32_t n_ubatch,
55
  bool embd_all) override;
56
 
57
+ llama_memory_context_ptr init_full() override;
58
 
59
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
60
 
61
  bool get_can_shift() const override;
62
 
 
90
  const std::unique_ptr<llama_memory_recurrent> mem_recr;
91
  };
92
 
93
+ class llama_memory_hybrid_context : public llama_memory_context_i {
94
  public:
95
  // init failure
96
+ explicit llama_memory_hybrid_context(llama_memory_status status);
97
 
98
  // init full
99
+ explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
100
 
101
  // init update
102
+ explicit llama_memory_hybrid_context(
103
  llama_memory_hybrid * mem,
104
  llama_context * lctx,
105
  bool optimize);
106
 
107
  // init success
108
+ llama_memory_hybrid_context(
109
  llama_memory_hybrid * mem,
110
  std::vector<uint32_t> heads_attn,
111
  std::vector<llama_ubatch> ubatches);
112
 
113
+ ~llama_memory_hybrid_context() = default;
114
 
115
  bool next() override;
116
  bool apply() override;
 
119
  const llama_ubatch & get_ubatch() const override;
120
 
121
  //
122
+ // llama_memory_hybrid_context
123
  //
124
 
125
+ const llama_kv_cache_unified_context * get_attn() const;
126
+ const llama_memory_recurrent_context * get_recr() const;
127
 
128
  private:
129
  // the index of the next ubatch to process
 
131
 
132
  std::vector<llama_ubatch> ubatches;
133
 
134
+ const llama_memory_context_ptr ctx_attn;
135
+ const llama_memory_context_ptr ctx_recr;
136
 
137
  const llama_memory_status status;
138
  };
examples/talk-llama/llama-memory-recurrent.cpp CHANGED
@@ -362,42 +362,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
362
  return result;
363
  }
364
 
365
- llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
366
- std::vector<llama_ubatch> ubatches;
 
367
 
368
- while (true) {
369
- llama_ubatch ubatch;
 
370
 
371
- if (embd_all) {
372
- // if all tokens are output, split by sequence
373
- ubatch = balloc.split_seq(n_ubatch);
374
- } else {
375
- ubatch = balloc.split_equal(n_ubatch);
 
 
 
 
 
 
 
376
  }
377
 
378
- if (ubatch.n_tokens == 0) {
379
  break;
380
  }
381
 
382
- ubatches.push_back(std::move(ubatch)); // NOLINT
383
- }
384
-
385
- if (!prepare(ubatches)) {
386
- return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
387
- }
388
 
389
- return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches));
390
  }
391
 
392
- llama_memory_state_ptr llama_memory_recurrent::init_full() {
393
- return std::make_unique<llama_memory_recurrent_state>(this);
394
  }
395
 
396
- llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
397
  GGML_UNUSED(lctx);
398
  GGML_UNUSED(optimize);
399
 
400
- return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
401
  }
402
 
403
  bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -1040,22 +1045,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
1040
  }
1041
 
1042
  //
1043
- // llama_memory_recurrent_state
1044
  //
1045
 
1046
- llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
1047
 
1048
- llama_memory_recurrent_state::llama_memory_recurrent_state(
1049
  llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
1050
  }
1051
 
1052
- llama_memory_recurrent_state::llama_memory_recurrent_state(
1053
  llama_memory_recurrent * mem,
1054
  std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
1055
 
1056
- llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
1057
 
1058
- bool llama_memory_recurrent_state::next() {
1059
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1060
 
1061
  if (++i_next >= ubatches.size()) {
@@ -1065,48 +1070,56 @@ bool llama_memory_recurrent_state::next() {
1065
  return true;
1066
  }
1067
 
1068
- bool llama_memory_recurrent_state::apply() {
1069
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
 
 
 
 
 
 
 
1070
 
1071
  mem->find_slot(ubatches[i_next]);
1072
 
1073
  return true;
1074
  }
1075
 
1076
- llama_memory_status llama_memory_recurrent_state::get_status() const {
1077
  return status;
1078
  }
1079
 
1080
- const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
1081
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1082
 
1083
  return ubatches[i_next];
1084
  }
1085
 
1086
- uint32_t llama_memory_recurrent_state::get_n_rs() const {
1087
  return is_full ? mem->size : mem->n;
1088
  }
1089
 
1090
- uint32_t llama_memory_recurrent_state::get_head() const {
1091
  return is_full ? 0 : mem->head;
1092
  }
1093
 
1094
- int32_t llama_memory_recurrent_state::get_rs_z() const {
1095
  return is_full ? 0 : mem->rs_z;
1096
  }
1097
 
1098
- uint32_t llama_memory_recurrent_state::get_size() const {
1099
  return mem->size;
1100
  }
1101
 
1102
- ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
1103
  return mem->r_l[il];
1104
  }
1105
 
1106
- ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
1107
  return mem->s_l[il];
1108
  }
1109
 
1110
- int32_t llama_memory_recurrent_state::s_copy(int i) const {
1111
  return mem->cells[i + mem->head].src0;
1112
  }
 
362
  return result;
363
  }
364
 
365
+ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
366
+ do {
367
+ balloc.split_reset();
368
 
369
+ std::vector<llama_ubatch> ubatches;
370
+ while (true) {
371
+ llama_ubatch ubatch;
372
 
373
+ if (embd_all) {
374
+ // if all tokens are output, split by sequence
375
+ ubatch = balloc.split_seq(n_ubatch);
376
+ } else {
377
+ ubatch = balloc.split_equal(n_ubatch);
378
+ }
379
+
380
+ if (ubatch.n_tokens == 0) {
381
+ break;
382
+ }
383
+
384
+ ubatches.push_back(std::move(ubatch)); // NOLINT
385
  }
386
 
387
+ if (!prepare(ubatches)) {
388
  break;
389
  }
390
 
391
+ return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
392
+ } while (false);
 
 
 
 
393
 
394
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
395
  }
396
 
397
+ llama_memory_context_ptr llama_memory_recurrent::init_full() {
398
+ return std::make_unique<llama_memory_recurrent_context>(this);
399
  }
400
 
401
+ llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
402
  GGML_UNUSED(lctx);
403
  GGML_UNUSED(optimize);
404
 
405
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
406
  }
407
 
408
  bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
 
1045
  }
1046
 
1047
  //
1048
+ // llama_memory_recurrent_context
1049
  //
1050
 
1051
+ llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
1052
 
1053
+ llama_memory_recurrent_context::llama_memory_recurrent_context(
1054
  llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
1055
  }
1056
 
1057
+ llama_memory_recurrent_context::llama_memory_recurrent_context(
1058
  llama_memory_recurrent * mem,
1059
  std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
1060
 
1061
+ llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
1062
 
1063
+ bool llama_memory_recurrent_context::next() {
1064
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1065
 
1066
  if (++i_next >= ubatches.size()) {
 
1070
  return true;
1071
  }
1072
 
1073
+ bool llama_memory_recurrent_context::apply() {
1074
+ assert(!llama_memory_status_is_fail(status));
1075
+
1076
+ // no ubatches -> this is an update
1077
+ if (ubatches.empty()) {
1078
+ // recurrent cache never performs updates
1079
+ assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE);
1080
+
1081
+ return true;
1082
+ }
1083
 
1084
  mem->find_slot(ubatches[i_next]);
1085
 
1086
  return true;
1087
  }
1088
 
1089
+ llama_memory_status llama_memory_recurrent_context::get_status() const {
1090
  return status;
1091
  }
1092
 
1093
+ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
1094
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1095
 
1096
  return ubatches[i_next];
1097
  }
1098
 
1099
+ uint32_t llama_memory_recurrent_context::get_n_rs() const {
1100
  return is_full ? mem->size : mem->n;
1101
  }
1102
 
1103
+ uint32_t llama_memory_recurrent_context::get_head() const {
1104
  return is_full ? 0 : mem->head;
1105
  }
1106
 
1107
+ int32_t llama_memory_recurrent_context::get_rs_z() const {
1108
  return is_full ? 0 : mem->rs_z;
1109
  }
1110
 
1111
+ uint32_t llama_memory_recurrent_context::get_size() const {
1112
  return mem->size;
1113
  }
1114
 
1115
+ ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
1116
  return mem->r_l[il];
1117
  }
1118
 
1119
+ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
1120
  return mem->s_l[il];
1121
  }
1122
 
1123
+ int32_t llama_memory_recurrent_context::s_copy(int i) const {
1124
  return mem->cells[i + mem->head].src0;
1125
  }
examples/talk-llama/llama-memory-recurrent.h CHANGED
@@ -11,8 +11,8 @@
11
  // llama_memory_recurrent
12
  //
13
 
14
- // TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
15
- // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
16
  class llama_memory_recurrent : public llama_memory_i {
17
  public:
18
 
@@ -34,14 +34,14 @@ public:
34
  // llama_memory_i
35
  //
36
 
37
- llama_memory_state_ptr init_batch(
38
  llama_batch_allocr & balloc,
39
  uint32_t n_ubatch,
40
  bool embd_all) override;
41
 
42
- llama_memory_state_ptr init_full() override;
43
 
44
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
45
 
46
  void clear(bool data) override;
47
 
@@ -125,24 +125,24 @@ private:
125
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
126
  };
127
 
128
- class llama_memory_recurrent_state : public llama_memory_state_i {
129
  public:
130
  // used for errors
131
- llama_memory_recurrent_state(llama_memory_status status);
132
 
133
- // used to create a full-cache state
134
- llama_memory_recurrent_state(
135
  llama_memory_recurrent * mem);
136
 
137
- // used to create a state from a batch
138
- llama_memory_recurrent_state(
139
  llama_memory_recurrent * mem,
140
  std::vector<llama_ubatch> ubatches);
141
 
142
- virtual ~llama_memory_recurrent_state();
143
 
144
  //
145
- // llama_memory_state_i
146
  //
147
 
148
  bool next() override;
@@ -152,7 +152,7 @@ public:
152
  const llama_ubatch & get_ubatch() const override;
153
 
154
  //
155
- // llama_memory_recurrent_state specific API
156
  //
157
 
158
  uint32_t get_n_rs() const;
 
11
  // llama_memory_recurrent
12
  //
13
 
14
+ // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
15
+ // see the implementation of llama_kv_cache_unified_context_i for an example how to do it
16
  class llama_memory_recurrent : public llama_memory_i {
17
  public:
18
 
 
34
  // llama_memory_i
35
  //
36
 
37
+ llama_memory_context_ptr init_batch(
38
  llama_batch_allocr & balloc,
39
  uint32_t n_ubatch,
40
  bool embd_all) override;
41
 
42
+ llama_memory_context_ptr init_full() override;
43
 
44
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
45
 
46
  void clear(bool data) override;
47
 
 
125
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
126
  };
127
 
128
+ class llama_memory_recurrent_context : public llama_memory_context_i {
129
  public:
130
  // used for errors
131
+ llama_memory_recurrent_context(llama_memory_status status);
132
 
133
+ // used to create a full-cache or update context
134
+ llama_memory_recurrent_context(
135
  llama_memory_recurrent * mem);
136
 
137
+ // used to create a batch processing context from a batch
138
+ llama_memory_recurrent_context(
139
  llama_memory_recurrent * mem,
140
  std::vector<llama_ubatch> ubatches);
141
 
142
+ virtual ~llama_memory_recurrent_context();
143
 
144
  //
145
+ // llama_memory_context_i
146
  //
147
 
148
  bool next() override;
 
152
  const llama_ubatch & get_ubatch() const override;
153
 
154
  //
155
+ // llama_memory_recurrent_context specific API
156
  //
157
 
158
  uint32_t get_n_rs() const;
examples/talk-llama/llama-memory.cpp CHANGED
@@ -40,3 +40,20 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
40
  // if either status has an update, then the combined status has an update
41
  return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
42
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  // if either status has an update, then the combined status has an update
41
  return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
42
  }
43
+
44
+ bool llama_memory_status_is_fail(llama_memory_status status) {
45
+ switch (status) {
46
+ case LLAMA_MEMORY_STATUS_SUCCESS:
47
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
48
+ {
49
+ return false;
50
+ }
51
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
52
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
53
+ {
54
+ return true;
55
+ }
56
+ }
57
+
58
+ return false;
59
+ }
examples/talk-llama/llama-memory.h CHANGED
@@ -3,7 +3,6 @@
3
  #include "llama.h"
4
 
5
  #include <memory>
6
- #include <vector>
7
 
8
  struct llama_ubatch;
9
 
@@ -28,23 +27,24 @@ enum llama_memory_status {
28
  LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
29
  };
30
 
31
- // helper function for combining the status of two memory states
32
  // useful for implementing hybrid memory types (e.g. iSWA)
33
  llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
34
 
35
- // the interface for managing the memory state during batch processing
 
 
 
36
  // this interface is implemented per memory type. see:
37
- // - llama_kv_cache_unified_state
38
- // - llama_kv_cache_unified_iswa_state
39
  // ...
40
  //
41
- // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
42
- //
43
- // TODO: rename to llama_memory_context_i ?
44
- struct llama_memory_state_i {
45
- virtual ~llama_memory_state_i() = default;
46
 
47
- // consume the current ubatch from the state and proceed to the next one
48
  // return false if we are done
49
  virtual bool next() = 0;
50
 
@@ -55,11 +55,11 @@ struct llama_memory_state_i {
55
  // get the current ubatch
56
  virtual const llama_ubatch & get_ubatch() const = 0;
57
 
58
- // get the status of the memory state - used for error handling and checking if any updates would be applied
59
  virtual llama_memory_status get_status() const = 0;
60
  };
61
 
62
- using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
63
 
64
  // general concept of LLM memory
65
  // the KV cache is a type of LLM memory, but there can be other types
@@ -67,19 +67,19 @@ struct llama_memory_i {
67
  virtual ~llama_memory_i() = default;
68
 
69
  // split the input batch into a set of ubatches and verify that they can fit into the cache
70
- // return a state object containing the ubatches and KV cache state required to process them
71
- // check the llama_memory_state_i::get_status() for the result
72
- virtual llama_memory_state_ptr init_batch(
73
  llama_batch_allocr & balloc,
74
  uint32_t n_ubatch,
75
  bool embd_all) = 0;
76
 
77
  // simulate full cache, used for allocating worst-case compute buffers
78
- virtual llama_memory_state_ptr init_full() = 0;
79
 
80
  // prepare for any pending memory updates, such as shifts, defrags, etc.
81
  // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
82
- virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
83
 
84
  // getters
85
  virtual bool get_can_shift() const = 0;
 
3
  #include "llama.h"
4
 
5
  #include <memory>
 
6
 
7
  struct llama_ubatch;
8
 
 
27
  LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
28
  };
29
 
30
+ // helper function for combining the status of two memory contexts
31
  // useful for implementing hybrid memory types (e.g. iSWA)
32
  llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
33
 
34
+ // helper function for checking if a memory status indicates a failure
35
+ bool llama_memory_status_is_fail(llama_memory_status status);
36
+
37
+ // the interface for managing the memory context during batch processing
38
  // this interface is implemented per memory type. see:
39
+ // - llama_kv_cache_unified_context
40
+ // - llama_kv_cache_unified_iswa_context
41
  // ...
42
  //
43
+ // the only method that should mutate the memory and the memory context is llama_memory_i::apply()
44
+ struct llama_memory_context_i {
45
+ virtual ~llama_memory_context_i() = default;
 
 
46
 
47
+ // consume the current ubatch from the context and proceed to the next one
48
  // return false if we are done
49
  virtual bool next() = 0;
50
 
 
55
  // get the current ubatch
56
  virtual const llama_ubatch & get_ubatch() const = 0;
57
 
58
+ // get the status of the memory context - used for error handling and checking if any updates would be applied
59
  virtual llama_memory_status get_status() const = 0;
60
  };
61
 
62
+ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
63
 
64
  // general concept of LLM memory
65
  // the KV cache is a type of LLM memory, but there can be other types
 
67
  virtual ~llama_memory_i() = default;
68
 
69
  // split the input batch into a set of ubatches and verify that they can fit into the cache
70
+ // return a context object containing the ubatches and memory state required to process them
71
+ // check the llama_memory_context_i::get_status() for the result
72
+ virtual llama_memory_context_ptr init_batch(
73
  llama_batch_allocr & balloc,
74
  uint32_t n_ubatch,
75
  bool embd_all) = 0;
76
 
77
  // simulate full cache, used for allocating worst-case compute buffers
78
+ virtual llama_memory_context_ptr init_full() = 0;
79
 
80
  // prepare for any pending memory updates, such as shifts, defrags, etc.
81
  // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
82
+ virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
83
 
84
  // getters
85
  virtual bool get_can_shift() const = 0;
examples/talk-llama/llama-model.cpp CHANGED
@@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
47
  case LLM_TYPE_475M: return "475M";
48
  case LLM_TYPE_770M: return "770M";
49
  case LLM_TYPE_780M: return "780M";
 
50
  case LLM_TYPE_0_5B: return "0.5B";
51
  case LLM_TYPE_0_6B: return "0.6B";
52
  case LLM_TYPE_1B: return "1B";
@@ -103,6 +104,8 @@ const char * llm_type_name(llm_type type) {
103
  case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
104
  case LLM_TYPE_30B_A3B: return "30B.A3B";
105
  case LLM_TYPE_235B_A22B: return "235B.A22B";
 
 
106
  default: return "?B";
107
  }
108
  }
@@ -1017,6 +1020,24 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1017
  ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
1018
  : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
1019
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1020
  case LLM_ARCH_STARCODER2:
1021
  {
1022
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1484,6 +1505,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1484
  default: type = LLM_TYPE_UNKNOWN;
1485
  }
1486
  } break;
 
 
 
 
 
 
 
 
1487
  default: throw std::runtime_error("unsupported model architecture");
1488
  }
1489
 
@@ -2950,6 +2979,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2950
  layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
2951
  }
2952
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2953
  case LLM_ARCH_STARCODER2:
2954
  {
2955
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4268,6 +4353,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
4268
 
4269
  layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
4270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4271
  layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4272
  layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4273
  }
@@ -8980,6 +9099,442 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8980
  }
8981
  };
8982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8983
  // TODO: move up next to build_starcoder
8984
  struct llm_build_starcoder2 : public llm_graph_context {
8985
  llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
@@ -9171,9 +9726,9 @@ struct llm_build_mamba : public llm_graph_context {
9171
  ggml_tensor * cur,
9172
  const llama_ubatch & ubatch,
9173
  int il) const {
9174
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
9175
 
9176
- const auto kv_head = kv_state->get_head();
9177
 
9178
  const int64_t d_conv = hparams.ssm_d_conv;
9179
  const int64_t d_inner = hparams.ssm_d_inner;
@@ -9191,8 +9746,8 @@ struct llm_build_mamba : public llm_graph_context {
9191
  GGML_ASSERT(ubatch.equal_seqs);
9192
  GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
9193
 
9194
- ggml_tensor * conv_states_all = kv_state->get_r_l(il);
9195
- ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
9196
 
9197
  // (ab)using the KV cache to store the states
9198
  ggml_tensor * conv = build_rs(
@@ -11916,7 +12471,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11916
  ggml_tensor * x_prev,
11917
  const llama_ubatch & ubatch,
11918
  int il) const {
11919
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
11920
 
11921
  const auto n_tokens = ubatch.n_tokens;
11922
  const auto n_seqs = ubatch.n_seqs;
@@ -11926,7 +12481,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11926
  const auto n_head = n_embd / head_size;
11927
  const auto n_head_kv = hparams.n_head_kv(il);
11928
 
11929
- const auto kv_head = kv_state->get_head();
11930
 
11931
  const auto & layer = model.layers[il];
11932
 
@@ -12038,7 +12593,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
12038
  }
12039
 
12040
  ggml_tensor * wkv_state = build_rs(
12041
- inp, gf, kv_state->get_s_l(il),
12042
  hparams.n_embd_s(), n_seqs);
12043
 
12044
  ggml_tensor * wkv_output;
@@ -12057,9 +12612,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
12057
  wkv_state,
12058
  ggml_view_1d(
12059
  ctx0,
12060
- kv_state->get_s_l(il),
12061
  hparams.n_embd_s() * n_seqs,
12062
- hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
12063
  )
12064
  )
12065
  );
@@ -12313,7 +12868,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12313
  ggml_tensor *& first_layer_value,
12314
  const llama_ubatch & ubatch,
12315
  int il) const {
12316
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
12317
 
12318
  const auto n_tokens = ubatch.n_tokens;
12319
  const auto n_seqs = ubatch.n_seqs;
@@ -12322,7 +12877,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12322
  const auto head_count = n_embd / head_size;
12323
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12324
 
12325
- const auto kv_head = kv_state->get_head();
12326
 
12327
  const auto & layer = model.layers[il];
12328
 
@@ -12393,7 +12948,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12393
  a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
12394
 
12395
  ggml_tensor * wkv_state = build_rs(
12396
- inp, gf, kv_state->get_s_l(il),
12397
  hparams.n_embd_s(), n_seqs);
12398
 
12399
  ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -12407,9 +12962,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12407
  wkv_state,
12408
  ggml_view_1d(
12409
  ctx0,
12410
- kv_state->get_s_l(il),
12411
  hparams.n_embd_s() * n_seqs,
12412
- hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
12413
  )
12414
  )
12415
  );
@@ -13613,6 +14168,136 @@ struct llm_build_dots1 : public llm_graph_context {
13613
  }
13614
  };
13615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13616
  struct llm_build_arcee : public llm_graph_context {
13617
  llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
13618
  const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -13974,6 +14659,10 @@ llm_graph_result_ptr llama_model::build_graph(
13974
  {
13975
  llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
13976
  } break;
 
 
 
 
13977
  case LLM_ARCH_STARCODER2:
13978
  {
13979
  llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
@@ -14119,6 +14808,10 @@ llm_graph_result_ptr llama_model::build_graph(
14119
  {
14120
  llm = std::make_unique<llm_build_arcee>(*this, params, gf);
14121
  } break;
 
 
 
 
14122
  default:
14123
  GGML_ABORT("fatal error");
14124
  }
@@ -14270,6 +14963,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
14270
  case LLM_ARCH_BAILINGMOE:
14271
  case LLM_ARCH_NEO_BERT:
14272
  case LLM_ARCH_ARCEE:
 
14273
  return LLAMA_ROPE_TYPE_NORM;
14274
 
14275
  // the pairs of head values are offset by n_rot/2
@@ -14295,6 +14989,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
14295
  case LLM_ARCH_GEMMA:
14296
  case LLM_ARCH_GEMMA2:
14297
  case LLM_ARCH_GEMMA3:
 
14298
  case LLM_ARCH_STARCODER2:
14299
  case LLM_ARCH_OPENELM:
14300
  case LLM_ARCH_GPTNEOX:
@@ -14377,7 +15072,7 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
14377
  // do not extend this list unless absolutely necessary
14378
  // Mistral-Small-2503 does not have built-in chat template
14379
  llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
14380
- if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
14381
  return "mistral-v7-tekken";
14382
  }
14383
 
 
47
  case LLM_TYPE_475M: return "475M";
48
  case LLM_TYPE_770M: return "770M";
49
  case LLM_TYPE_780M: return "780M";
50
+ case LLM_TYPE_0_3B: return "0.3B";
51
  case LLM_TYPE_0_5B: return "0.5B";
52
  case LLM_TYPE_0_6B: return "0.6B";
53
  case LLM_TYPE_1B: return "1B";
 
104
  case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
105
  case LLM_TYPE_30B_A3B: return "30B.A3B";
106
  case LLM_TYPE_235B_A22B: return "235B.A22B";
107
+ case LLM_TYPE_E2B: return "E2B";
108
+ case LLM_TYPE_E4B: return "E4B";
109
  default: return "?B";
110
  }
111
  }
 
1020
  ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
1021
  : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
1022
  } break;
1023
+ case LLM_ARCH_GEMMA3N:
1024
+ {
1025
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1026
+ hparams.set_swa_pattern(5);
1027
+
1028
+ hparams.rope_freq_base_train_swa = 10000.0f;
1029
+ hparams.rope_freq_scale_train_swa = 1.0f;
1030
+ hparams.f_attention_scale = 1.0f;
1031
+
1032
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
1033
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1034
+
1035
+ switch (hparams.n_layer) {
1036
+ case 30: type = LLM_TYPE_E2B; break;
1037
+ case 35: type = LLM_TYPE_E4B; break;
1038
+ default: type = LLM_TYPE_UNKNOWN;
1039
+ }
1040
+ } break;
1041
  case LLM_ARCH_STARCODER2:
1042
  {
1043
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
 
1505
  default: type = LLM_TYPE_UNKNOWN;
1506
  }
1507
  } break;
1508
+ case LLM_ARCH_ERNIE4_5:
1509
+ {
1510
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1511
+ switch (hparams.n_layer) {
1512
+ case 18: type = LLM_TYPE_0_3B; break;
1513
+ default: type = LLM_TYPE_UNKNOWN;
1514
+ }
1515
+ } break;
1516
  default: throw std::runtime_error("unsupported model architecture");
1517
  }
1518
 
 
2979
  layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
2980
  }
2981
  } break;
2982
+ case LLM_ARCH_GEMMA3N:
2983
+ {
2984
+ const int64_t n_altup = hparams.n_altup;
2985
+ const int64_t laurel_rank = hparams.laurel_rank;
2986
+ const int64_t n_embd_altup = hparams.n_embd_altup;
2987
+
2988
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2989
+ // if output is NULL, init from the input tok embed
2990
+ if (output == NULL) {
2991
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
2992
+ }
2993
+
2994
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2995
+ tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
2996
+
2997
+ altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
2998
+ altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
2999
+ per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
3000
+ per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
3001
+
3002
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3003
+
3004
+ for (int i = 0; i < n_layer; ++i) {
3005
+ auto & layer = layers[i];
3006
+
3007
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
3008
+
3009
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
3010
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
3011
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
3012
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
3013
+
3014
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
3015
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
3016
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
3017
+
3018
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3019
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
3020
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3021
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3022
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
3023
+
3024
+ // altup & laurel
3025
+ layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0);
3026
+ layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0);
3027
+ layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
3028
+ layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0);
3029
+ layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
3030
+ layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0);
3031
+ layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0);
3032
+ layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0);
3033
+ layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0);
3034
+ layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0);
3035
+ layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
3036
+ }
3037
+ } break;
3038
  case LLM_ARCH_STARCODER2:
3039
  {
3040
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
4353
 
4354
  layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
4355
 
4356
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4357
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4358
+ }
4359
+ } break;
4360
+ case LLM_ARCH_ERNIE4_5:
4361
+ {
4362
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4363
+
4364
+ // output
4365
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4366
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4367
+ // if output is NULL, init from the input tok embed
4368
+ if (output == NULL) {
4369
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4370
+ }
4371
+
4372
+ for (int i = 0; i < n_layer; ++i) {
4373
+ auto & layer = layers[i];
4374
+
4375
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4376
+
4377
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4378
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
4379
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
4380
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4381
+
4382
+ // optional bias tensors
4383
+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4384
+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
4385
+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
4386
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4387
+
4388
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4389
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4390
  layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4391
  layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4392
  }
 
9099
  }
9100
  };
9101
 
9102
+ struct llm_build_gemma3n_iswa : public llm_graph_context {
9103
+ const llama_model & model;
9104
+ ggml_cgraph * gf;
9105
+
9106
+ const int64_t n_embd_head;
9107
+ const int64_t n_embd_altup;
9108
+ const int64_t n_altup;
9109
+ const int i_altup_act;
9110
+ const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
9111
+ const int n_layer_sparsity = 10; // number of layers using activation sparsity
9112
+ const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
9113
+
9114
+ ggml_tensor * one; // containing single element 1.0f
9115
+
9116
+ llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
9117
+ : llm_graph_context(params),
9118
+ model(model),
9119
+ gf(gf),
9120
+ n_embd_head(model.hparams.n_embd_head_k),
9121
+ n_embd_altup(model.hparams.n_embd_altup),
9122
+ n_altup(model.hparams.n_altup),
9123
+ i_altup_act(model.hparams.i_altup_act) {
9124
+ ggml_tensor * cur;
9125
+ ggml_tensor * inpL;
9126
+
9127
+ // TODO: remove this when ggml_scale_add is implemented
9128
+ one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
9129
+ {
9130
+ auto inp = std::make_unique<llm_graph_input_one>();
9131
+ inp->one = one;
9132
+ res->add_input(std::move(inp));
9133
+ }
9134
+
9135
+ inpL = build_inp_embd(model.tok_embd);
9136
+
9137
+ // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
9138
+ if (ubatch.token) {
9139
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
9140
+ cb(inpL, "inp_scaled", -1);
9141
+ }
9142
+
9143
+ // inp_pos - contains the positions
9144
+ ggml_tensor * inp_pos = build_inp_pos();
9145
+
9146
+ // TODO: is causal == true correct? might need some changes
9147
+ auto * inp_attn = build_attn_inp_kv_unified_iswa();
9148
+
9149
+ // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
9150
+ ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
9151
+
9152
+ // inpL now has only 1 altup, project it to the rest of the altups
9153
+ // these "added" altups will be concat to the last dim of inpL
9154
+ {
9155
+ ggml_tensor * target_magnitude = calc_magnitude(inpL);
9156
+ ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
9157
+ ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
9158
+ ggml_tensor * new_magnitude = calc_magnitude(altup_added);
9159
+ altup_added = ggml_div(ctx0,
9160
+ ggml_mul(ctx0, altup_added, target_magnitude),
9161
+ new_magnitude);
9162
+ inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
9163
+ cb(inpL, "inp_stacked", -1);
9164
+ }
9165
+
9166
+ // inpL now has shape: [n_embd, n_tokens, n_altup]
9167
+ // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
9168
+
9169
+ for (int il = 0; il < n_layer; ++il) {
9170
+ // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
9171
+ const bool has_kv = (il < n_layer_kv);
9172
+
9173
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
9174
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
9175
+
9176
+ ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
9177
+ ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
9178
+
9179
+ // predicted value will go through self-attention and laurel
9180
+ ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
9181
+ cur = active_prediction;
9182
+ cb(cur, "active_prediction", il);
9183
+
9184
+ // norm
9185
+ cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
9186
+ cb(cur, "attn_norm", il);
9187
+
9188
+ // laurel
9189
+ ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
9190
+
9191
+ // self-attention
9192
+ if (has_kv) {
9193
+ // compute Q and K and RoPE them
9194
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
9195
+ cb(Qcur, "Qcur", il);
9196
+
9197
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
9198
+ cb(Kcur, "Kcur", il);
9199
+
9200
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
9201
+ cb(Vcur, "Vcur", il);
9202
+
9203
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
9204
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
9205
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
9206
+
9207
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
9208
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
9209
+ Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
9210
+
9211
+ cb(Qcur, "Qcur_normed", il);
9212
+ cb(Kcur, "Kcur_normed", il);
9213
+ cb(Vcur, "Vcur_normed", il);
9214
+
9215
+ Qcur = ggml_rope_ext(
9216
+ ctx0, Qcur, inp_pos, nullptr,
9217
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
9218
+ ext_factor, attn_factor, beta_fast, beta_slow);
9219
+
9220
+ Kcur = ggml_rope_ext(
9221
+ ctx0, Kcur, inp_pos, nullptr,
9222
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
9223
+ ext_factor, attn_factor, beta_fast, beta_slow);
9224
+
9225
+ cb(Qcur, "Qcur_pos", il);
9226
+ cb(Kcur, "Kcur_pos", il);
9227
+
9228
+ cur = build_attn(inp_attn, gf,
9229
+ model.layers[il].wo, NULL,
9230
+ Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
9231
+ } else {
9232
+ // no KV layers
9233
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
9234
+ cb(Qcur, "Qcur", il);
9235
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
9236
+
9237
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
9238
+ cb(Qcur, "Qcur_normed", il);
9239
+
9240
+ Qcur = ggml_rope_ext(
9241
+ ctx0, Qcur, inp_pos, nullptr,
9242
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
9243
+ ext_factor, attn_factor, beta_fast, beta_slow);
9244
+ cb(Qcur, "Qcur_pos", il);
9245
+
9246
+ cur = build_attn(inp_attn, gf,
9247
+ model.layers[il].wo, NULL,
9248
+ Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
9249
+ }
9250
+
9251
+ cur = build_norm(cur,
9252
+ model.layers[il].attn_post_norm, NULL,
9253
+ LLM_NORM_RMS, il);
9254
+ cb(cur, "attn_post_norm", il);
9255
+
9256
+ cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
9257
+ cb(cur, "attn_gated", il);
9258
+
9259
+ ggml_tensor * attn_laurel = ggml_scale(ctx0,
9260
+ ggml_add(ctx0, cur, laurel_out),
9261
+ 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
9262
+ cb(attn_laurel, "attn_laurel", il);
9263
+
9264
+ cur = build_norm(attn_laurel,
9265
+ model.layers[il].ffn_norm, NULL,
9266
+ LLM_NORM_RMS, il);
9267
+ cb(cur, "ffn_norm", il);
9268
+
9269
+ // feed-forward network
9270
+ {
9271
+ ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur);
9272
+ ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
9273
+
9274
+ if (il < n_layer_sparsity) {
9275
+ // apply activation sparsity
9276
+ gate_proj = gaussian_topk(gate_proj);
9277
+ }
9278
+ gate_proj = ggml_gelu(ctx0, gate_proj);
9279
+
9280
+ cur = ggml_mul(ctx0, up_proj, gate_proj);
9281
+ cur = build_lora_mm(model.layers[il].ffn_down, cur);
9282
+ cb(cur, "ffn_out", il);
9283
+ }
9284
+
9285
+ cur = build_norm(cur,
9286
+ model.layers[il].ffn_post_norm, NULL,
9287
+ LLM_NORM_RMS, -1);
9288
+ cb(cur, "ffn_post_norm", il);
9289
+
9290
+ ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
9291
+ cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
9292
+
9293
+ ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
9294
+
9295
+ ggml_tensor * first_prediction; // [n_embd, n_tokens]
9296
+ {
9297
+ first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
9298
+ first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
9299
+ first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
9300
+ first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
9301
+ cb(first_prediction, "first_prediction_gated", il);
9302
+ ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
9303
+ first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
9304
+ cb(first_prediction, "first_prediction_scaled", il);
9305
+
9306
+ first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
9307
+ first_prediction = build_norm(first_prediction,
9308
+ model.layers[il].per_layer_post_norm, NULL,
9309
+ LLM_NORM_RMS, il);
9310
+ cb(first_prediction, "first_prediction_out", il);
9311
+ }
9312
+
9313
+ // equivalent to python code: corrected_predictions[1:] += first_prediction
9314
+ {
9315
+ ggml_tensor * slice_first = view_2d_slice(corrected, 0);
9316
+ ggml_tensor * slice_rest = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1,
9317
+ ggml_row_size(corrected->type, n_embd),
9318
+ ggml_row_size(corrected->type, n_embd*n_tokens),
9319
+ n_embd*n_tokens*ggml_element_size(corrected));
9320
+ ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
9321
+ corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
9322
+ }
9323
+
9324
+ cur = corrected; // [n_embd, n_tokens, n_altup]
9325
+ cur = build_cvec(cur, il);
9326
+ cb(cur, "l_out", il);
9327
+
9328
+ // input for next layer
9329
+ inpL = cur;
9330
+ }
9331
+
9332
+ cur = inpL; // [n_embd, n_tokens, n_altup]
9333
+
9334
+ // cur now has multiple altup(s), we want to merge them back to 1 altup
9335
+ {
9336
+ ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
9337
+ // do a view to skip the first slice (active altup)
9338
+ ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
9339
+ ggml_row_size(cur->type, n_embd),
9340
+ ggml_row_size(cur->type, n_embd*n_tokens),
9341
+ n_embd*n_tokens*ggml_element_size(cur));
9342
+ ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
9343
+ ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
9344
+ altup_unembd = ggml_div(ctx0,
9345
+ ggml_mul(ctx0, altup_unembd, target_magnitude),
9346
+ new_magnitude);
9347
+ cb(altup_unembd, "altup_unembd", -1);
9348
+
9349
+ // equivalent to torch.mean(hidden_states, dim=0)
9350
+ cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
9351
+ for (int i = 0; i < n_altup - 1; ++i) {
9352
+ cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
9353
+ }
9354
+ cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
9355
+ cb(cur, "unembd_merged", -1);
9356
+ }
9357
+
9358
+ // cur now has shape: [n_embd, n_tokens]
9359
+
9360
+ // TODO: move this to right after the last KV layer
9361
+ {
9362
+ // skip computing output for unused tokens
9363
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9364
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9365
+ }
9366
+
9367
+ cur = build_norm(cur,
9368
+ model.output_norm, NULL,
9369
+ LLM_NORM_RMS, -1);
9370
+
9371
+ cb(cur, "result_norm", -1);
9372
+ res->t_embd = cur;
9373
+
9374
+ cur = build_lora_mm(model.output, cur);
9375
+
9376
+ {
9377
+ // final logit soft-capping
9378
+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
9379
+ cur = ggml_tanh(ctx0, cur);
9380
+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
9381
+ }
9382
+
9383
+ cb(cur, "result_output", -1);
9384
+ res->t_logits = cur;
9385
+
9386
+ ggml_build_forward_expand(gf, cur);
9387
+ }
9388
+
9389
+ ggml_tensor * calc_magnitude(ggml_tensor * x) {
9390
+ return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
9391
+ }
9392
+
9393
+ // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
9394
+ ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) {
9395
+ GGML_ASSERT(idx < (int)x->ne[2]);
9396
+ return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
9397
+ ggml_row_size(x->type, x->ne[0]),
9398
+ idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
9399
+ }
9400
+
9401
+ // equivalent to get_per_layer_inputs() in python code
9402
+ // output shape: [n_embd_altup, n_layer, n_tokens]
9403
+ ggml_tensor * get_per_layer_inputs() {
9404
+ auto inp = std::make_unique<llm_graph_input_embd>();
9405
+ ggml_tensor * inp_per_layer;
9406
+ if (ubatch.token) {
9407
+ inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
9408
+ ggml_set_input(inp->tokens);
9409
+ res->t_tokens = inp->tokens;
9410
+ inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
9411
+ inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
9412
+ inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup));
9413
+ cb(inp_per_layer, "inp_per_layer_selected", -1);
9414
+ } else {
9415
+ GGML_ABORT("TODO: support embd input");
9416
+ }
9417
+ res->add_input(std::move(inp));
9418
+ return inp_per_layer;
9419
+ }
9420
+
9421
+ // equivalent to project_per_layer_inputs() in python code
9422
+ // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
9423
+ // output shape: [n_embd_altup, n_tokens, n_layer]
9424
+ ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
9425
+ const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd);
9426
+ const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
9427
+
9428
+ ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
9429
+ per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
9430
+ per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
9431
+ per_layer_proj = build_norm(per_layer_proj,
9432
+ model.per_layer_proj_norm, NULL,
9433
+ LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens]
9434
+ cb(per_layer_proj, "per_layer_proj", -1);
9435
+
9436
+ inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
9437
+ inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
9438
+ cb(inp_per_layer, "inp_per_layer", -1);
9439
+
9440
+ // permute to shape: [n_embd_altup, n_tokens, n_layer]
9441
+ inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
9442
+ return inp_per_layer;
9443
+ }
9444
+
9445
+ // input cur shape: [n_altup, n_tokens]
9446
+ // output shape: [n_altup, n_tokens]
9447
+ ggml_tensor * laurel(ggml_tensor * cur, int il) {
9448
+ ggml_tensor * tmp = cur;
9449
+ tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
9450
+ tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
9451
+ tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
9452
+ tmp = ggml_add(ctx0, tmp, cur);
9453
+ cb(tmp, "laurel_out", il);
9454
+ return tmp;
9455
+ }
9456
+
9457
+ // input x shape: [n_embd, n_tokens]
9458
+ // output shape: [n_embd, n_tokens]
9459
+ ggml_tensor * gaussian_topk(ggml_tensor * x) {
9460
+ ggml_tensor * mean = ggml_mean(ctx0, x);
9461
+ ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0,
9462
+ ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
9463
+ 1.0f / (float)(x->ne[0] - 1)
9464
+ ));
9465
+ ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
9466
+ return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
9467
+ }
9468
+
9469
+ //
9470
+ // altup functions
9471
+ //
9472
+
9473
+ // equivalent to compute_router_modalities() in python code
9474
+ // input x shape: [n_embd, n_tokens]
9475
+ // output shape: [n_altup, n_tokens]
9476
+ ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) {
9477
+ ggml_tensor * router_inputs = build_norm(x,
9478
+ model.layers[il].altup_router_norm, NULL,
9479
+ LLM_NORM_RMS, il);
9480
+
9481
+ // router_input_scale
9482
+ router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd);
9483
+
9484
+ ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
9485
+ return ggml_tanh(ctx0, output); // [n_altup, n_tokens]
9486
+ }
9487
+
9488
+ // input cur shape: [n_embd, n_tokens, n_altup]
9489
+ // output shape: [n_embd, n_tokens, n_altup]
9490
+ ggml_tensor * altup_predict(ggml_tensor * cur, int il) {
9491
+ ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
9492
+ ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
9493
+ cb(modalities, "modalities", il);
9494
+
9495
+ ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
9496
+ cb(all_coefs, "all_coefs", il);
9497
+ // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
9498
+ all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
9499
+
9500
+ // permute to [n_altup, n_embd, n_tokens]
9501
+ ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
9502
+ ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
9503
+
9504
+ // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
9505
+ predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
9506
+ predictions = ggml_add(ctx0, predictions, cur);
9507
+ cb(predictions, "predictions", il);
9508
+
9509
+ return predictions;
9510
+ }
9511
+
9512
+ // input predictions shape: [n_embd, n_tokens, n_altup]
9513
+ // input activated shape: [n_embd, n_tokens]
9514
+ // output shape: [n_embd, n_tokens, n_altup]
9515
+ ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
9516
+ ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
9517
+ cb(modalities, "modalities", il);
9518
+
9519
+ ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
9520
+ ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
9521
+ cb(innovation, "innovation", il);
9522
+
9523
+ ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
9524
+ all_coefs = ggml_add(ctx0, all_coefs, one);
9525
+ cb(all_coefs, "all_coefs", il);
9526
+ all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
9527
+ all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
9528
+
9529
+ innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
9530
+ ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
9531
+ corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
9532
+ cb(corrected, "corrected", il);
9533
+
9534
+ return corrected;
9535
+ }
9536
+ };
9537
+
9538
  // TODO: move up next to build_starcoder
9539
  struct llm_build_starcoder2 : public llm_graph_context {
9540
  llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
 
9726
  ggml_tensor * cur,
9727
  const llama_ubatch & ubatch,
9728
  int il) const {
9729
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
9730
 
9731
+ const auto kv_head = mctx_cur->get_head();
9732
 
9733
  const int64_t d_conv = hparams.ssm_d_conv;
9734
  const int64_t d_inner = hparams.ssm_d_inner;
 
9746
  GGML_ASSERT(ubatch.equal_seqs);
9747
  GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
9748
 
9749
+ ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
9750
+ ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
9751
 
9752
  // (ab)using the KV cache to store the states
9753
  ggml_tensor * conv = build_rs(
 
12471
  ggml_tensor * x_prev,
12472
  const llama_ubatch & ubatch,
12473
  int il) const {
12474
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
12475
 
12476
  const auto n_tokens = ubatch.n_tokens;
12477
  const auto n_seqs = ubatch.n_seqs;
 
12481
  const auto n_head = n_embd / head_size;
12482
  const auto n_head_kv = hparams.n_head_kv(il);
12483
 
12484
+ const auto kv_head = mctx_cur->get_head();
12485
 
12486
  const auto & layer = model.layers[il];
12487
 
 
12593
  }
12594
 
12595
  ggml_tensor * wkv_state = build_rs(
12596
+ inp, gf, mctx_cur->get_s_l(il),
12597
  hparams.n_embd_s(), n_seqs);
12598
 
12599
  ggml_tensor * wkv_output;
 
12612
  wkv_state,
12613
  ggml_view_1d(
12614
  ctx0,
12615
+ mctx_cur->get_s_l(il),
12616
  hparams.n_embd_s() * n_seqs,
12617
+ hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
12618
  )
12619
  )
12620
  );
 
12868
  ggml_tensor *& first_layer_value,
12869
  const llama_ubatch & ubatch,
12870
  int il) const {
12871
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
12872
 
12873
  const auto n_tokens = ubatch.n_tokens;
12874
  const auto n_seqs = ubatch.n_seqs;
 
12877
  const auto head_count = n_embd / head_size;
12878
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12879
 
12880
+ const auto kv_head = mctx_cur->get_head();
12881
 
12882
  const auto & layer = model.layers[il];
12883
 
 
12948
  a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
12949
 
12950
  ggml_tensor * wkv_state = build_rs(
12951
+ inp, gf, mctx_cur->get_s_l(il),
12952
  hparams.n_embd_s(), n_seqs);
12953
 
12954
  ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
 
12962
  wkv_state,
12963
  ggml_view_1d(
12964
  ctx0,
12965
+ mctx_cur->get_s_l(il),
12966
  hparams.n_embd_s() * n_seqs,
12967
+ hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
12968
  )
12969
  )
12970
  );
 
14168
  }
14169
  };
14170
 
14171
+ struct llm_build_ernie4_5 : public llm_graph_context {
14172
+ llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14173
+ const int64_t n_embd_head = hparams.n_embd_head_v;
14174
+
14175
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14176
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
14177
+
14178
+ ggml_tensor * cur;
14179
+ ggml_tensor * inpL;
14180
+
14181
+ inpL = build_inp_embd(model.tok_embd);
14182
+
14183
+ // inp_pos - contains the positions
14184
+ ggml_tensor * inp_pos = build_inp_pos();
14185
+
14186
+ auto * inp_attn = build_attn_inp_kv_unified();
14187
+
14188
+ for (int il = 0; il < n_layer; ++il) {
14189
+ ggml_tensor * inpSA = inpL;
14190
+
14191
+ // norm
14192
+ {
14193
+ cur = build_norm(inpL,
14194
+ model.layers[il].attn_norm, NULL,
14195
+ LLM_NORM_RMS, il);
14196
+ cb(cur, "attn_norm", il);
14197
+ }
14198
+
14199
+ // self-attention
14200
+ {
14201
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14202
+ cb(Qcur, "Qcur", il);
14203
+ if (model.layers[il].bq) {
14204
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14205
+ cb(Qcur, "Qcur", il);
14206
+ }
14207
+
14208
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14209
+ cb(Kcur, "Kcur", il);
14210
+ if (model.layers[il].bk) {
14211
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14212
+ cb(Kcur, "Kcur", il);
14213
+ }
14214
+
14215
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14216
+ cb(Vcur, "Vcur", il);
14217
+ if (model.layers[il].bv) {
14218
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14219
+ cb(Vcur, "Vcur", il);
14220
+ }
14221
+
14222
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14223
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14224
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14225
+
14226
+ Qcur = ggml_rope_ext(
14227
+ ctx0, Qcur, inp_pos, nullptr,
14228
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14229
+ ext_factor, attn_factor, beta_fast, beta_slow
14230
+ );
14231
+
14232
+ Kcur = ggml_rope_ext(
14233
+ ctx0, Kcur, inp_pos, nullptr,
14234
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14235
+ ext_factor, attn_factor, beta_fast, beta_slow
14236
+ );
14237
+
14238
+ cb(Qcur, "Qcur", il);
14239
+ cb(Kcur, "Kcur", il);
14240
+ cb(Vcur, "Vcur", il);
14241
+
14242
+ cur = build_attn(inp_attn, gf,
14243
+ model.layers[il].wo, NULL,
14244
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14245
+ }
14246
+
14247
+ if (il == n_layer - 1) {
14248
+ // skip computing output for unused tokens
14249
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
14250
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14251
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14252
+ }
14253
+
14254
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14255
+ cb(ffn_inp, "ffn_inp", il);
14256
+
14257
+ // feed-forward network
14258
+ {
14259
+ cur = build_norm(ffn_inp,
14260
+ model.layers[il].ffn_norm, NULL,
14261
+ LLM_NORM_RMS, il);
14262
+ cb(cur, "ffn_norm", il);
14263
+
14264
+ cur = build_ffn(cur,
14265
+ model.layers[il].ffn_up, NULL, NULL,
14266
+ model.layers[il].ffn_gate, NULL, NULL,
14267
+ model.layers[il].ffn_down, NULL, NULL,
14268
+ NULL,
14269
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
14270
+ cb(cur, "ffn_out", il);
14271
+ }
14272
+
14273
+ cur = ggml_add(ctx0, cur, ffn_inp);
14274
+
14275
+ cur = build_cvec(cur, il);
14276
+ cb(cur, "l_out", il);
14277
+
14278
+ // input for next layer
14279
+ inpL = cur;
14280
+ }
14281
+
14282
+ cur = inpL;
14283
+
14284
+ cur = build_norm(cur,
14285
+ model.output_norm, NULL,
14286
+ LLM_NORM_RMS, -1);
14287
+
14288
+ cb(cur, "result_norm", -1);
14289
+ res->t_embd = cur;
14290
+
14291
+ // lm_head
14292
+ cur = build_lora_mm(model.output, cur);
14293
+
14294
+ cb(cur, "result_output", -1);
14295
+ res->t_logits = cur;
14296
+
14297
+ ggml_build_forward_expand(gf, cur);
14298
+ }
14299
+ };
14300
+
14301
  struct llm_build_arcee : public llm_graph_context {
14302
  llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14303
  const int64_t n_embd_head = hparams.n_embd_head_v;
 
14659
  {
14660
  llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
14661
  } break;
14662
+ case LLM_ARCH_GEMMA3N:
14663
+ {
14664
+ llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
14665
+ } break;
14666
  case LLM_ARCH_STARCODER2:
14667
  {
14668
  llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
 
14808
  {
14809
  llm = std::make_unique<llm_build_arcee>(*this, params, gf);
14810
  } break;
14811
+ case LLM_ARCH_ERNIE4_5:
14812
+ {
14813
+ llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
14814
+ } break;
14815
  default:
14816
  GGML_ABORT("fatal error");
14817
  }
 
14963
  case LLM_ARCH_BAILINGMOE:
14964
  case LLM_ARCH_NEO_BERT:
14965
  case LLM_ARCH_ARCEE:
14966
+ case LLM_ARCH_ERNIE4_5:
14967
  return LLAMA_ROPE_TYPE_NORM;
14968
 
14969
  // the pairs of head values are offset by n_rot/2
 
14989
  case LLM_ARCH_GEMMA:
14990
  case LLM_ARCH_GEMMA2:
14991
  case LLM_ARCH_GEMMA3:
14992
+ case LLM_ARCH_GEMMA3N:
14993
  case LLM_ARCH_STARCODER2:
14994
  case LLM_ARCH_OPENELM:
14995
  case LLM_ARCH_GPTNEOX:
 
15072
  // do not extend this list unless absolutely necessary
15073
  // Mistral-Small-2503 does not have built-in chat template
15074
  llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
15075
+ if (!name && pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
15076
  return "mistral-v7-tekken";
15077
  }
15078
 
examples/talk-llama/llama-model.h CHANGED
@@ -39,6 +39,7 @@ enum llm_type {
39
  LLM_TYPE_475M,
40
  LLM_TYPE_770M,
41
  LLM_TYPE_780M,
 
42
  LLM_TYPE_0_5B,
43
  LLM_TYPE_0_6B,
44
  LLM_TYPE_1B,
@@ -95,6 +96,8 @@ enum llm_type {
95
  LLM_TYPE_17B_128E, // llama4 Maverick
96
  LLM_TYPE_30B_A3B,
97
  LLM_TYPE_235B_A22B,
 
 
98
  };
99
 
100
  std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
@@ -316,6 +319,19 @@ struct llama_layer {
316
  struct ggml_tensor * ffn_up_scale = nullptr;
317
  struct ggml_tensor * ffn_down_scale = nullptr;
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  struct llama_layer_posnet posnet;
320
 
321
  struct llama_layer_convnext convnext;
@@ -354,6 +370,13 @@ struct llama_model {
354
  struct ggml_tensor * conv1d = nullptr;
355
  struct ggml_tensor * conv1d_b = nullptr;
356
 
 
 
 
 
 
 
 
357
  std::vector<llama_layer> layers;
358
 
359
  llama_model_params params;
 
39
  LLM_TYPE_475M,
40
  LLM_TYPE_770M,
41
  LLM_TYPE_780M,
42
+ LLM_TYPE_0_3B,
43
  LLM_TYPE_0_5B,
44
  LLM_TYPE_0_6B,
45
  LLM_TYPE_1B,
 
96
  LLM_TYPE_17B_128E, // llama4 Maverick
97
  LLM_TYPE_30B_A3B,
98
  LLM_TYPE_235B_A22B,
99
+ LLM_TYPE_E2B,
100
+ LLM_TYPE_E4B,
101
  };
102
 
103
  std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
 
319
  struct ggml_tensor * ffn_up_scale = nullptr;
320
  struct ggml_tensor * ffn_down_scale = nullptr;
321
 
322
+ // altup & laurel
323
+ struct ggml_tensor * per_layer_inp_gate = nullptr;
324
+ struct ggml_tensor * per_layer_proj = nullptr;
325
+ struct ggml_tensor * per_layer_post_norm = nullptr;
326
+ struct ggml_tensor * altup_correct_coef = nullptr;
327
+ struct ggml_tensor * altup_correct_scale = nullptr;
328
+ struct ggml_tensor * altup_predict_coef = nullptr;
329
+ struct ggml_tensor * altup_router = nullptr;
330
+ struct ggml_tensor * altup_router_norm = nullptr;
331
+ struct ggml_tensor * laurel_l = nullptr;
332
+ struct ggml_tensor * laurel_r = nullptr;
333
+ struct ggml_tensor * laurel_post_norm = nullptr;
334
+
335
  struct llama_layer_posnet posnet;
336
 
337
  struct llama_layer_convnext convnext;
 
370
  struct ggml_tensor * conv1d = nullptr;
371
  struct ggml_tensor * conv1d_b = nullptr;
372
 
373
+ // gemma3n altup
374
+ struct ggml_tensor * tok_embd_per_layer = nullptr;
375
+ struct ggml_tensor * altup_proj = nullptr;
376
+ struct ggml_tensor * altup_unembd_proj = nullptr;
377
+ struct ggml_tensor * per_layer_model_proj = nullptr;
378
+ struct ggml_tensor * per_layer_proj_norm = nullptr;
379
+
380
  std::vector<llama_layer> layers;
381
 
382
  llama_model_params params;
examples/talk-llama/llama-quant.cpp CHANGED
@@ -1,5 +1,4 @@
1
  #include "llama-quant.h"
2
-
3
  #include "llama-impl.h"
4
  #include "llama-model.h"
5
  #include "llama-model-loader.h"
@@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) {
27
  }
28
  }
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  struct quantize_state_impl {
31
  const llama_model & model;
32
  const llama_model_quantize_params * params;
@@ -174,7 +223,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
174
  new_type = GGML_TYPE_Q6_K;
175
  }
176
  }
177
- } else if (name == "token_embd.weight") {
178
  if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
179
  new_type = qs.params->token_embedding_type;
180
  } else {
@@ -568,6 +617,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
568
  const size_t align = GGUF_DEFAULT_ALIGNMENT;
569
  gguf_context_ptr ctx_out { gguf_init_empty() };
570
 
 
 
 
 
 
571
  // copy the KV pairs from the input file
572
  gguf_set_kv (ctx_out.get(), ml.meta.get());
573
  gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
@@ -597,12 +651,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
597
  }
598
  }
599
 
 
 
 
 
600
  // make a list of weights
601
  std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
602
  tensors.reserve(ml.weights_map.size());
603
  for (const auto & it : ml.weights_map) {
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  tensors.push_back(&it.second);
605
  }
 
 
 
606
 
607
  // keep_split requires that the weights are sorted by split index
608
  if (params->keep_split) {
@@ -640,7 +714,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
640
  if (llama_model_has_encoder(&model)) {
641
  n_attn_layer *= 3;
642
  }
643
- GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
644
  }
645
 
646
  size_t total_size_org = 0;
@@ -681,7 +755,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
681
  for (size_t i = 0; i < ctx_outs.size(); ++i) {
682
  gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
683
  gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
684
- gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
685
  }
686
  }
687
 
@@ -756,6 +830,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
756
  // NOTE: can't use LLM_TN here because the layer number is not known
757
  quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
758
 
 
 
 
 
 
 
 
759
  // do not quantize positional embeddings and token types (BERT)
760
  quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
761
  quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
@@ -832,7 +913,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
832
 
833
  const float * imatrix = nullptr;
834
  if (imatrix_data) {
835
- auto it = imatrix_data->find(tensor->name);
836
  if (it == imatrix_data->end()) {
837
  LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
838
  } else {
@@ -947,6 +1028,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
947
  /*.imatrix =*/ nullptr,
948
  /*.kv_overrides =*/ nullptr,
949
  /*.tensor_type =*/ nullptr,
 
950
  };
951
 
952
  return result;
 
1
  #include "llama-quant.h"
 
2
  #include "llama-impl.h"
3
  #include "llama-model.h"
4
  #include "llama-model-loader.h"
 
26
  }
27
  }
28
 
29
+ static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
30
+ if (prune.empty()) {
31
+ return orig_name;
32
+ }
33
+
34
+ static const std::regex pattern(R"(blk\.(\d+)\.)");
35
+ if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
36
+ const int blk = std::stoi(match[1]);
37
+ std::string new_name = orig_name;
38
+
39
+ if (mapped.count(blk)) {
40
+ // Already mapped, do nothing
41
+ } else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
42
+ mapped[blk] = "";
43
+ } else if (blk < prune.front()) {
44
+ mapped[blk] = std::to_string(blk);
45
+ next_id = blk + 1;
46
+ } else {
47
+ mapped[blk] = std::to_string(next_id);
48
+ ++next_id;
49
+ }
50
+
51
+ return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
52
+ }
53
+
54
+ return orig_name;
55
+ }
56
+
57
+ static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
58
+ if (mapped.empty()) {
59
+ return orig_name;
60
+ }
61
+
62
+ static const std::regex pattern(R"(blk\.(\d+)\.)");
63
+ if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
64
+ const std::string blk(match[1]);
65
+ std::string new_name = orig_name;
66
+
67
+ for (const auto & p : mapped) {
68
+ if (p.second == blk) {
69
+ LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
70
+ return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
71
+ }
72
+ }
73
+ GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
74
+ }
75
+
76
+ return orig_name;
77
+ }
78
+
79
  struct quantize_state_impl {
80
  const llama_model & model;
81
  const llama_model_quantize_params * params;
 
223
  new_type = GGML_TYPE_Q6_K;
224
  }
225
  }
226
+ } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
227
  if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
228
  new_type = qs.params->token_embedding_type;
229
  } else {
 
617
  const size_t align = GGUF_DEFAULT_ALIGNMENT;
618
  gguf_context_ptr ctx_out { gguf_init_empty() };
619
 
620
+ std::vector<int> prune_list = {};
621
+ if (params->prune_layers) {
622
+ prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
623
+ }
624
+
625
  // copy the KV pairs from the input file
626
  gguf_set_kv (ctx_out.get(), ml.meta.get());
627
  gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
 
651
  }
652
  }
653
 
654
+ std::map<int, std::string> mapped;
655
+ int blk_id = 0;
656
+ int pruned_attention_w = 0;
657
+
658
  // make a list of weights
659
  std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
660
  tensors.reserve(ml.weights_map.size());
661
  for (const auto & it : ml.weights_map) {
662
+ const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
663
+ if (remapped_name.empty()) {
664
+ if (it.first.find("attn_v.weight") != std::string::npos ||
665
+ it.first.find("attn_qkv.weight") != std::string::npos ||
666
+ it.first.find("attn_kv_b.weight") != std::string::npos) {
667
+ pruned_attention_w++;
668
+ }
669
+ LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
670
+ continue;
671
+ } else if (remapped_name != it.first) {
672
+ ggml_set_name(it.second.tensor, remapped_name.c_str());
673
+ LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
674
+ }
675
  tensors.push_back(&it.second);
676
  }
677
+ if (!prune_list.empty()) {
678
+ gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
679
+ }
680
 
681
  // keep_split requires that the weights are sorted by split index
682
  if (params->keep_split) {
 
714
  if (llama_model_has_encoder(&model)) {
715
  n_attn_layer *= 3;
716
  }
717
+ GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
718
  }
719
 
720
  size_t total_size_org = 0;
 
755
  for (size_t i = 0; i < ctx_outs.size(); ++i) {
756
  gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
757
  gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
758
+ gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
759
  }
760
  }
761
 
 
830
  // NOTE: can't use LLM_TN here because the layer number is not known
831
  quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
832
 
833
+ // these are very small (e.g. 4x4)
834
+ quantize &= name.find("altup") == std::string::npos;
835
+ quantize &= name.find("laurel") == std::string::npos;
836
+
837
+ // these are not too big so keep them as it is
838
+ quantize &= name.find("per_layer_model_proj") == std::string::npos;
839
+
840
  // do not quantize positional embeddings and token types (BERT)
841
  quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
842
  quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
 
913
 
914
  const float * imatrix = nullptr;
915
  if (imatrix_data) {
916
+ auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
917
  if (it == imatrix_data->end()) {
918
  LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
919
  } else {
 
1028
  /*.imatrix =*/ nullptr,
1029
  /*.kv_overrides =*/ nullptr,
1030
  /*.tensor_type =*/ nullptr,
1031
+ /*.prune_layers =*/ nullptr
1032
  };
1033
 
1034
  return result;
examples/talk-llama/llama.h CHANGED
@@ -390,6 +390,7 @@ extern "C" {
390
  void * imatrix; // pointer to importance matrix data
391
  void * kv_overrides; // pointer to vector containing overrides
392
  void * tensor_types; // pointer to vector containing tensor types
 
393
  } llama_model_quantize_params;
394
 
395
  typedef struct llama_logit_bias {
@@ -943,12 +944,14 @@ extern "C" {
943
  // Requires the context to have a memory.
944
  // For encode-decoder contexts, processes the batch using the decoder.
945
  // Positive return values does not mean a fatal error, but rather a warning.
946
- // Upon non-zero return values, the memory state is restored to the state before this call
 
 
947
  // 0 - success
948
  // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
949
- // 2 - aborted
950
  // -1 - invalid input batch
951
- // < -1 - error
952
  LLAMA_API int32_t llama_decode(
953
  struct llama_context * ctx,
954
  struct llama_batch batch);
 
390
  void * imatrix; // pointer to importance matrix data
391
  void * kv_overrides; // pointer to vector containing overrides
392
  void * tensor_types; // pointer to vector containing tensor types
393
+ void * prune_layers; // pointer to vector containing layer indices to prune
394
  } llama_model_quantize_params;
395
 
396
  typedef struct llama_logit_bias {
 
944
  // Requires the context to have a memory.
945
  // For encode-decoder contexts, processes the batch using the decoder.
946
  // Positive return values does not mean a fatal error, but rather a warning.
947
+ // Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
948
+ // To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
949
+ // Upon other return values, the memory state is restored to the state before this call
950
  // 0 - success
951
  // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
952
+ // 2 - aborted (processed ubatches will remain in the context's memory)
953
  // -1 - invalid input batch
954
+ // < -1 - fatal error (processed ubatches will remain in the context's memory)
955
  LLAMA_API int32_t llama_decode(
956
  struct llama_context * ctx,
957
  struct llama_batch batch);