ggerganov commited on
Commit
ade9bc3
·
1 Parent(s): 48a7292

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/CMakeLists.txt CHANGED
@@ -18,7 +18,8 @@ if (WHISPER_SDL2)
18
  llama-io.cpp
19
  llama-kv-cache-unified.cpp
20
  llama-kv-cache-unified-iswa.cpp
21
- llama-kv-cache-recurrent.cpp
 
22
  llama-memory.cpp
23
  llama-mmap.cpp
24
  llama-model-loader.cpp
 
18
  llama-io.cpp
19
  llama-kv-cache-unified.cpp
20
  llama-kv-cache-unified-iswa.cpp
21
+ llama-memory-recurrent.cpp
22
+ llama-memory-hybrid.cpp
23
  llama-memory.cpp
24
  llama-mmap.cpp
25
  llama-model-loader.cpp
examples/talk-llama/llama-arch.cpp CHANGED
@@ -147,6 +147,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
147
  { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
148
  { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
149
  { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
 
150
 
151
  { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
152
  { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -197,6 +198,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
197
  { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
198
  { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
199
  { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
 
200
  { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
201
  { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
202
  { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
@@ -1816,3 +1818,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
1816
  const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
1817
  return LLM_TENSOR_INFOS.at(tensor);
1818
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
148
  { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
149
  { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
150
+ { LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
151
 
152
  { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
153
  { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
 
198
  { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
199
  { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
200
  { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
201
+ { LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
202
  { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
203
  { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
204
  { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
 
1818
  const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
1819
  return LLM_TENSOR_INFOS.at(tensor);
1820
  }
1821
+
1822
+ bool llm_arch_is_recurrent(const llm_arch & arch) {
1823
+ switch (arch) {
1824
+ case LLM_ARCH_MAMBA:
1825
+ case LLM_ARCH_RWKV6:
1826
+ case LLM_ARCH_RWKV6QWEN2:
1827
+ case LLM_ARCH_RWKV7:
1828
+ case LLM_ARCH_ARWKV7:
1829
+ return true;
1830
+ default:
1831
+ return false;
1832
+ }
1833
+ }
1834
+
1835
+ bool llm_arch_is_hybrid(const llm_arch & arch) {
1836
+ // TODO: There are currently no hybrid models! Once there are, this will be
1837
+ // the place to identify them
1838
+ switch (arch) {
1839
+ default:
1840
+ return false;
1841
+ }
1842
+ }
examples/talk-llama/llama-arch.h CHANGED
@@ -151,6 +151,7 @@ enum llm_kv {
151
  LLM_KV_ATTENTION_SCALE,
152
  LLM_KV_ATTENTION_KEY_LENGTH_MLA,
153
  LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
 
154
 
155
  LLM_KV_ROPE_DIMENSION_COUNT,
156
  LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -193,6 +194,7 @@ enum llm_kv {
193
  LLM_KV_TOKENIZER_MASK_ID,
194
  LLM_KV_TOKENIZER_ADD_BOS,
195
  LLM_KV_TOKENIZER_ADD_EOS,
 
196
  LLM_KV_TOKENIZER_ADD_PREFIX,
197
  LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
198
  LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
@@ -439,3 +441,6 @@ const char * llm_arch_name(llm_arch arch);
439
  llm_arch llm_arch_from_string(const std::string & name);
440
 
441
  const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
 
 
 
 
151
  LLM_KV_ATTENTION_SCALE,
152
  LLM_KV_ATTENTION_KEY_LENGTH_MLA,
153
  LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
154
+ LLM_KV_ATTENTION_LAYER_INDICES,
155
 
156
  LLM_KV_ROPE_DIMENSION_COUNT,
157
  LLM_KV_ROPE_DIMENSION_SECTIONS,
 
194
  LLM_KV_TOKENIZER_MASK_ID,
195
  LLM_KV_TOKENIZER_ADD_BOS,
196
  LLM_KV_TOKENIZER_ADD_EOS,
197
+ LLM_KV_TOKENIZER_ADD_SEP,
198
  LLM_KV_TOKENIZER_ADD_PREFIX,
199
  LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
200
  LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
 
441
  llm_arch llm_arch_from_string(const std::string & name);
442
 
443
  const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
444
+
445
+ bool llm_arch_is_recurrent(const llm_arch & arch);
446
+ bool llm_arch_is_hybrid (const llm_arch & arch);
examples/talk-llama/llama-batch.cpp CHANGED
@@ -1,7 +1,6 @@
1
  #include "llama-batch.h"
2
 
3
  #include "llama-impl.h"
4
- #include "llama-cparams.h"
5
  #include "llama-vocab.h"
6
  #include "llama-memory.h"
7
 
@@ -10,282 +9,7 @@
10
  #include <algorithm>
11
  #include <sstream>
12
 
13
- llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
14
- // clear empty sequences
15
- // the previous ubatch is assumed to be gone,
16
- // so nothing should refer to values in these sequences anymore.
17
- for (size_t i = seq.size(); i-- > 0;) {
18
- if (seq[i].length == 0) {
19
- seq.pop_back();
20
- } else {
21
- break;
22
- }
23
- }
24
-
25
- udatas.push_back({});
26
-
27
- auto & udata = udatas.back();
28
-
29
- udata.token.resize(!has_embd ? n_ubatch : 0);
30
- udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
31
- udata.pos.resize(n_ubatch);
32
- udata.n_seq_id.resize(n_ubatch);
33
- udata.seq_id.resize(n_ubatch);
34
- udata.output.resize(n_ubatch);
35
-
36
- llama_ubatch ubatch = {
37
- /*equal_seqs =*/ true,
38
- /*n_tokens =*/ 0,
39
- /*n_seq_tokens =*/ 0,
40
- /*n_seqs =*/ 0,
41
- /*token =*/ !has_embd ? udata.token.data() : nullptr,
42
- /*embd =*/ has_embd ? udata.embd.data() : nullptr,
43
- /*pos =*/ udata.pos.data(),
44
- /*n_seq_id =*/ udata.n_seq_id.data(),
45
- /*seq_id =*/ udata.seq_id.data(),
46
- /*output =*/ udata.output.data(),
47
- };
48
-
49
- return ubatch;
50
- }
51
-
52
- void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
53
- GGML_ASSERT(batch != nullptr);
54
- GGML_ASSERT(length <= seq.length);
55
- // Can only add sequences of equal lengths to a batch,
56
- // otherwise it isn't clear to which sequence a token belongs
57
- GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
58
- GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
59
- // NOTE: loops are separated for cache-friendliness
60
- if (batch->token) {
61
- if (ubatch.equal_seqs) {
62
- for (size_t i = 0; i < length; ++i) {
63
- ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
64
- }
65
- } else {
66
- // simple split
67
- ubatch.token = batch->token + seq.offset;
68
- }
69
- } else {
70
- ubatch.token = nullptr;
71
- }
72
- if (batch->embd) {
73
- if (ubatch.equal_seqs) {
74
- for (size_t i = 0; i < length; ++i) {
75
- memcpy(
76
- ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
77
- batch->embd + (n_embd * ids[seq.offset + i]),
78
- n_embd * sizeof(float)
79
- );
80
- }
81
- } else {
82
- // simple split
83
- ubatch.embd = batch->embd + (n_embd * seq.offset);
84
- }
85
- } else {
86
- ubatch.embd = nullptr;
87
- }
88
- if (ubatch.equal_seqs) {
89
- for (size_t i = 0; i < length; ++i) {
90
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
91
- }
92
- } else {
93
- // simple split
94
- ubatch.pos = batch->pos + seq.offset;
95
- }
96
- if (ubatch.equal_seqs) {
97
- ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
98
- if (seq.seq_id) {
99
- ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
100
- }
101
- } else {
102
- // simple split
103
- if (batch->n_seq_id) {
104
- ubatch.n_seq_id = batch->n_seq_id + seq.offset;
105
- } else {
106
- for (size_t i = 0; i < length; ++i) {
107
- ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
108
- }
109
- }
110
- if (batch->seq_id) {
111
- ubatch.seq_id = batch->seq_id + seq.offset;
112
- }
113
- }
114
- if (batch->logits) {
115
- if (ubatch.equal_seqs) {
116
- for (size_t i = 0; i < length; ++i) {
117
- size_t id = ids[seq.offset + i];
118
- int8_t is_output = batch->logits[id];
119
- ubatch.output[ubatch.n_tokens + i] = is_output;
120
- if (is_output) { out_ids.push_back(id); }
121
- }
122
- } else {
123
- // simple split
124
- ubatch.output = batch->logits + seq.offset;
125
- for (size_t i = 0; i < length; ++i) {
126
- if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
127
- }
128
- }
129
- } else {
130
- // only get last output
131
- for (size_t i = 0; i < length; ++i) {
132
- size_t id = ids[seq.offset + i];
133
- int8_t is_last = id == ids.size() - 1;
134
- ubatch.output[ubatch.n_tokens + i] = is_last;
135
- if (is_last) { out_ids.push_back(id); }
136
- }
137
- }
138
- if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
139
- ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
140
- }
141
- ubatch.n_tokens += length;
142
- ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
143
- seq.offset += length;
144
- seq.length -= length;
145
- n_tokens -= length;
146
- GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
147
- }
148
-
149
- llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
150
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
151
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
152
- ubatch.equal_seqs = false;
153
- if (!seq.empty()) {
154
- llama_sbatch_seq & s = seq[0];
155
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
156
- GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
157
- add_seq_to_ubatch(ubatch, s, length);
158
- }
159
- return ubatch;
160
- }
161
-
162
- llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
163
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
164
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
165
- if (!seq.empty()) {
166
- size_t length = 0;
167
- size_t n_tokens_in_ubatch = 0;
168
- GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
169
- // smallest first, because it's easier to split this way;
170
- // starting from the end to pop in constant time.
171
- for (size_t i = seq.size(); i-- > 0;) {
172
- llama_sbatch_seq & s = seq[i];
173
- GGML_ASSERT(s.length > 0);
174
- if (length == 0) {
175
- length = s.length < n_ubatch ? s.length : n_ubatch;
176
- }
177
- add_seq_to_ubatch(ubatch, s, length);
178
- n_tokens_in_ubatch += length;
179
- // shared prompts can't be mixed with any of their sequences,
180
- // so it's safer to compute them in their own ubatch
181
- if (s.n_seq_id > 1) { break; }
182
- // stop when there isn't enough space for another sequence
183
- if (length + n_tokens_in_ubatch > n_ubatch) { break; }
184
- }
185
- }
186
- return ubatch;
187
- }
188
-
189
- llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
190
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
191
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
192
- if (!seq.empty()) {
193
- llama_sbatch_seq & s = seq[seq.size() - 1];
194
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
195
- GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
196
- add_seq_to_ubatch(ubatch, s, length);
197
- }
198
- return ubatch;
199
- }
200
-
201
- llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
202
- GGML_ASSERT(batch.n_tokens >= 0);
203
- this->batch = &batch;
204
- this->n_embd = n_embd;
205
-
206
- n_tokens = batch.n_tokens;
207
- ids.resize(n_tokens);
208
- out_ids.clear();
209
- // TODO: reserve out_ids and seq
210
-
211
- for (size_t i = 0; i < n_tokens; ++i) {
212
- ids[i] = i;
213
- }
214
-
215
- if (simple_split) {
216
- seq.resize(1);
217
- llama_sbatch_seq & s = seq[0];
218
- s.n_seq_id = 0;
219
- s.seq_id = nullptr;
220
- s.offset = 0;
221
- s.length = n_tokens;
222
- return;
223
- }
224
-
225
- std::sort(ids.begin(), ids.end(),
226
- [&batch](size_t a, size_t b) {
227
- int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
228
- int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
229
- // sort by seq_id, then by pos
230
- if (n_seq_a == n_seq_b) {
231
- if (batch.seq_id) {
232
- for (int32_t i = 0; i < n_seq_a; ++i) {
233
- llama_seq_id seq_id_a = batch.seq_id[a][i];
234
- llama_seq_id seq_id_b = batch.seq_id[b][i];
235
- // smaller seq_ids go first
236
- if (seq_id_a != seq_id_b) {
237
- return seq_id_a < seq_id_b;
238
- }
239
- }
240
- }
241
- // when all else is equal, sort by pos
242
- if (batch.pos) {
243
- return batch.pos[a] < batch.pos[b];
244
- }
245
- // no pos, sort by id
246
- return a < b;
247
- }
248
- // shared prompts go first
249
- return n_seq_a > n_seq_b;
250
- }
251
- );
252
-
253
- // init seq
254
- llama_sbatch_seq * last_seq = nullptr;
255
-
256
- for (size_t i = 0; i < n_tokens; ++i) {
257
- const size_t bi = ids[i];
258
- const int32_t n_seqs = batch.n_seq_id[bi];
259
- llama_seq_id * seq_ids = batch.seq_id[bi];
260
- if (last_seq != nullptr) {
261
- bool same = n_seqs == last_seq->n_seq_id;
262
- for (int32_t j = 0; same && j < n_seqs; ++j) {
263
- if (seq_ids[j] != last_seq->seq_id[j]) {
264
- same = false;
265
- }
266
- }
267
- if (same) {
268
- last_seq->length += 1;
269
- continue;
270
- }
271
- }
272
- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
273
- seq.push_back(new_seq);
274
- last_seq = &seq.back();
275
- }
276
-
277
- // keep shared prompts first at the end, then sort by length descending.
278
- std::sort(seq.begin(), seq.end(),
279
- [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
280
- if (a.n_seq_id == b.n_seq_id) {
281
- return a.length > b.length;
282
- }
283
- return a.n_seq_id < b.n_seq_id;
284
- }
285
- );
286
- }
287
-
288
- llama_batch_allocr::llama_batch_allocr() {
289
  const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
290
  debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
291
 
@@ -294,17 +18,22 @@ llama_batch_allocr::llama_batch_allocr() {
294
  for (auto & cur : seq_cpl) {
295
  cur.resize(LLAMA_MAX_SEQ);
296
  }
 
 
297
  }
298
 
299
  bool llama_batch_allocr::init(
300
  const llama_batch & batch_inp,
301
  const llama_vocab & vocab,
302
  const llama_memory_i * memory,
303
- bool embd_all) {
 
304
  clear();
305
 
306
  batch = batch_inp;
307
 
 
 
308
  GGML_ASSERT(batch.n_tokens > 0);
309
 
310
  //
@@ -359,6 +88,7 @@ bool llama_batch_allocr::init(
359
  llama_pos p0[LLAMA_MAX_SEQ];
360
  for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
361
  if (!memory) {
 
362
  p0[s] = 0;
363
  } else {
364
  p0[s] = memory->seq_pos_max(s) + 1;
@@ -370,8 +100,11 @@ bool llama_batch_allocr::init(
370
 
371
  pos[i] = p0[seq_id];
372
 
 
373
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
374
- p0[batch.seq_id[i][s]] = pos[i] + 1;
 
 
375
  }
376
  }
377
 
@@ -379,7 +112,7 @@ bool llama_batch_allocr::init(
379
  }
380
 
381
  if (!batch.logits) {
382
- if (embd_all) {
383
  // return the output for all tokens
384
  output.resize(batch.n_tokens, true);
385
  } else {
@@ -389,7 +122,7 @@ bool llama_batch_allocr::init(
389
  }
390
 
391
  batch.logits = output.data();
392
- } else if (embd_all) {
393
  bool warn = false;
394
 
395
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -410,6 +143,9 @@ bool llama_batch_allocr::init(
410
  // compute stats
411
  //
412
 
 
 
 
413
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
414
  n_outputs += batch.logits[i] != 0;
415
  }
@@ -417,85 +153,86 @@ bool llama_batch_allocr::init(
417
  // determine coupled sequences
418
  // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
419
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
 
 
420
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
421
- seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
422
 
423
- if (s > 0) {
424
- const llama_seq_id s0 = batch.seq_id[i][0];
425
- const llama_seq_id s1 = batch.seq_id[i][s];
426
 
 
427
  // mark that sequence s1 is coupled to s0
428
  seq_cpl[s1][s0] = true;
429
 
430
- // note: the other way around is not necessary for now
431
  //seq_cpl[s0][s1] = true;
432
  }
433
  }
434
  }
435
 
436
- if (debug > 0) {
437
- LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
438
- LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
439
- LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
440
- LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
441
- LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
442
- LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
443
- LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
444
- LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
445
- LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
446
 
447
- if (debug > 1) {
448
- int seq_id_max = 0;
449
- for (int32_t i = 0; i < batch.n_tokens; ++i) {
450
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
451
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
452
- seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
453
- }
454
- }
455
  }
456
- ++seq_id_max;
457
 
458
- LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
459
- for (int32_t i = 0; i < batch.n_tokens; ++i) {
460
- std::vector<int8_t> seq_id(seq_id_max);
461
 
462
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
463
- seq_id[batch.seq_id[i][s]] = 1;
464
- }
 
 
 
 
465
 
466
- std::stringstream ss;
467
- for (int s = 0; s < seq_id_max; ++s) {
468
- if (seq_id[s]) {
469
- ss << s%10;
470
- } else {
471
- ss << ".";
472
- }
473
- }
474
 
475
- LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
476
- __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
477
- batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  }
479
- LLAMA_LOG_DEBUG("%s: ]\n", __func__);
480
-
481
- LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
482
- for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
483
- if (seq_pos[s0].empty()) {
484
- continue;
485
- }
486
 
487
- std::stringstream ss;
488
- for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
489
- if (seq_cpl[s0][s1]) {
490
- ss << s1 << " ";
491
- }
492
  }
493
-
494
- LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
495
- __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
496
  }
497
- LLAMA_LOG_DEBUG("%s: ]\n", __func__);
 
 
498
  }
 
499
  }
500
 
501
  //
@@ -507,9 +244,22 @@ bool llama_batch_allocr::init(
507
  continue;
508
  }
509
 
510
- if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
511
- LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
512
- return false;
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  }
514
 
515
  if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
@@ -532,17 +282,120 @@ bool llama_batch_allocr::init(
532
  }
533
  }
534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  return true;
536
  }
537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  const llama_batch & llama_batch_allocr::get_batch() const {
539
  return batch;
540
  }
541
 
 
 
 
 
542
  uint32_t llama_batch_allocr::get_n_outputs() const {
543
  return n_outputs;
544
  }
545
 
 
 
 
 
546
  llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
547
  return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
548
  }
@@ -551,14 +404,188 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
551
  return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
552
  }
553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  void llama_batch_allocr::clear() {
555
  n_outputs = 0;
556
 
557
  batch = {};
558
- pos.clear();
559
- n_seq_id.clear();
560
- seq_id.clear();
561
- output.clear();
 
 
562
 
563
  for (auto & cur : seq_pos) {
564
  cur.clear();
@@ -567,6 +594,177 @@ void llama_batch_allocr::clear() {
567
  for (auto & cur : seq_cpl) {
568
  std::fill(cur.begin(), cur.end(), false);
569
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  }
571
 
572
  //
@@ -577,25 +775,25 @@ struct llama_batch llama_batch_get_one(
577
  llama_token * tokens,
578
  int32_t n_tokens) {
579
  return {
580
- /*n_tokens =*/ n_tokens,
581
- /*tokens =*/ tokens,
582
- /*embd =*/ nullptr,
583
- /*pos =*/ nullptr,
584
- /*n_seq_id =*/ nullptr,
585
- /*seq_id =*/ nullptr,
586
- /*logits =*/ nullptr,
587
  };
588
  }
589
 
590
  struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
591
  llama_batch batch = {
592
- /*n_tokens =*/ 0,
593
- /*tokens =*/ nullptr,
594
- /*embd =*/ nullptr,
595
- /*pos =*/ nullptr,
596
- /*n_seq_id =*/ nullptr,
597
- /*seq_id =*/ nullptr,
598
- /*logits =*/ nullptr,
599
  };
600
 
601
  if (embd) {
 
1
  #include "llama-batch.h"
2
 
3
  #include "llama-impl.h"
 
4
  #include "llama-vocab.h"
5
  #include "llama-memory.h"
6
 
 
9
  #include <algorithm>
10
  #include <sstream>
11
 
12
+ llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
14
  debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
15
 
 
18
  for (auto & cur : seq_cpl) {
19
  cur.resize(LLAMA_MAX_SEQ);
20
  }
21
+
22
+ seq_idx.resize(LLAMA_MAX_SEQ, -1);
23
  }
24
 
25
  bool llama_batch_allocr::init(
26
  const llama_batch & batch_inp,
27
  const llama_vocab & vocab,
28
  const llama_memory_i * memory,
29
+ uint32_t n_embd,
30
+ bool output_all) {
31
  clear();
32
 
33
  batch = batch_inp;
34
 
35
+ this->vocab = &vocab;
36
+
37
  GGML_ASSERT(batch.n_tokens > 0);
38
 
39
  //
 
88
  llama_pos p0[LLAMA_MAX_SEQ];
89
  for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
90
  if (!memory) {
91
+ // if no memory -> start from 0
92
  p0[s] = 0;
93
  } else {
94
  p0[s] = memory->seq_pos_max(s) + 1;
 
100
 
101
  pos[i] = p0[seq_id];
102
 
103
+ // update the starting position for all sequences that are assigned to the this token
104
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
105
+ const llama_seq_id seq_id = batch.seq_id[i][s];
106
+
107
+ p0[seq_id] = pos[i] + 1;
108
  }
109
  }
110
 
 
112
  }
113
 
114
  if (!batch.logits) {
115
+ if (output_all) {
116
  // return the output for all tokens
117
  output.resize(batch.n_tokens, true);
118
  } else {
 
122
  }
123
 
124
  batch.logits = output.data();
125
+ } else if (output_all) {
126
  bool warn = false;
127
 
128
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
 
143
  // compute stats
144
  //
145
 
146
+ this->n_embd = n_embd;
147
+
148
+ // count the outputs in this batch
149
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
150
  n_outputs += batch.logits[i] != 0;
151
  }
 
153
  // determine coupled sequences
154
  // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
155
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
156
+ const llama_seq_id s0 = batch.seq_id[i][0];
157
+
158
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
159
+ const llama_seq_id s1 = batch.seq_id[i][s];
160
 
161
+ seq_pos[s1].insert(batch.pos[i]);
 
 
162
 
163
+ if (s > 0) {
164
  // mark that sequence s1 is coupled to s0
165
  seq_cpl[s1][s0] = true;
166
 
167
+ // note: tracking the other way around is not necessary for now
168
  //seq_cpl[s0][s1] = true;
169
  }
170
  }
171
  }
172
 
173
+ // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
174
+ {
175
+ seq_set_t seq_set_unq;
 
 
 
 
 
 
 
176
 
177
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
178
+ seq_set_t cur;
179
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
180
+ const llama_seq_id seq_id = batch.seq_id[i][s];
181
+
182
+ cur .set(seq_id);
183
+ seq_set_unq.set(seq_id);
 
184
  }
 
185
 
186
+ seq_set.push_back(cur);
187
+ seq_set_map[cur].push_back(i);
188
+ }
189
 
190
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
191
+ if (seq_set_unq.test(s)) {
192
+ seq_idx[s] = seq_id_unq.size();
193
+ seq_id_unq.push_back(s);
194
+ }
195
+ }
196
+ }
197
 
198
+ if (debug > 0) {
199
+ LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
 
 
 
 
 
 
200
 
201
+ llama_ubatch ubatch {
202
+ /*.equal_seqs =*/ false,
203
+ /*.n_tokens =*/ (uint32_t) batch.n_tokens,
204
+ /*.n_seq_tokens =*/ (uint32_t) 1,
205
+ /*.n_seqs =*/ (uint32_t) batch.n_tokens,
206
+ /*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
207
+ /*.token =*/ batch.token,
208
+ /*.embd =*/ batch.embd,
209
+ /*.pos =*/ batch.pos,
210
+ /*.n_seq_id =*/ batch.n_seq_id,
211
+ /*.seq_id =*/ batch.seq_id,
212
+ /*.seq_id_unq =*/ this->seq_id_unq.data(),
213
+ /*.seq_idx =*/ this->seq_idx.data(),
214
+ /*.output =*/ batch.logits,
215
+ };
216
+
217
+ ubatch_print(ubatch, debug);
218
+
219
+ LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
220
+ for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
221
+ if (seq_pos[s0].empty()) {
222
+ continue;
223
  }
 
 
 
 
 
 
 
224
 
225
+ std::stringstream ss;
226
+ for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
227
+ if (seq_cpl[s0][s1]) {
228
+ ss << s1 << " ";
 
229
  }
 
 
 
230
  }
231
+
232
+ LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
233
+ __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
234
  }
235
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
236
  }
237
 
238
  //
 
244
  continue;
245
  }
246
 
247
+ if (memory) {
248
+ if (batch.token) {
249
+ if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
250
+ LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
251
+ return false;
252
+ }
253
+ } else {
254
+ assert(batch.embd);
255
+
256
+ // for embeddings (typically used as vision input), we allow them to have repeating positions
257
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
258
+ if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
259
+ LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
260
+ return false;
261
+ }
262
+ }
263
  }
264
 
265
  if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
 
282
  }
283
  }
284
 
285
+ // disallow partial sequence sub-sets:
286
+ //
287
+ // invalid: x
288
+ // i: 0 1 2 ...
289
+ // ---------------------------------------
290
+ // seq_id[i][0]: 0 0 1
291
+ // seq_id[i][1]: 1 1 2
292
+ // seq_id[i][2]: 2
293
+ //
294
+ // disallow decreasing sequence positions:
295
+ //
296
+ // invalid: x
297
+ // i: 0 1 2 3 4 5 6 ...
298
+ // ---------------------------------------
299
+ // pos[i]: 4 5 0 1 6 2 3
300
+ // seq_id[i][0]: 0 0 1 1 0 1 0
301
+ //
302
+ {
303
+ seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
304
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
305
+ cur_seq_set[s].set();
306
+ }
307
+
308
+ llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
309
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
310
+ cur_seq_pos[s] = -1;
311
+ }
312
+
313
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
314
+ const llama_pos pos = batch.pos[i];
315
+
316
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
317
+ const llama_seq_id seq_id = batch.seq_id[i][s];
318
+
319
+ cur_seq_set[seq_id] &= seq_set[i];
320
+
321
+ if (cur_seq_set[seq_id].none()) {
322
+ LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
323
+ return false;
324
+ }
325
+
326
+ if (pos < cur_seq_pos[seq_id]) {
327
+ LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
328
+ return false;
329
+ }
330
+ }
331
+ }
332
+ }
333
+
334
+ split_reset();
335
+
336
  return true;
337
  }
338
 
339
+ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
340
+ const uint32_t n_tokens = n_seq_tokens*n_seqs;
341
+
342
+ clear();
343
+ split_reset();
344
+
345
+ ubatches.emplace_back();
346
+
347
+ auto & ubatch = ubatches.back();
348
+
349
+ ubatch.token .resize(n_tokens);
350
+ ubatch.embd .clear();
351
+ ubatch.pos .resize(n_tokens);
352
+ ubatch.n_seq_id .resize(n_tokens);
353
+ ubatch.seq_id .resize(n_tokens);
354
+ ubatch.seq_id_unq.resize(0);
355
+ ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
356
+ ubatch.output .resize(n_tokens);
357
+
358
+ for (uint32_t s = 0; s < n_seqs; ++s) {
359
+ ubatch.seq_idx[s] = s;
360
+ ubatch.seq_id_unq.push_back(s);
361
+ }
362
+
363
+ llama_ubatch res {
364
+ /*.equal_seqs =*/ true,
365
+ /*.n_tokens =*/ n_tokens,
366
+ /*.n_seq_tokens =*/ n_seq_tokens,
367
+ /*.n_seqs =*/ n_seqs,
368
+ /*.n_seqs_unq =*/ n_seqs,
369
+
370
+ /*.token =*/ ubatch.token.data(),
371
+ /*.embd =*/ nullptr,
372
+ /*.pos =*/ ubatch.pos.data(),
373
+ /*.n_seq_id =*/ ubatch.n_seq_id.data(),
374
+ /*.seq_id =*/ ubatch.seq_id.data(),
375
+ /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
376
+ /*.seq_idx =*/ ubatch.seq_idx.data(),
377
+ /*.output =*/ ubatch.output.data(),
378
+ };
379
+
380
+ return res;
381
+ }
382
+
383
  const llama_batch & llama_batch_allocr::get_batch() const {
384
  return batch;
385
  }
386
 
387
+ uint32_t llama_batch_allocr::get_n_tokens() const {
388
+ return batch.n_tokens;
389
+ }
390
+
391
  uint32_t llama_batch_allocr::get_n_outputs() const {
392
  return n_outputs;
393
  }
394
 
395
+ std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
396
+ return out_ids;
397
+ }
398
+
399
  llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
400
  return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
401
  }
 
404
  return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
405
  }
406
 
407
+ void llama_batch_allocr::split_reset() {
408
+ out_ids.clear();
409
+
410
+ used.clear();
411
+ used.resize(get_n_tokens(), false);
412
+
413
+ ubatches.clear();
414
+ }
415
+
416
+ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
417
+ // find the first unused token
418
+ uint32_t cur_idx = 0;
419
+ while (cur_idx < used.size() && used[cur_idx]) {
420
+ ++cur_idx;
421
+ }
422
+
423
+ // we are done
424
+ if (cur_idx >= used.size()) {
425
+ return {};
426
+ }
427
+
428
+ std::vector<int32_t> idxs;
429
+
430
+ while (true) {
431
+ idxs.push_back(cur_idx);
432
+
433
+ used[cur_idx] = true;
434
+
435
+ ++cur_idx;
436
+
437
+ if (cur_idx >= used.size()) {
438
+ break;
439
+ }
440
+
441
+ if (idxs.size() >= n_ubatch) {
442
+ break;
443
+ }
444
+ }
445
+
446
+ return ubatch_add(idxs, idxs.size(), false);
447
+ }
448
+
449
+ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
450
+ std::vector<seq_set_t> cur_seq_set;
451
+
452
+ // determine the non-overlapping sequence sets participating in this ubatch
453
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
454
+ if (used[i]) {
455
+ continue;
456
+ }
457
+
458
+ bool add = true;
459
+
460
+ for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
461
+ // no overlap with existing sequence sets:
462
+ if (!(cur_seq_set[s] & seq_set[i]).none()) {
463
+ add = false;
464
+ break;
465
+ }
466
+ }
467
+
468
+ if (add) {
469
+ cur_seq_set.push_back(seq_set[i]);
470
+
471
+ if (cur_seq_set.size() > n_ubatch) {
472
+ break;
473
+ }
474
+ }
475
+ }
476
+
477
+ const uint32_t n_seqs = cur_seq_set.size();
478
+
479
+ // we are done
480
+ if (n_seqs == 0) {
481
+ return {};
482
+ }
483
+
484
+ // the current batch index of each sequence set
485
+ std::vector<int32_t> cur_idx(n_seqs, 0);
486
+
487
+ for (uint32_t s = 0; s < n_seqs; ++s) {
488
+ while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
489
+ ++cur_idx[s];
490
+ }
491
+ }
492
+
493
+ // the list of batch indices for each sequence set
494
+ // at the end we will concat these to get the final ubatch
495
+ std::vector<idx_vec_t> idxs_per_seq(n_seqs);
496
+
497
+ while (true) {
498
+ // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
499
+ // if we haven't reached n_ubatch
500
+ bool can_expand = true;
501
+
502
+ for (uint32_t s = 0; s < n_seqs; ++s) {
503
+ if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
504
+ can_expand = false;
505
+ break;
506
+ }
507
+ }
508
+
509
+ if (!can_expand) {
510
+ break;
511
+ }
512
+
513
+ for (uint32_t s = 0; s < n_seqs; ++s) {
514
+ const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
515
+
516
+ idxs_per_seq[s].push_back(idx);
517
+
518
+ used[idx] = true;
519
+
520
+ ++cur_idx[s];
521
+ }
522
+
523
+ if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
524
+ break;
525
+ }
526
+ }
527
+
528
+ // concat the per-sequence-set lists
529
+ std::vector<int32_t> idxs;
530
+
531
+ for (uint32_t s = 0; s < n_seqs; ++s) {
532
+ idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
533
+ }
534
+
535
+ return ubatch_add(idxs, n_seqs, true);
536
+ }
537
+
538
+ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
539
+ // find the first unused token
540
+ uint32_t cur_idx = 0;
541
+ while (cur_idx < used.size() && used[cur_idx]) {
542
+ ++cur_idx;
543
+ }
544
+
545
+ // we are done
546
+ if (cur_idx >= used.size()) {
547
+ return {};
548
+ }
549
+
550
+ // this is the starting sequence set
551
+ // we allow adding tokens only if their sequence set is a subset of the current sequence set
552
+ auto cur_seq_set = seq_set[cur_idx];
553
+
554
+ std::vector<int32_t> idxs;
555
+
556
+ while (true) {
557
+ idxs.push_back(cur_idx);
558
+
559
+ used[cur_idx] = true;
560
+
561
+ if (idxs.size() >= n_ubatch) {
562
+ break;
563
+ }
564
+
565
+ do {
566
+ ++cur_idx;
567
+ } while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
568
+
569
+ if (cur_idx == get_n_tokens()) {
570
+ break;
571
+ }
572
+
573
+ cur_seq_set = seq_set[cur_idx];
574
+ }
575
+
576
+ return ubatch_add(idxs, 1, true);
577
+ }
578
+
579
  void llama_batch_allocr::clear() {
580
  n_outputs = 0;
581
 
582
  batch = {};
583
+
584
+ pos .clear();
585
+ n_seq_id .clear();
586
+ seq_id .clear();
587
+ seq_id_unq.clear();
588
+ output .clear();
589
 
590
  for (auto & cur : seq_pos) {
591
  cur.clear();
 
594
  for (auto & cur : seq_cpl) {
595
  std::fill(cur.begin(), cur.end(), false);
596
  }
597
+
598
+ seq_set.clear();
599
+
600
+ seq_set_map.clear();
601
+
602
+ std::fill(seq_idx.begin(), seq_idx.end(), -1);
603
+ }
604
+
605
+ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
606
+ const uint32_t n_tokens = idxs.size();
607
+
608
+ assert(n_tokens%n_seqs == 0);
609
+
610
+ ubatches.emplace_back();
611
+
612
+ auto & ubatch = ubatches.back();
613
+
614
+ const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
615
+
616
+ const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
617
+ const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
618
+
619
+ ubatch.token .resize(n_tokens);
620
+ ubatch.embd .resize(n_embd_all);
621
+ ubatch.pos .resize(n_pos_all);
622
+ ubatch.n_seq_id .resize(n_tokens);
623
+ ubatch.seq_id .resize(n_tokens);
624
+ ubatch.seq_id_unq.resize(0);
625
+ ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
626
+ ubatch.output .resize(n_tokens);
627
+
628
+ seq_set_t seq_set_unq;
629
+
630
+ for (size_t i = 0; i < idxs.size(); ++i) {
631
+ if (batch.token) {
632
+ ubatch.token[i] = batch.token[idxs[i]];
633
+ }
634
+
635
+ if (batch.embd) {
636
+ memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
637
+ }
638
+
639
+ for (int j = 0; j < n_pos_cur; ++j) {
640
+ ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
641
+ }
642
+
643
+ ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
644
+ ubatch.seq_id[i] = batch.seq_id[idxs[i]];
645
+ ubatch.output[i] = batch.logits[idxs[i]];
646
+
647
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
648
+ seq_set_unq.set(ubatch.seq_id[i][s]);
649
+ }
650
+
651
+ if (ubatch.output[i]) {
652
+ out_ids.push_back(idxs[i]);
653
+ }
654
+ }
655
+
656
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
657
+ if (seq_set_unq.test(s)) {
658
+ ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
659
+ ubatch.seq_id_unq.push_back(s);
660
+ }
661
+ }
662
+
663
+ llama_ubatch res {
664
+ /*.equal_seqs =*/ equal_seqs,
665
+ /*.n_tokens =*/ n_tokens,
666
+ /*.n_seq_tokens =*/ n_tokens/n_seqs,
667
+ /*.n_seqs =*/ n_seqs,
668
+ /*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
669
+
670
+ /*.token =*/ batch.token ? ubatch.token.data() : nullptr,
671
+ /*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
672
+ /*.pos =*/ ubatch.pos.data(),
673
+ /*.n_seq_id =*/ ubatch.n_seq_id.data(),
674
+ /*.seq_id =*/ ubatch.seq_id.data(),
675
+ /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
676
+ /*.seq_idx =*/ ubatch.seq_idx.data(),
677
+ /*.output =*/ ubatch.output.data(),
678
+ };
679
+
680
+ if (debug > 0) {
681
+ LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
682
+
683
+ ubatch_print(res, debug);
684
+ }
685
+
686
+ return res;
687
+ }
688
+
689
+ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
690
+ if (debug > 0) {
691
+ LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
692
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
693
+ LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
694
+ LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
695
+ LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
696
+
697
+ std::stringstream ss_seq_id_unq;
698
+ std::stringstream ss_seq_idx;
699
+
700
+ ss_seq_id_unq << "[ ";
701
+ ss_seq_idx << "[";
702
+
703
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
704
+ ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
705
+ }
706
+
707
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
708
+ if (ubatch.seq_idx[s] >= 0) {
709
+ ss_seq_idx << ubatch.seq_idx[s]%10;
710
+ } else {
711
+ ss_seq_idx << ".";
712
+ }
713
+ }
714
+
715
+ ss_seq_id_unq << "]";
716
+ ss_seq_idx << "]";
717
+
718
+ LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
719
+ LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
720
+ LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
721
+ LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
722
+ LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
723
+ LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
724
+ LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
725
+ LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
726
+ LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
727
+
728
+ if (debug > 1) {
729
+ int seq_id_max = 0;
730
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
731
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
732
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
733
+ seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
734
+ }
735
+ }
736
+ }
737
+ ++seq_id_max;
738
+
739
+ LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
740
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
741
+ std::vector<int8_t> seq_id(seq_id_max);
742
+
743
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
744
+ seq_id[ubatch.seq_id[i][s]] = 1;
745
+ }
746
+
747
+ std::stringstream ss;
748
+ for (int s = 0; s < seq_id_max; ++s) {
749
+ if (seq_id[s]) {
750
+ ss << s%10;
751
+ } else {
752
+ ss << ".";
753
+ }
754
+ }
755
+
756
+ if (ubatch.token) {
757
+ LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
758
+ __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
759
+ ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
760
+ } else {
761
+ LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
762
+ __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
763
+ }
764
+ }
765
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
766
+ }
767
+ }
768
  }
769
 
770
  //
 
775
  llama_token * tokens,
776
  int32_t n_tokens) {
777
  return {
778
+ /*n_tokens =*/ n_tokens,
779
+ /*tokens =*/ tokens,
780
+ /*embd =*/ nullptr,
781
+ /*pos =*/ nullptr,
782
+ /*n_seq_id =*/ nullptr,
783
+ /*seq_id =*/ nullptr,
784
+ /*logits =*/ nullptr,
785
  };
786
  }
787
 
788
  struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
789
  llama_batch batch = {
790
+ /*n_tokens =*/ 0,
791
+ /*tokens =*/ nullptr,
792
+ /*embd =*/ nullptr,
793
+ /*pos =*/ nullptr,
794
+ /*n_seq_id =*/ nullptr,
795
+ /*seq_id =*/ nullptr,
796
+ /*logits =*/ nullptr,
797
  };
798
 
799
  if (embd) {
examples/talk-llama/llama-batch.h CHANGED
@@ -2,86 +2,44 @@
2
 
3
  #include "llama.h"
4
 
 
 
5
  #include <array>
6
  #include <vector>
7
  #include <set>
 
 
8
 
9
- // very similar to llama_batch,
10
- // but has more metadata about sequences
11
  struct llama_ubatch {
12
  bool equal_seqs;
13
  // TODO: whole_seqs for embeddings?
14
 
15
  uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
16
- uint32_t n_seq_tokens; // tokens per sequence
17
- uint32_t n_seqs;
18
-
19
- llama_token * token; // [n_tokens]
20
- float * embd; // [n_embd, n_tokens]
21
- llama_pos * pos; // [n_tokens]
22
- int32_t * n_seq_id; // [n_seqs]
23
- llama_seq_id ** seq_id; // [n_seqs]
24
- int8_t * output; // [n_tokens]
25
- };
26
-
27
- struct llama_sbatch_seq {
28
- int32_t n_seq_id;
29
-
30
- llama_seq_id * seq_id;
31
-
32
- size_t offset;
33
- size_t length;
34
- };
35
-
36
- // sequence-length-aware batch splitting
37
- struct llama_sbatch {
38
- // tokens left in this batch
39
- size_t n_tokens;
40
-
41
- size_t n_embd;
42
-
43
- // sorted indices into the batch
44
- std::vector<int64_t> ids;
45
- // batch indices of the output
46
- std::vector<int64_t> out_ids;
47
- std::vector<llama_sbatch_seq> seq;
48
-
49
- const llama_batch * batch = nullptr;
50
-
51
- // buffers for the ubatches
52
- // TODO: very hacky, this needs a complete rework
53
- struct ubatch_data {
54
- std::vector<llama_token> token;
55
- std::vector<float> embd;
56
- std::vector<llama_pos> pos;
57
- std::vector<int32_t> n_seq_id;
58
- std::vector<llama_seq_id *> seq_id;
59
- std::vector<int8_t> output;
60
- };
61
-
62
- std::vector<ubatch_data> udatas;
63
-
64
- llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
65
-
66
- void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
67
-
68
- // simple split, unknown number of sequences of unequal lengths
69
- llama_ubatch split_simple(size_t n_ubatch);
70
-
71
- // make batches of equal-length sequences
72
- llama_ubatch split_equal(size_t n_ubatch);
73
-
74
- // sequence-wise split
75
- llama_ubatch split_seq(size_t n_ubatch);
76
-
77
- llama_sbatch() = default;
78
- llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
79
  };
80
 
81
- // a helper for sanitizing and fulfilling a batch
82
  class llama_batch_allocr {
83
  public:
84
- llama_batch_allocr();
85
 
86
  // sanitize and auto-gen missing data in the input batch
87
  // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
@@ -89,20 +47,57 @@ public:
89
  const llama_batch & batch_inp,
90
  const llama_vocab & vocab,
91
  const llama_memory_i * memory,
92
- bool embd_all);
 
93
 
94
  const llama_batch & get_batch() const;
95
 
 
96
  uint32_t get_n_outputs() const;
97
 
 
 
 
 
98
  llama_pos seq_pos_min(llama_seq_id seq_id) const;
99
  llama_pos seq_pos_max(llama_seq_id seq_id) const;
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  private:
102
  void clear();
103
 
 
 
 
 
 
 
 
104
  llama_batch batch;
105
 
 
 
 
 
 
 
 
 
106
  uint32_t n_outputs;
107
 
108
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@@ -110,10 +105,43 @@ private:
110
  std::vector<llama_pos> pos;
111
  std::vector<int32_t> n_seq_id;
112
  std::vector<llama_seq_id *> seq_id;
 
 
113
  std::vector<int8_t> output;
114
 
115
- std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
- std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  int debug;
119
  };
 
2
 
3
  #include "llama.h"
4
 
5
+ #include "llama-cparams.h"
6
+
7
  #include <array>
8
  #include <vector>
9
  #include <set>
10
+ #include <bitset>
11
+ #include <unordered_map>
12
 
13
+ // keep this struct lightweight
14
+ // it points to data in `llama_batch_allocr`
15
  struct llama_ubatch {
16
  bool equal_seqs;
17
  // TODO: whole_seqs for embeddings?
18
 
19
  uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
20
+ uint32_t n_seq_tokens; // tokens per sequence set
21
+ uint32_t n_seqs; // sequence sets in the ubatch
22
+ uint32_t n_seqs_unq; // unique sequence ids in the ubatch
23
+
24
+ // seq_id_unq: unique sequence ids in the ubatch
25
+ // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
26
+ // used for extracting sequence pooled embeddings
27
+
28
+ // // size | idx | val
29
+ llama_token * token; // [n_tokens] | i | id, token
30
+ float * embd; // [n_embd, n_tokens] | i | embd
31
+ llama_pos * pos; // [n_tokens] | i | pos
32
+ int32_t * n_seq_id; // [n_tokens] | i | -
33
+ llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
34
+ llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35
+ int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36
+ int8_t * output; // [n_tokens] | i | -
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  };
38
 
39
+ // a helper for sanitizing, fulfilling and splitting a batch
40
  class llama_batch_allocr {
41
  public:
42
+ llama_batch_allocr(uint32_t n_pos_per_embd);
43
 
44
  // sanitize and auto-gen missing data in the input batch
45
  // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
 
47
  const llama_batch & batch_inp,
48
  const llama_vocab & vocab,
49
  const llama_memory_i * memory,
50
+ uint32_t n_embd,
51
+ bool output_all);
52
 
53
  const llama_batch & get_batch() const;
54
 
55
+ uint32_t get_n_tokens() const;
56
  uint32_t get_n_outputs() const;
57
 
58
+ // the array of output indices in the order they were encountered during the ubatch splitting
59
+ std::vector<int32_t> & get_out_ids();
60
+
61
+ // min/max positions of each sequence in the current ubatch
62
  llama_pos seq_pos_min(llama_seq_id seq_id) const;
63
  llama_pos seq_pos_max(llama_seq_id seq_id) const;
64
 
65
+ // call once before splitting the batch to reset the internal state
66
+ void split_reset();
67
+
68
+ // simple split, unknown number of sequence sets of unequal lengths
69
+ llama_ubatch split_simple(uint32_t n_ubatch);
70
+
71
+ // make ubatches of equal-length sequences sets
72
+ llama_ubatch split_equal(uint32_t n_ubatch);
73
+
74
+ // sequence-set-wise split - each ubatch contains a single sequence-set
75
+ llama_ubatch split_seq(uint32_t n_ubatch);
76
+
77
+ // a helper method for creating a well-defined ubatch of tokens
78
+ // TODO: support embeddings if needed in the future
79
+ llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
80
+
81
  private:
82
  void clear();
83
 
84
+ // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
85
+ // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
86
+ llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
87
+
88
+ // for debugging, start with LLAMA_BATCH_DEBUG=2
89
+ void ubatch_print(const llama_ubatch & ubatch, int debug);
90
+
91
  llama_batch batch;
92
 
93
+ // only for debugging purposes
94
+ const llama_vocab * vocab;
95
+
96
+ // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
97
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
98
+ const uint32_t n_pos_per_embd;
99
+
100
+ uint32_t n_embd;
101
  uint32_t n_outputs;
102
 
103
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
 
105
  std::vector<llama_pos> pos;
106
  std::vector<int32_t> n_seq_id;
107
  std::vector<llama_seq_id *> seq_id;
108
+ std::vector<llama_seq_id> seq_id_unq;
109
+ std::vector<int32_t> seq_idx;
110
  std::vector<int8_t> output;
111
 
112
+ using pos_set_t = std::set<llama_pos>;
113
+ using seq_cpl_t = std::vector<bool>;
114
+
115
+ std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
+ std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
117
+
118
+ using idx_vec_t = std::vector<int32_t>;
119
+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
120
+
121
+ std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
122
+
123
+ std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
124
+
125
+ // batch indices of the output
126
+ std::vector<int32_t> out_ids;
127
+
128
+ // used[i] indicates if token i has already been used in a previous ubatch
129
+ std::vector<bool> used;
130
+
131
+ // llama_ubatch points to this data:
132
+ struct ubatch {
133
+ std::vector<llama_token> token;
134
+ std::vector<float> embd;
135
+ std::vector<llama_pos> pos;
136
+ std::vector<int32_t> n_seq_id;
137
+ std::vector<llama_seq_id *> seq_id;
138
+ std::vector<llama_seq_id> seq_id_unq;
139
+ std::vector<int32_t> seq_idx;
140
+ std::vector<int8_t> output;
141
+ };
142
+
143
+ // current splitting state:
144
+ std::vector<ubatch> ubatches;
145
 
146
  int debug;
147
  };
examples/talk-llama/llama-chat.cpp CHANGED
@@ -333,7 +333,7 @@ int32_t llm_chat_apply_template(
333
  std::string role(message->role);
334
  if (role == "system") {
335
  // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
336
- system_prompt = trim(message->content);
337
  continue;
338
  }
339
  // in gemma, "assistant" is "model"
@@ -355,7 +355,7 @@ int32_t llm_chat_apply_template(
355
  std::string role(message->role);
356
  if (role == "system") {
357
  // there is no system message support, we will merge it with user prompt
358
- system_prompt = message->content;
359
  continue;
360
  } else if (role == "user") {
361
  ss << "Human: ";
 
333
  std::string role(message->role);
334
  if (role == "system") {
335
  // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
336
+ system_prompt += trim(message->content);
337
  continue;
338
  }
339
  // in gemma, "assistant" is "model"
 
355
  std::string role(message->role);
356
  if (role == "system") {
357
  // there is no system message support, we will merge it with user prompt
358
+ system_prompt += message->content;
359
  continue;
360
  } else if (role == "user") {
361
  ss << "Human: ";
examples/talk-llama/llama-context.cpp CHANGED
@@ -20,7 +20,7 @@ llama_context::llama_context(
20
  const llama_model & model,
21
  llama_context_params params) :
22
  model(model),
23
- batch_allocr(std::make_unique<llama_batch_allocr>()) {
24
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
25
 
26
  t_start_us = model.t_start_us;
@@ -722,22 +722,26 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
722
  }
723
 
724
  int llama_context::encode(const llama_batch & batch_inp) {
 
 
725
  if (batch_inp.n_tokens == 0) {
726
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
727
  return -1;
728
  }
729
 
 
 
 
 
730
  // note: during encode, we always pass the full sequence starting from pos = 0
731
- if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
732
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733
  return -1;
734
  }
735
 
736
- const llama_batch & batch = batch_allocr->get_batch();
737
 
738
- const uint32_t n_tokens = batch.n_tokens;
739
-
740
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
741
 
742
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
743
  GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
@@ -751,14 +755,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
751
 
752
  n_queued_tokens += n_tokens;
753
 
754
- const auto & hparams = model.hparams;
755
-
756
- const int64_t n_embd = hparams.n_embd;
757
-
758
- llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
759
-
760
- const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
761
-
762
  // reserve output buffer
763
  if (output_reserve(n_tokens) < n_tokens) {
764
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
@@ -817,34 +813,28 @@ int llama_context::encode(const llama_batch & batch_inp) {
817
  {
818
  // extract sequence embeddings
819
  auto & embd_seq_out = embd_seq;
820
- embd_seq_out.clear();
821
 
822
- GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
 
 
823
 
824
- // TODO: fix indexing [UBATCH_IDX]
825
- for (uint32_t i = 0; i < n_tokens; i++) {
826
- const llama_seq_id seq_id = ubatch.seq_id[i][0];
827
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
828
- continue;
829
- }
830
  embd_seq_out[seq_id].resize(n_embd);
831
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
832
  }
833
  } break;
834
  case LLAMA_POOLING_TYPE_RANK:
835
  {
836
  // extract the rerank score - n_cls_out floats per sequence
837
  auto & embd_seq_out = embd_seq;
 
838
  const uint32_t n_cls_out = hparams.n_cls_out;
839
 
840
- // TODO: fix indexing [UBATCH_IDX]
841
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
842
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
843
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
844
- continue;
845
- }
846
  embd_seq_out[seq_id].resize(n_cls_out);
847
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
848
  }
849
  } break;
850
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -869,12 +859,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
869
  cross.v_embd.resize(cross.n_embd*cross.n_enc);
870
  memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
871
 
 
 
872
  // remember the sequence ids used during the encoding - needed for cross attention later
873
  cross.seq_ids_enc.resize(n_tokens);
874
  for (uint32_t i = 0; i < n_tokens; i++) {
875
  cross.seq_ids_enc[i].clear();
 
876
  for (int s = 0; s < batch.n_seq_id[i]; s++) {
877
- llama_seq_id seq_id = batch.seq_id[i][s];
 
878
  cross.seq_ids_enc[i].insert(seq_id);
879
  }
880
  }
@@ -884,6 +878,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
884
  }
885
 
886
  int llama_context::decode(const llama_batch & batch_inp) {
 
 
887
  if (!memory) {
888
  LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
889
  return encode(batch_inp);
@@ -894,29 +890,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
894
  return -1;
895
  }
896
 
897
- // when computing embeddings, all tokens are output
898
- const bool embd_all = cparams.embeddings;
899
-
900
- if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
901
- LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
902
- return -1;
903
- }
904
-
905
- const llama_batch & batch = batch_allocr->get_batch();
906
-
907
  const auto & vocab = model.vocab;
908
  const auto & hparams = model.hparams;
909
 
910
  const int32_t n_vocab = vocab.n_tokens();
911
  const int64_t n_embd = hparams.n_embd;
912
 
913
- const uint32_t n_tokens_all = batch.n_tokens;
 
914
 
915
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
 
 
916
 
917
- const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
 
918
 
919
- if (embd_all) {
920
  // require that all tokens are output
921
  if (n_outputs_all != n_tokens_all) {
922
  LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -945,7 +936,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
945
  llama_memory_state_ptr mstate;
946
 
947
  while (true) {
948
- mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
949
  if (!mstate) {
950
  return -2;
951
  }
@@ -966,19 +957,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
966
  did_optimize = true;
967
 
968
  if (kv_self_update(true)) {
969
- LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
970
 
971
  continue;
972
  }
973
  }
974
 
975
- LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
976
 
977
  return 1;
978
  }
979
  case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
980
  {
981
- LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
982
 
983
  return -2;
984
  }
@@ -1005,7 +996,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1005
  if (n_outputs_all == n_tokens_all) {
1006
  n_outputs_new = ubatch.n_tokens;
1007
  } else {
1008
- GGML_ASSERT(ubatch.output);
1009
  for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
1010
  n_outputs_new += (int32_t) (ubatch.output[i] != 0);
1011
  }
@@ -1105,27 +1095,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
1105
  // extract sequence embeddings (cleared before processing each batch)
1106
  auto & embd_seq_out = embd_seq;
1107
 
1108
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1109
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1110
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1111
- continue;
1112
- }
1113
  embd_seq_out[seq_id].resize(n_embd);
1114
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1115
  }
1116
  } break;
1117
  case LLAMA_POOLING_TYPE_RANK:
1118
  {
1119
- // extract the rerank score - a single float per sequence
1120
  auto & embd_seq_out = embd_seq;
1121
 
1122
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1123
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1124
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1125
- continue;
1126
- }
1127
- embd_seq_out[seq_id].resize(1);
1128
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
 
1129
  }
1130
  } break;
1131
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -1145,7 +1135,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1145
  if (n_outputs > 0) {
1146
  bool sorted_output = true;
1147
 
1148
- auto & out_ids = mstate->out_ids();
1149
 
1150
  GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1151
 
@@ -1318,8 +1308,8 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1318
 
1319
  this->n_outputs = n_outputs;
1320
 
1321
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
1322
- llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
1323
 
1324
  auto * gf = graph_init();
1325
  auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
@@ -2039,7 +2029,12 @@ void llama_context::opt_epoch_iter(
2039
  batch.logits [pos_batch] = true;
2040
  }
2041
 
2042
- const auto n_tokens_all = batch.n_tokens;
 
 
 
 
 
2043
 
2044
  n_queued_tokens += n_tokens_all;
2045
 
@@ -2047,7 +2042,7 @@ void llama_context::opt_epoch_iter(
2047
 
2048
  uint32_t n_outputs_all = n_tokens_all;
2049
 
2050
- auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
2051
  if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2052
  LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2053
  break;
 
20
  const llama_model & model,
21
  llama_context_params params) :
22
  model(model),
23
+ balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
24
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
25
 
26
  t_start_us = model.t_start_us;
 
722
  }
723
 
724
  int llama_context::encode(const llama_batch & batch_inp) {
725
+ GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
726
+
727
  if (batch_inp.n_tokens == 0) {
728
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
729
  return -1;
730
  }
731
 
732
+ const auto & hparams = model.hparams;
733
+
734
+ const int64_t n_embd = hparams.n_embd;
735
+
736
  // note: during encode, we always pass the full sequence starting from pos = 0
737
+ if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
738
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
739
  return -1;
740
  }
741
 
742
+ const uint32_t n_tokens = balloc->get_n_tokens();
743
 
744
+ const llama_ubatch ubatch = balloc->split_simple(n_tokens);
 
 
745
 
746
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
747
  GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
 
755
 
756
  n_queued_tokens += n_tokens;
757
 
 
 
 
 
 
 
 
 
758
  // reserve output buffer
759
  if (output_reserve(n_tokens) < n_tokens) {
760
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
 
813
  {
814
  // extract sequence embeddings
815
  auto & embd_seq_out = embd_seq;
 
816
 
817
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
818
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
819
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
820
 
 
 
 
 
 
 
821
  embd_seq_out[seq_id].resize(n_embd);
822
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
823
  }
824
  } break;
825
  case LLAMA_POOLING_TYPE_RANK:
826
  {
827
  // extract the rerank score - n_cls_out floats per sequence
828
  auto & embd_seq_out = embd_seq;
829
+
830
  const uint32_t n_cls_out = hparams.n_cls_out;
831
 
832
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
833
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
834
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
835
+
 
 
836
  embd_seq_out[seq_id].resize(n_cls_out);
837
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
838
  }
839
  } break;
840
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
 
859
  cross.v_embd.resize(cross.n_embd*cross.n_enc);
860
  memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
861
 
862
+ const auto & batch = balloc->get_batch();
863
+
864
  // remember the sequence ids used during the encoding - needed for cross attention later
865
  cross.seq_ids_enc.resize(n_tokens);
866
  for (uint32_t i = 0; i < n_tokens; i++) {
867
  cross.seq_ids_enc[i].clear();
868
+
869
  for (int s = 0; s < batch.n_seq_id[i]; s++) {
870
+ const llama_seq_id seq_id = batch.seq_id[i][s];
871
+
872
  cross.seq_ids_enc[i].insert(seq_id);
873
  }
874
  }
 
878
  }
879
 
880
  int llama_context::decode(const llama_batch & batch_inp) {
881
+ GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
882
+
883
  if (!memory) {
884
  LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
885
  return encode(batch_inp);
 
890
  return -1;
891
  }
892
 
 
 
 
 
 
 
 
 
 
 
893
  const auto & vocab = model.vocab;
894
  const auto & hparams = model.hparams;
895
 
896
  const int32_t n_vocab = vocab.n_tokens();
897
  const int64_t n_embd = hparams.n_embd;
898
 
899
+ // when computing embeddings, all tokens are output
900
+ const bool output_all = cparams.embeddings;
901
 
902
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
903
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
904
+ return -1;
905
+ }
906
 
907
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
908
+ const uint32_t n_outputs_all = balloc->get_n_outputs();
909
 
910
+ if (output_all) {
911
  // require that all tokens are output
912
  if (n_outputs_all != n_tokens_all) {
913
  LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
 
936
  llama_memory_state_ptr mstate;
937
 
938
  while (true) {
939
+ mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
940
  if (!mstate) {
941
  return -2;
942
  }
 
957
  did_optimize = true;
958
 
959
  if (kv_self_update(true)) {
960
+ LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
961
 
962
  continue;
963
  }
964
  }
965
 
966
+ LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
967
 
968
  return 1;
969
  }
970
  case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
971
  {
972
+ LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
973
 
974
  return -2;
975
  }
 
996
  if (n_outputs_all == n_tokens_all) {
997
  n_outputs_new = ubatch.n_tokens;
998
  } else {
 
999
  for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
1000
  n_outputs_new += (int32_t) (ubatch.output[i] != 0);
1001
  }
 
1095
  // extract sequence embeddings (cleared before processing each batch)
1096
  auto & embd_seq_out = embd_seq;
1097
 
1098
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1099
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1100
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1101
+
 
1102
  embd_seq_out[seq_id].resize(n_embd);
1103
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
1104
  }
1105
  } break;
1106
  case LLAMA_POOLING_TYPE_RANK:
1107
  {
1108
+ // extract the rerank score - n_cls_out floats per sequence
1109
  auto & embd_seq_out = embd_seq;
1110
 
1111
+ const uint32_t n_cls_out = hparams.n_cls_out;
1112
+
1113
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1114
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1115
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1116
+
1117
+ embd_seq_out[seq_id].resize(n_cls_out);
1118
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
1119
  }
1120
  } break;
1121
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
 
1135
  if (n_outputs > 0) {
1136
  bool sorted_output = true;
1137
 
1138
+ auto & out_ids = balloc->get_out_ids();
1139
 
1140
  GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1141
 
 
1308
 
1309
  this->n_outputs = n_outputs;
1310
 
1311
+ llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1312
+ llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1313
 
1314
  auto * gf = graph_init();
1315
  auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
 
2029
  batch.logits [pos_batch] = true;
2030
  }
2031
 
2032
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2033
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2034
+ return;
2035
+ }
2036
+
2037
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
2038
 
2039
  n_queued_tokens += n_tokens_all;
2040
 
 
2042
 
2043
  uint32_t n_outputs_all = n_tokens_all;
2044
 
2045
+ auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
2046
  if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2047
  LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2048
  break;
examples/talk-llama/llama-context.h CHANGED
@@ -247,7 +247,7 @@ private:
247
  std::map<llama_seq_id, std::vector<float>> embd_seq;
248
 
249
  // reuse the batch_allocr to avoid unnecessary memory allocations
250
- std::unique_ptr<llama_batch_allocr> batch_allocr;
251
 
252
  uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
253
 
 
247
  std::map<llama_seq_id, std::vector<float>> embd_seq;
248
 
249
  // reuse the batch_allocr to avoid unnecessary memory allocations
250
+ std::unique_ptr<llama_batch_allocr> balloc;
251
 
252
  uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
253
 
examples/talk-llama/llama-graph.cpp CHANGED
@@ -6,7 +6,8 @@
6
 
7
  #include "llama-kv-cache-unified.h"
8
  #include "llama-kv-cache-unified-iswa.h"
9
- #include "llama-kv-cache-recurrent.h"
 
10
 
11
  #include <cassert>
12
  #include <cmath>
@@ -91,36 +92,28 @@ void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
91
  }
92
 
93
  void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
94
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
95
- //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
96
 
97
- if (!out_ids) {
98
- LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
99
- } else {
100
- const int64_t n_tokens = ubatch->n_tokens;
101
 
102
- GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
103
- int32_t * data = (int32_t *) out_ids->data;
104
 
105
- if (n_outputs == n_tokens) {
106
- for (int i = 0; i < n_tokens; ++i) {
107
- data[i] = i;
108
- }
109
- } else if (ubatch->output) {
110
- int32_t n_outputs = 0;
111
- for (int i = 0; i < n_tokens; ++i) {
112
- if (ubatch->output[i]) {
113
- data[n_outputs++] = i;
114
- }
115
- }
116
- // the graph needs to have been passed the correct number of outputs
117
- GGML_ASSERT(n_outputs == n_outputs);
118
- } else if (n_outputs == 1) {
119
- // only keep last output
120
- data[0] = n_tokens - 1;
121
- } else {
122
- GGML_ASSERT(n_outputs == 0);
123
- }
124
  }
125
  }
126
  }
@@ -129,127 +122,114 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
129
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
130
  const int64_t n_tokens = ubatch->n_tokens;
131
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
132
- const int64_t n_seqs = ubatch->n_seqs;
133
 
134
  GGML_ASSERT(mean);
135
  GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
136
 
137
  float * data = (float *) mean->data;
138
- memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
139
 
140
- std::vector<uint64_t> sum(n_tokens, 0);
 
 
 
 
141
 
142
- // TODO: fix indexing [UBATCH_IDX]
143
- for (int s = 0; s < n_seqs; ++s) {
144
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
145
-
146
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
147
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
148
-
149
- sum[seq_id] += ubatch->n_seq_tokens;
150
  }
151
 
152
- std::vector<float> div(n_tokens, 0.0f);
153
- for (int i = 0; i < n_tokens; ++i) {
154
- const uint64_t s = sum[i];
155
- if (s > 0) {
156
- div[i] = 1.0f/float(s);
157
  }
158
  }
159
 
160
- // TODO: fix indexing [UBATCH_IDX]
161
- for (int s = 0; s < n_seqs; ++s) {
162
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
 
163
 
164
- for (int i = 0; i < n_seq_tokens; ++i) {
165
- data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
 
166
  }
167
  }
168
  }
169
  }
170
 
171
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
172
- if (cparams.embeddings && (
173
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
174
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
175
- const int64_t n_tokens = ubatch->n_tokens;
176
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
177
- const int64_t n_seqs = ubatch->n_seqs;
178
 
 
 
 
 
179
  GGML_ASSERT(cls);
180
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
181
 
182
  uint32_t * data = (uint32_t *) cls->data;
183
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
184
-
185
- // TODO: fix indexing [UBATCH_IDX]
186
- for (int s = 0; s < n_seqs; ++s) {
187
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
188
-
189
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
190
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
191
 
192
- for (int i = 0; i < n_seq_tokens; ++i) {
193
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
 
 
194
 
195
- if (pos == 0) {
196
- data[seq_id] = s*n_seq_tokens + i;
197
- }
198
  }
199
  }
200
  }
201
 
202
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
203
- const int64_t n_tokens = ubatch->n_tokens;
204
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
205
- const int64_t n_seqs = ubatch->n_seqs;
206
-
207
  GGML_ASSERT(cls);
208
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
209
 
210
  uint32_t * data = (uint32_t *) cls->data;
211
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
212
-
213
- std::vector<int> last_pos(n_tokens, -1);
214
- std::vector<int> last_row(n_tokens, -1);
215
 
216
- // TODO: fix indexing [UBATCH_IDX]
217
- for (int s = 0; s < n_seqs; ++s) {
218
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
219
 
220
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
221
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
222
 
223
- for (int i = 0; i < n_seq_tokens; ++i) {
224
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
 
225
 
226
- if (pos >= last_pos[seq_id]) {
227
- last_pos[seq_id] = pos;
228
- last_row[seq_id] = s*n_seq_tokens + i;
229
  }
230
  }
231
  }
232
 
233
- for (int i = 0; i < n_tokens; ++i) {
234
- if (last_row[i] >= 0) {
235
- data[i] = last_row[i];
236
  }
237
  }
238
  }
239
  }
240
 
241
- void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
242
  GGML_UNUSED(ubatch);
243
 
244
- const int64_t n_kv = kv_state->get_n_kv();
245
 
246
  if (s_copy) {
247
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
248
  int32_t * data = (int32_t *) s_copy->data;
249
 
250
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
251
- for (uint32_t i = 0; i < n_kv; ++i) {
252
- data[i] = kv_state->s_copy(i);
253
  }
254
  }
255
  }
@@ -265,89 +245,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
265
  }
266
 
267
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
268
- if (kq_mask) {
269
- if (cparams.causal_attn) {
270
- const int64_t n_kv = ubatch->n_tokens;
271
- const int64_t n_tokens = ubatch->n_tokens;
272
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
273
- const int64_t n_seqs = ubatch->n_seqs;
274
-
275
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
276
- float * data = (float *) kq_mask->data;
277
-
278
- for (int h = 0; h < 1; ++h) {
279
- for (int s1 = 0; s1 < n_seqs; ++s1) {
280
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
281
-
282
- for (int j = 0; j < n_seq_tokens; ++j) {
283
- const int32_t tj = s1*n_seq_tokens + j;
284
-
285
- for (int s0 = 0; s0 < n_seqs; ++s0) {
286
- for (int i = 0; i < n_seq_tokens; ++i) {
287
- const int32_t ti = s0*n_seq_tokens + i;
288
- float f = -INFINITY;
289
-
290
- // TODO: fix indexing [UBATCH_IDX]
291
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
292
- if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
293
- if (hparams.use_alibi) {
294
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
295
- } else {
296
- f = 0.0f;
297
- }
298
- break;
299
- }
300
- }
301
-
302
- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
303
- }
304
- }
305
- }
306
- }
307
- }
308
- } else {
309
- const int64_t n_tokens = ubatch->n_tokens;
310
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
311
- const int64_t n_seqs = ubatch->n_seqs;
312
- const int64_t n_stride = ubatch->n_tokens;
313
-
314
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
315
-
316
- float * data = (float *) kq_mask->data;
317
-
318
- for (int h = 0; h < 1; ++h) {
319
- for (int s1 = 0; s1 < n_seqs; ++s1) {
320
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
321
-
322
- for (int j = 0; j < n_seq_tokens; ++j) {
323
- const int32_t tj = s1*n_seq_tokens + j;
324
-
325
- for (int s0 = 0; s0 < n_seqs; ++s0) {
326
- for (int i = 0; i < n_seq_tokens; ++i) {
327
- const int32_t ti = s0*n_seq_tokens + i;
328
- float f = -INFINITY;
329
-
330
- // TODO: fix indexing [UBATCH_IDX]
331
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
332
- if (ubatch->seq_id[s0][s] == seq_id) {
333
- if (hparams.use_alibi) {
334
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
335
- } else {
336
- f = 0.0f;
337
- }
338
- break;
339
- }
340
- }
341
-
342
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
343
- }
344
- }
345
 
346
- for (int i = n_tokens; i < n_stride; ++i) {
347
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
 
 
 
 
348
  }
 
349
  }
350
  }
 
 
351
  }
352
  }
353
  }
@@ -370,39 +297,59 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
370
  }
371
 
372
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
373
- if (cross_kq_mask) {
374
- const int64_t n_enc = cross_kq_mask->ne[0];
375
- const int64_t n_tokens = ubatch->n_tokens;
376
 
377
- GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
378
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
379
 
380
- float * data = (float *) cross_kq_mask->data;
 
381
 
382
- for (int h = 0; h < 1; ++h) {
383
- for (int j = 0; j < n_tokens; ++j) {
384
- for (int i = 0; i < n_enc; ++i) {
385
- float f = -INFINITY;
386
- // TODO: fix indexing [UBATCH_IDX]
387
- for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
388
- const llama_seq_id seq_id = ubatch->seq_id[j][s];
389
- if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
390
- f = 0.0f;
391
- }
 
 
392
  }
393
- data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
394
  }
 
 
395
  }
 
396
 
397
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
398
- for (int j = 0; j < n_enc; ++j) {
399
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
400
- }
401
  }
402
  }
403
  }
404
  }
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  //
407
  // llm_graph_context
408
  //
@@ -448,10 +395,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
448
  res (std::make_unique<llm_graph_result>()) {
449
  }
450
 
451
- int64_t llm_graph_context::n_pos_per_embd() const {
452
- return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
453
- }
454
-
455
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
456
  if (cb_func) {
457
  cb_func(ubatch, cur, name, il);
@@ -896,11 +839,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
896
  }
897
 
898
  ggml_tensor * llm_graph_context::build_inp_pos() const {
899
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
900
 
901
  auto & cur = inp->pos;
902
 
903
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
904
  ggml_set_input(cur);
905
 
906
  res->add_input(std::move(inp));
@@ -923,6 +866,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
923
  }
924
 
925
  ggml_tensor * llm_graph_context::build_inp_out_ids() const {
 
 
 
 
 
 
 
 
926
  auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
927
 
928
  auto & cur = inp->out_ids;
@@ -940,7 +891,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
940
 
941
  auto & cur = inp->mean;
942
 
943
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
944
  ggml_set_input(cur);
945
 
946
  res->add_input(std::move(inp));
@@ -953,24 +904,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
953
 
954
  auto & cur = inp->cls;
955
 
956
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
957
- ggml_set_input(cur);
958
-
959
- res->add_input(std::move(inp));
960
-
961
- return cur;
962
- }
963
-
964
- ggml_tensor * llm_graph_context::build_inp_s_copy() const {
965
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
966
-
967
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
968
-
969
- const auto n_kv = kv_state->get_n_kv();
970
-
971
- auto & cur = inp->s_copy;
972
-
973
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
974
  ggml_set_input(cur);
975
 
976
  res->add_input(std::move(inp));
@@ -1047,6 +981,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1047
  return pos_bias;
1048
  }
1049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1050
  ggml_tensor * llm_graph_context::build_attn_mha(
1051
  ggml_cgraph * gf,
1052
  ggml_tensor * q,
@@ -1291,36 +1252,6 @@ ggml_tensor * llm_graph_context::build_attn(
1291
  return cur;
1292
  }
1293
 
1294
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1295
- const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1296
-
1297
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1298
-
1299
- {
1300
- const auto n_kv = kv_state->get_base()->get_n_kv();
1301
-
1302
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1303
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1304
- ggml_set_input(inp->self_kq_mask);
1305
-
1306
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1307
- }
1308
-
1309
- {
1310
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1311
-
1312
- const auto n_kv = kv_state->get_swa()->get_n_kv();
1313
-
1314
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1315
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1316
- ggml_set_input(inp->self_kq_mask_swa);
1317
-
1318
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1319
- }
1320
-
1321
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1322
- }
1323
-
1324
  ggml_tensor * llm_graph_context::build_attn(
1325
  llm_graph_input_attn_kv_unified_iswa * inp,
1326
  ggml_cgraph * gf,
@@ -1430,20 +1361,99 @@ ggml_tensor * llm_graph_context::build_attn(
1430
  return cur;
1431
  }
1432
 
1433
- ggml_tensor * llm_graph_context::build_recurrent_state(
1434
- ggml_cgraph * gf,
1435
- ggml_tensor * s,
1436
- ggml_tensor * state_copy,
1437
- int32_t state_size,
1438
- int32_t n_seqs,
1439
- bool avoid_copies) const {
1440
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1441
-
1442
- const auto n_kv = kv_state->get_n_kv();
1443
- const auto kv_head = kv_state->get_head();
1444
- const auto rs_zero = kv_state->get_rs_z();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1445
 
1446
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1447
 
1448
  // Clear a single state which will then be copied to the other cleared states.
1449
  // Note that this is a no-op when the view is zero-sized.
@@ -1474,22 +1484,59 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
1474
  return output_states;
1475
  }
1476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1477
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1478
- ggml_cgraph * gf,
1479
- ggml_tensor * state_copy,
1480
- const llama_ubatch & ubatch,
1481
  int il) const {
1482
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1483
 
1484
  const auto token_shift_count = hparams.token_shift_count;
1485
 
1486
  const int64_t n_seqs = ubatch.n_seqs;
1487
 
1488
- ggml_tensor * token_shift_all = kv_state->get_k_l(il);
1489
 
1490
- ggml_tensor * token_shift = build_recurrent_state(
1491
- gf, token_shift_all, state_copy,
1492
- hparams.n_embd_k_s(), n_seqs);
1493
 
1494
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1495
 
@@ -1500,7 +1547,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1500
  ggml_tensor * token_shift,
1501
  const llama_ubatch & ubatch,
1502
  int il) const {
1503
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1504
 
1505
  const auto token_shift_count = hparams.token_shift_count;
1506
  const auto n_embd = hparams.n_embd;
@@ -1512,7 +1559,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1512
  return ggml_cpy(
1513
  ctx0,
1514
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1515
- ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
1516
  );
1517
  }
1518
 
 
6
 
7
  #include "llama-kv-cache-unified.h"
8
  #include "llama-kv-cache-unified-iswa.h"
9
+ #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-recurrent.h"
11
 
12
  #include <cassert>
13
  #include <cmath>
 
92
  }
93
 
94
  void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
95
+ GGML_ASSERT(out_ids);
 
96
 
97
+ const int64_t n_tokens = ubatch->n_tokens;
 
 
 
98
 
99
+ GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
100
+ int32_t * data = (int32_t *) out_ids->data;
101
 
102
+ if (n_outputs == n_tokens) {
103
+ for (int i = 0; i < n_tokens; ++i) {
104
+ data[i] = i;
105
+ }
106
+
107
+ return;
108
+ }
109
+
110
+ GGML_ASSERT(ubatch->output);
111
+
112
+ int n_outputs = 0;
113
+
114
+ for (int i = 0; i < n_tokens; ++i) {
115
+ if (ubatch->output[i]) {
116
+ data[n_outputs++] = i;
 
 
 
 
117
  }
118
  }
119
  }
 
122
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
123
  const int64_t n_tokens = ubatch->n_tokens;
124
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
125
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
126
 
127
  GGML_ASSERT(mean);
128
  GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
129
 
130
  float * data = (float *) mean->data;
131
+ memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
132
 
133
+ std::vector<uint64_t> sums(n_seqs_unq, 0);
134
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
135
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
136
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
137
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
138
 
139
+ sums[seq_idx] += ubatch->n_seq_tokens;
140
+ }
 
 
 
 
 
 
141
  }
142
 
143
+ std::vector<float> div(n_seqs_unq, 0.0f);
144
+ for (int s = 0; s < n_seqs_unq; ++s) {
145
+ const uint64_t sum = sums[s];
146
+ if (sum > 0) {
147
+ div[s] = 1.0f/float(sum);
148
  }
149
  }
150
 
151
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
152
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
153
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
154
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
155
 
156
+ for (int j = 0; j < n_seq_tokens; ++j) {
157
+ data[seq_idx*n_tokens + i + j] = div[seq_idx];
158
+ }
159
  }
160
  }
161
  }
162
  }
163
 
164
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
165
+ const int64_t n_tokens = ubatch->n_tokens;
166
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
167
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
 
 
 
168
 
169
+ if (cparams.embeddings && (
170
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
171
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
172
+ )) {
173
  GGML_ASSERT(cls);
174
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
175
 
176
  uint32_t * data = (uint32_t *) cls->data;
177
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
 
 
 
 
 
 
 
178
 
179
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
180
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
181
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
182
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
183
 
184
+ data[seq_idx] = i;
 
 
185
  }
186
  }
187
  }
188
 
189
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
 
 
 
 
190
  GGML_ASSERT(cls);
191
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
192
 
193
  uint32_t * data = (uint32_t *) cls->data;
194
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
 
 
 
195
 
196
+ std::vector<int> last_pos(n_seqs_unq, -1);
197
+ std::vector<int> last_row(n_seqs_unq, -1);
 
198
 
199
+ for (int i = 0; i < n_tokens; ++i) {
200
+ const llama_pos pos = ubatch->pos[i];
201
 
202
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
203
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
204
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
205
 
206
+ if (pos >= last_pos[seq_idx]) {
207
+ last_pos[seq_idx] = pos;
208
+ last_row[seq_idx] = i;
209
  }
210
  }
211
  }
212
 
213
+ for (int s = 0; s < n_seqs_unq; ++s) {
214
+ if (last_row[s] >= 0) {
215
+ data[s] = last_row[s];
216
  }
217
  }
218
  }
219
  }
220
 
221
+ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
222
  GGML_UNUSED(ubatch);
223
 
224
+ const int64_t n_rs = mem_state->get_n_rs();
225
 
226
  if (s_copy) {
227
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
228
  int32_t * data = (int32_t *) s_copy->data;
229
 
230
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
231
+ for (uint32_t i = 0; i < n_rs; ++i) {
232
+ data[i] = mem_state->s_copy(i);
233
  }
234
  }
235
  }
 
245
  }
246
 
247
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
248
+ const int64_t n_kv = ubatch->n_tokens;
249
+ const int64_t n_tokens = ubatch->n_tokens;
250
+
251
+ GGML_ASSERT(kq_mask);
252
+ GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
253
+
254
+ float * data = (float *) kq_mask->data;
255
+
256
+ for (int h = 0; h < 1; ++h) {
257
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
258
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
259
+
260
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
261
+ float f = -INFINITY;
262
+
263
+ for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
264
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
+ // TODO: reimplement this like in llama_kv_cache_unified
267
+ if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
268
+ if (hparams.use_alibi) {
269
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
270
+ } else {
271
+ f = 0.0f;
272
  }
273
+ break;
274
  }
275
  }
276
+
277
+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
278
  }
279
  }
280
  }
 
297
  }
298
 
299
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
300
+ GGML_ASSERT(cross_kq_mask);
 
 
301
 
302
+ const int64_t n_enc = cross_kq_mask->ne[0];
303
+ const int64_t n_tokens = ubatch->n_tokens;
304
 
305
+ GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
306
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
307
 
308
+ float * data = (float *) cross_kq_mask->data;
309
+
310
+ for (int h = 0; h < 1; ++h) {
311
+ for (int i = 0; i < n_tokens; ++i) {
312
+ for (int j = 0; j < n_enc; ++j) {
313
+ float f = -INFINITY;
314
+
315
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
316
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
317
+
318
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
319
+ f = 0.0f;
320
  }
 
321
  }
322
+
323
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
324
  }
325
+ }
326
 
327
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
328
+ for (int j = 0; j < n_enc; ++j) {
329
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
 
330
  }
331
  }
332
  }
333
  }
334
 
335
+ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
+ if (self_kq_mask) {
337
+ mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
+ }
339
+
340
+ const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
341
+
342
+ if (s_copy) {
343
+ GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
344
+ int32_t * data = (int32_t *) s_copy->data;
345
+
346
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
+ for (uint32_t i = 0; i < n_rs; ++i) {
348
+ data[i] = mem_state->get_state_recr()->s_copy(i);
349
+ }
350
+ }
351
+ }
352
+
353
  //
354
  // llm_graph_context
355
  //
 
395
  res (std::make_unique<llm_graph_result>()) {
396
  }
397
 
 
 
 
 
398
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
399
  if (cb_func) {
400
  cb_func(ubatch, cur, name, il);
 
839
  }
840
 
841
  ggml_tensor * llm_graph_context::build_inp_pos() const {
842
+ auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
843
 
844
  auto & cur = inp->pos;
845
 
846
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
847
  ggml_set_input(cur);
848
 
849
  res->add_input(std::move(inp));
 
866
  }
867
 
868
  ggml_tensor * llm_graph_context::build_inp_out_ids() const {
869
+ // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
870
+ // but this would make the graph topology depend on the number of output tokens, which can interere with
871
+ // features that require constant topology such as pipline parallelism
872
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
873
+ //if (n_outputs < n_tokens) {
874
+ // return nullptr;
875
+ //}
876
+
877
  auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
878
 
879
  auto & cur = inp->out_ids;
 
891
 
892
  auto & cur = inp->mean;
893
 
894
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
895
  ggml_set_input(cur);
896
 
897
  res->add_input(std::move(inp));
 
904
 
905
  auto & cur = inp->cls;
906
 
907
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
908
  ggml_set_input(cur);
909
 
910
  res->add_input(std::move(inp));
 
981
  return pos_bias;
982
  }
983
 
984
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
985
+ const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
986
+
987
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
988
+
989
+ {
990
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
991
+
992
+ const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
993
+
994
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
995
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
996
+ ggml_set_input(inp->self_kq_mask);
997
+
998
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
999
+ }
1000
+
1001
+ {
1002
+ const auto n_rs = mem_state->get_state_recr()->get_n_rs();
1003
+
1004
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1005
+ ggml_set_input(inp->s_copy);
1006
+ }
1007
+
1008
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1009
+ }
1010
+
1011
  ggml_tensor * llm_graph_context::build_attn_mha(
1012
  ggml_cgraph * gf,
1013
  ggml_tensor * q,
 
1252
  return cur;
1253
  }
1254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1255
  ggml_tensor * llm_graph_context::build_attn(
1256
  llm_graph_input_attn_kv_unified_iswa * inp,
1257
  ggml_cgraph * gf,
 
1361
  return cur;
1362
  }
1363
 
1364
+ ggml_tensor * llm_graph_context::build_attn(
1365
+ llm_graph_input_mem_hybrid * inp,
1366
+ ggml_cgraph * gf,
1367
+ ggml_tensor * wo,
1368
+ ggml_tensor * wo_b,
1369
+ ggml_tensor * q_cur,
1370
+ ggml_tensor * k_cur,
1371
+ ggml_tensor * v_cur,
1372
+ ggml_tensor * kq_b,
1373
+ ggml_tensor * v_mla,
1374
+ float kq_scale,
1375
+ int il) const {
1376
+ // these nodes are added to the graph together so that they are not reordered
1377
+ // by doing so, the number of splits in the graph is reduced
1378
+ ggml_build_forward_expand(gf, q_cur);
1379
+ ggml_build_forward_expand(gf, k_cur);
1380
+ ggml_build_forward_expand(gf, v_cur);
1381
+
1382
+ const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
1383
+
1384
+ // store to KV cache
1385
+ {
1386
+ ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1387
+ ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1388
+ }
1389
+
1390
+ const auto & kq_mask = inp->get_kq_mask();
1391
+
1392
+ ggml_tensor * q = q_cur;
1393
+ ggml_tensor * k = kv_state->get_k(ctx0, il);
1394
+ ggml_tensor * v = kv_state->get_v(ctx0, il);
1395
+
1396
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1397
+ cb(cur, "kqv_out", il);
1398
+
1399
+ if (wo) {
1400
+ cur = build_lora_mm(wo, cur);
1401
+ if (arch == LLM_ARCH_GLM4) {
1402
+ // GLM4 seems to have numerical issues with half-precision accumulators
1403
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1404
+ }
1405
+ }
1406
+
1407
+ if (wo_b) {
1408
+ cur = ggml_add(ctx0, cur, wo_b);
1409
+ }
1410
+
1411
+ return cur;
1412
+ }
1413
+
1414
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1415
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1416
+
1417
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1418
+
1419
+ {
1420
+ const auto n_kv = kv_state->get_base()->get_n_kv();
1421
+
1422
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1423
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1424
+ ggml_set_input(inp->self_kq_mask);
1425
+
1426
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1427
+ }
1428
+
1429
+ {
1430
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1431
 
1432
+ const auto n_kv = kv_state->get_swa()->get_n_kv();
1433
+
1434
+ inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1435
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1436
+ ggml_set_input(inp->self_kq_mask_swa);
1437
+
1438
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1439
+ }
1440
+
1441
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1442
+ }
1443
+
1444
+ ggml_tensor * llm_graph_context::build_rs(
1445
+ ggml_cgraph * gf,
1446
+ ggml_tensor * s,
1447
+ ggml_tensor * state_copy,
1448
+ int32_t state_size,
1449
+ int32_t n_seqs,
1450
+ uint32_t n_kv,
1451
+ uint32_t kv_head,
1452
+ uint32_t kv_size,
1453
+ int32_t rs_zero,
1454
+ bool avoid_copies) const {
1455
+
1456
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1457
 
1458
  // Clear a single state which will then be copied to the other cleared states.
1459
  // Note that this is a no-op when the view is zero-sized.
 
1484
  return output_states;
1485
  }
1486
 
1487
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1488
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1489
+
1490
+ auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1491
+
1492
+ const auto n_rs = kv_state->get_n_rs();
1493
+
1494
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1495
+ ggml_set_input(inp->s_copy);
1496
+
1497
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1498
+ }
1499
+
1500
+ ggml_tensor * llm_graph_context::build_rs(
1501
+ llm_graph_input_rs * inp,
1502
+ ggml_cgraph * gf,
1503
+ ggml_tensor * s,
1504
+ int32_t state_size,
1505
+ int32_t n_seqs,
1506
+ bool avoid_copies) const {
1507
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1508
+
1509
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1510
+ }
1511
+
1512
+ ggml_tensor * llm_graph_context::build_rs(
1513
+ llm_graph_input_mem_hybrid * inp,
1514
+ ggml_cgraph * gf,
1515
+ ggml_tensor * s,
1516
+ int32_t state_size,
1517
+ int32_t n_seqs,
1518
+ bool avoid_copies) const {
1519
+ const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
1520
+
1521
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1522
+ }
1523
+
1524
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1525
+ llm_graph_input_rs * inp,
1526
+ ggml_cgraph * gf,
1527
+ const llama_ubatch & ubatch,
1528
  int il) const {
1529
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1530
 
1531
  const auto token_shift_count = hparams.token_shift_count;
1532
 
1533
  const int64_t n_seqs = ubatch.n_seqs;
1534
 
1535
+ ggml_tensor * token_shift_all = kv_state->get_r_l(il);
1536
 
1537
+ ggml_tensor * token_shift = build_rs(
1538
+ inp, gf, token_shift_all,
1539
+ hparams.n_embd_r(), n_seqs);
1540
 
1541
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1542
 
 
1547
  ggml_tensor * token_shift,
1548
  const llama_ubatch & ubatch,
1549
  int il) const {
1550
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1551
 
1552
  const auto token_shift_count = hparams.token_shift_count;
1553
  const auto n_embd = hparams.n_embd;
 
1559
  return ggml_cpy(
1560
  ctx0,
1561
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1562
+ ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
1563
  );
1564
  }
1565
 
examples/talk-llama/llama-graph.h CHANGED
@@ -21,7 +21,8 @@ struct llama_memory_state_i;
21
 
22
  class llama_kv_cache_unified_state;
23
  class llama_kv_cache_unified_iswa_state;
24
- class llama_kv_cache_recurrent_state;
 
25
 
26
  // certain models (typically multi-modal) can produce different types of graphs
27
  enum llm_graph_type {
@@ -94,14 +95,14 @@ public:
94
 
95
  class llm_graph_input_pos : public llm_graph_input_i {
96
  public:
97
- llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
98
  virtual ~llm_graph_input_pos() = default;
99
 
100
  void set_input(const llama_ubatch * ubatch) override;
101
 
102
  ggml_tensor * pos = nullptr; // I32 [n_batch]
103
 
104
- const int64_t n_pos_per_embd = 1;
105
  };
106
 
107
  // temperature tuning, used by llama4
@@ -188,16 +189,16 @@ public:
188
  const llama_cparams & cparams;
189
  };
190
 
191
- class llm_graph_input_s_copy : public llm_graph_input_i {
192
  public:
193
- llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
194
- virtual ~llm_graph_input_s_copy() = default;
195
 
196
  void set_input(const llama_ubatch * ubatch) override;
197
 
198
  ggml_tensor * s_copy; // I32 [kv_size]
199
 
200
- const llama_kv_cache_recurrent_state * kv_state;
201
  };
202
 
203
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -300,6 +301,33 @@ public:
300
  const llama_cross * cross = nullptr;
301
  };
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  //
304
  // llm_graph_result
305
  //
@@ -436,8 +464,6 @@ struct llm_graph_context {
436
 
437
  llm_graph_context(const llm_graph_params & params);
438
 
439
- int64_t n_pos_per_embd() const;
440
-
441
  void cb(ggml_tensor * cur, const char * name, int il) const;
442
 
443
  //
@@ -508,13 +534,14 @@ struct llm_graph_context {
508
  ggml_tensor * build_inp_out_ids() const;
509
  ggml_tensor * build_inp_mean() const;
510
  ggml_tensor * build_inp_cls() const;
511
- ggml_tensor * build_inp_s_copy() const;
512
 
513
  ggml_tensor * build_inp_cross_embd() const;
514
  ggml_tensor * build_inp_pos_bucket_enc() const;
515
  ggml_tensor * build_inp_pos_bucket_dec() const;
516
  ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
517
 
 
 
518
  //
519
  // attention
520
  //
@@ -589,22 +616,62 @@ struct llm_graph_context {
589
  float kq_scale,
590
  int il) const;
591
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  //
593
  // recurrent
594
  //
595
 
596
- ggml_tensor * build_recurrent_state(
597
- ggml_cgraph * gf,
598
- ggml_tensor * s,
599
- ggml_tensor * state_copy,
600
- int32_t state_size,
601
- int32_t n_seqs,
602
- bool avoid_copies = false) const;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
 
604
  ggml_tensor * build_rwkv_token_shift_load(
605
- ggml_cgraph * gf,
606
- ggml_tensor * state_copy,
607
- const llama_ubatch & ubatch,
608
  int il) const;
609
 
610
  ggml_tensor * build_rwkv_token_shift_store(
 
21
 
22
  class llama_kv_cache_unified_state;
23
  class llama_kv_cache_unified_iswa_state;
24
+ class llama_memory_recurrent_state;
25
+ class llama_memory_hybrid_state;
26
 
27
  // certain models (typically multi-modal) can produce different types of graphs
28
  enum llm_graph_type {
 
95
 
96
  class llm_graph_input_pos : public llm_graph_input_i {
97
  public:
98
+ llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
99
  virtual ~llm_graph_input_pos() = default;
100
 
101
  void set_input(const llama_ubatch * ubatch) override;
102
 
103
  ggml_tensor * pos = nullptr; // I32 [n_batch]
104
 
105
+ const uint32_t n_pos_per_embd = 1;
106
  };
107
 
108
  // temperature tuning, used by llama4
 
189
  const llama_cparams & cparams;
190
  };
191
 
192
+ class llm_graph_input_rs : public llm_graph_input_i {
193
  public:
194
+ llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
195
+ virtual ~llm_graph_input_rs() = default;
196
 
197
  void set_input(const llama_ubatch * ubatch) override;
198
 
199
  ggml_tensor * s_copy; // I32 [kv_size]
200
 
201
+ const llama_memory_recurrent_state * mem_state;
202
  };
203
 
204
  class llm_graph_input_cross_embd : public llm_graph_input_i {
 
301
  const llama_cross * cross = nullptr;
302
  };
303
 
304
+ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
305
+ public:
306
+ llm_graph_input_mem_hybrid(
307
+ const llama_hparams & hparams,
308
+ const llama_cparams & cparams,
309
+ const llama_memory_hybrid_state * mem_state) :
310
+ hparams(hparams),
311
+ cparams(cparams),
312
+ mem_state(mem_state) {
313
+ }
314
+ virtual ~llm_graph_input_mem_hybrid() = default;
315
+
316
+ void set_input(const llama_ubatch * ubatch) override;
317
+
318
+ ggml_tensor * s_copy; // I32 [kv_size]
319
+
320
+ ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
321
+
322
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
323
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
324
+
325
+ const llama_hparams & hparams;
326
+ const llama_cparams & cparams;
327
+
328
+ const llama_memory_hybrid_state * mem_state;
329
+ };
330
+
331
  //
332
  // llm_graph_result
333
  //
 
464
 
465
  llm_graph_context(const llm_graph_params & params);
466
 
 
 
467
  void cb(ggml_tensor * cur, const char * name, int il) const;
468
 
469
  //
 
534
  ggml_tensor * build_inp_out_ids() const;
535
  ggml_tensor * build_inp_mean() const;
536
  ggml_tensor * build_inp_cls() const;
 
537
 
538
  ggml_tensor * build_inp_cross_embd() const;
539
  ggml_tensor * build_inp_pos_bucket_enc() const;
540
  ggml_tensor * build_inp_pos_bucket_dec() const;
541
  ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
542
 
543
+ llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
544
+
545
  //
546
  // attention
547
  //
 
616
  float kq_scale,
617
  int il) const;
618
 
619
+ ggml_tensor * build_attn(
620
+ llm_graph_input_mem_hybrid * inp,
621
+ ggml_cgraph * gf,
622
+ ggml_tensor * wo,
623
+ ggml_tensor * wo_b,
624
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
625
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
626
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
627
+ ggml_tensor * kq_b,
628
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
629
+ float kq_scale,
630
+ int il) const;
631
  //
632
  // recurrent
633
  //
634
 
635
+ // TODO: avoid notion of "kv"
636
+ // TODO: move this implementation to llama_memory_recurrent.
637
+ // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
638
+ // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
639
+ // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
640
+ // `llama_memory_recurrent`
641
+ ggml_tensor * build_rs(
642
+ ggml_cgraph * gf,
643
+ ggml_tensor * s,
644
+ ggml_tensor * state_copy,
645
+ int32_t state_size,
646
+ int32_t n_seqs,
647
+ uint32_t n_kv,
648
+ uint32_t kv_head,
649
+ uint32_t kv_size,
650
+ int32_t rs_zero,
651
+ bool avoid_copies = false) const;
652
+
653
+ llm_graph_input_rs * build_rs_inp() const;
654
+
655
+ ggml_tensor * build_rs(
656
+ llm_graph_input_rs * inp,
657
+ ggml_cgraph * gf,
658
+ ggml_tensor * s,
659
+ int32_t state_size,
660
+ int32_t n_seqs,
661
+ bool avoid_copies = false) const;
662
+
663
+ ggml_tensor * build_rs(
664
+ llm_graph_input_mem_hybrid * inp,
665
+ ggml_cgraph * gf,
666
+ ggml_tensor * s,
667
+ int32_t state_size,
668
+ int32_t n_seqs,
669
+ bool avoid_copies = false) const;
670
 
671
  ggml_tensor * build_rwkv_token_shift_load(
672
+ llm_graph_input_rs * inp,
673
+ ggml_cgraph * gf,
674
+ const llama_ubatch & ubatch,
675
  int il) const;
676
 
677
  ggml_tensor * build_rwkv_token_shift_store(
examples/talk-llama/llama-hparams.cpp CHANGED
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
65
  return n_embd_head_v * n_head_kv;
66
  }
67
 
68
- uint32_t llama_hparams::n_embd_k_s() const {
69
  if (wkv_head_size != 0) {
70
  // for RWKV models
71
  return token_shift_count * n_embd;
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
76
  return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
77
  }
78
 
79
- uint32_t llama_hparams::n_embd_v_s() const {
80
  if (wkv_head_size != 0) {
81
  // corresponds to RWKV's wkv_states size
82
  return n_embd * wkv_head_size;
@@ -86,6 +86,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
86
  return ssm_d_state * ssm_d_inner;
87
  }
88
 
 
 
 
 
 
 
 
 
89
  bool llama_hparams::is_swa(uint32_t il) const {
90
  if (il < n_layer) {
91
  return swa_layers[il];
 
65
  return n_embd_head_v * n_head_kv;
66
  }
67
 
68
+ uint32_t llama_hparams::n_embd_r() const {
69
  if (wkv_head_size != 0) {
70
  // for RWKV models
71
  return token_shift_count * n_embd;
 
76
  return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
77
  }
78
 
79
+ uint32_t llama_hparams::n_embd_s() const {
80
  if (wkv_head_size != 0) {
81
  // corresponds to RWKV's wkv_states size
82
  return n_embd * wkv_head_size;
 
86
  return ssm_d_state * ssm_d_inner;
87
  }
88
 
89
+ bool llama_hparams::is_recurrent(uint32_t il) const {
90
+ return recurrent_layer_arr[il];
91
+ }
92
+
93
+ uint32_t llama_hparams::n_pos_per_embd() const {
94
+ return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
95
+ }
96
+
97
  bool llama_hparams::is_swa(uint32_t il) const {
98
  if (il < n_layer) {
99
  return swa_layers[il];
examples/talk-llama/llama-hparams.h CHANGED
@@ -115,6 +115,9 @@ struct llama_hparams {
115
  uint32_t ssm_d_state = 0;
116
  uint32_t ssm_dt_rank = 0;
117
 
 
 
 
118
  bool ssm_dt_b_c_rms = false;
119
 
120
  float f_clamp_kqv = 0.0f;
@@ -181,10 +184,15 @@ struct llama_hparams {
181
 
182
  // dimension of the rolling state embeddings
183
  // corresponds to Mamba's conv_states size or RWKV's token_shift states size
184
- uint32_t n_embd_k_s() const;
185
 
186
  // dimension of the recurrent state embeddings
187
- uint32_t n_embd_v_s() const;
 
 
 
 
 
188
 
189
  bool is_swa(uint32_t il) const;
190
  };
 
115
  uint32_t ssm_d_state = 0;
116
  uint32_t ssm_dt_rank = 0;
117
 
118
+ // for hybrid state space models
119
+ std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
120
+
121
  bool ssm_dt_b_c_rms = false;
122
 
123
  float f_clamp_kqv = 0.0f;
 
184
 
185
  // dimension of the rolling state embeddings
186
  // corresponds to Mamba's conv_states size or RWKV's token_shift states size
187
+ uint32_t n_embd_r() const;
188
 
189
  // dimension of the recurrent state embeddings
190
+ uint32_t n_embd_s() const;
191
+
192
+ // whether or not the given layer is recurrent (for hybrid models)
193
+ bool is_recurrent(uint32_t il) const;
194
+
195
+ uint32_t n_pos_per_embd() const;
196
 
197
  bool is_swa(uint32_t il) const;
198
  };
examples/talk-llama/llama-kv-cache-unified-iswa.cpp CHANGED
@@ -95,19 +95,22 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
  return kv_swa->seq_pos_max(seq_id);
96
  }
97
 
98
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
99
  GGML_UNUSED(embd_all);
100
 
101
  // first try simple split
102
  do {
103
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
104
 
105
  std::vector<llama_ubatch> ubatches;
 
 
106
 
107
- while (sbatch.n_tokens > 0) {
108
- auto ubatch = sbatch.split_simple(n_ubatch);
 
109
 
110
- ubatches.push_back(ubatch);
111
  }
112
 
113
  auto heads_base = kv_base->prepare(ubatches);
@@ -123,19 +126,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
123
  assert(heads_base.size() == heads_swa.size());
124
 
125
  return std::make_unique<llama_kv_cache_unified_iswa_state>(
126
- this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
127
  } while (false);
128
 
129
  // if it fails, try equal split
130
  do {
131
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
132
 
133
  std::vector<llama_ubatch> ubatches;
 
 
134
 
135
- while (sbatch.n_tokens > 0) {
136
- auto ubatch = sbatch.split_equal(n_ubatch);
 
137
 
138
- ubatches.push_back(ubatch);
139
  }
140
 
141
  auto heads_base = kv_base->prepare(ubatches);
@@ -151,7 +157,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
151
  assert(heads_base.size() == heads_swa.size());
152
 
153
  return std::make_unique<llama_kv_cache_unified_iswa_state>(
154
- this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
155
  } while (false);
156
 
157
  // TODO: if we fail again, we should attempt different splitting strategies
@@ -197,37 +203,31 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
197
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
198
 
199
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
200
- llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
201
- state_base = kv->get_base()->init_full();
202
- state_swa = kv->get_swa ()->init_full();
203
-
204
- status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
205
  }
206
 
207
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
208
  llama_kv_cache_unified_iswa * kv,
209
  llama_context * lctx,
210
- bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
211
- state_base = kv->get_base()->init_update(lctx, optimize);
212
- state_swa = kv->get_swa ()->init_update(lctx, optimize);
213
-
214
- status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
215
  }
216
 
217
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
218
  llama_kv_cache_unified_iswa * kv,
219
- llama_sbatch sbatch,
220
  std::vector<uint32_t> heads_base,
221
  std::vector<uint32_t> heads_swa,
222
- std::vector<llama_ubatch> ubatches)
223
- : status(LLAMA_MEMORY_STATUS_SUCCESS),
224
- sbatch(std::move(sbatch)),
225
- ubatches(std::move(ubatches)) {
226
  // note: here we copy the ubatches. not sure if this is ideal
227
- state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
228
- state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
229
-
230
- status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
231
  }
232
 
233
  llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
@@ -256,12 +256,6 @@ bool llama_kv_cache_unified_iswa_state::apply() {
256
  return res;
257
  }
258
 
259
- std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
260
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
261
-
262
- return sbatch.out_ids;
263
- }
264
-
265
  llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
266
  return status;
267
  }
 
95
  return kv_swa->seq_pos_max(seq_id);
96
  }
97
 
98
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
99
  GGML_UNUSED(embd_all);
100
 
101
  // first try simple split
102
  do {
103
+ balloc.split_reset();
104
 
105
  std::vector<llama_ubatch> ubatches;
106
+ while (true) {
107
+ auto ubatch = balloc.split_simple(n_ubatch);
108
 
109
+ if (ubatch.n_tokens == 0) {
110
+ break;
111
+ }
112
 
113
+ ubatches.push_back(std::move(ubatch)); // NOLINT
114
  }
115
 
116
  auto heads_base = kv_base->prepare(ubatches);
 
126
  assert(heads_base.size() == heads_swa.size());
127
 
128
  return std::make_unique<llama_kv_cache_unified_iswa_state>(
129
+ this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
130
  } while (false);
131
 
132
  // if it fails, try equal split
133
  do {
134
+ balloc.split_reset();
135
 
136
  std::vector<llama_ubatch> ubatches;
137
+ while (true) {
138
+ auto ubatch = balloc.split_equal(n_ubatch);
139
 
140
+ if (ubatch.n_tokens == 0) {
141
+ break;
142
+ }
143
 
144
+ ubatches.push_back(std::move(ubatch)); // NOLINT
145
  }
146
 
147
  auto heads_base = kv_base->prepare(ubatches);
 
157
  assert(heads_base.size() == heads_swa.size());
158
 
159
  return std::make_unique<llama_kv_cache_unified_iswa_state>(
160
+ this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
161
  } while (false);
162
 
163
  // TODO: if we fail again, we should attempt different splitting strategies
 
203
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
204
 
205
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
206
+ llama_kv_cache_unified_iswa * kv) :
207
+ state_base(kv->get_base()->init_full()),
208
+ state_swa (kv->get_swa ()->init_full()),
209
+ status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
 
210
  }
211
 
212
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
213
  llama_kv_cache_unified_iswa * kv,
214
  llama_context * lctx,
215
+ bool optimize) :
216
+ state_base(kv->get_base()->init_update(lctx, optimize)),
217
+ state_swa (kv->get_swa ()->init_update(lctx, optimize)),
218
+ status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
 
219
  }
220
 
221
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
222
  llama_kv_cache_unified_iswa * kv,
 
223
  std::vector<uint32_t> heads_base,
224
  std::vector<uint32_t> heads_swa,
225
+ std::vector<llama_ubatch> ubatches) :
226
+ ubatches(std::move(ubatches)),
 
 
227
  // note: here we copy the ubatches. not sure if this is ideal
228
+ state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)),
229
+ state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa), this->ubatches)),
230
+ status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
 
231
  }
232
 
233
  llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
 
256
  return res;
257
  }
258
 
 
 
 
 
 
 
259
  llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
260
  return status;
261
  }
examples/talk-llama/llama-kv-cache-unified-iswa.h CHANGED
@@ -32,7 +32,7 @@ public:
32
  //
33
 
34
  llama_memory_state_ptr init_batch(
35
- const llama_batch & batch,
36
  uint32_t n_ubatch,
37
  bool embd_all) override;
38
 
@@ -90,7 +90,6 @@ public:
90
  // used to create a state from a batch
91
  llama_kv_cache_unified_iswa_state(
92
  llama_kv_cache_unified_iswa * kv,
93
- llama_sbatch sbatch,
94
  std::vector<uint32_t> heads_base,
95
  std::vector<uint32_t> heads_swa,
96
  std::vector<llama_ubatch> ubatches);
@@ -104,8 +103,6 @@ public:
104
  bool next() override;
105
  bool apply() override;
106
 
107
- std::vector<int64_t> & out_ids() override;
108
-
109
  llama_memory_status get_status() const override;
110
  const llama_ubatch & get_ubatch() const override;
111
 
@@ -117,17 +114,15 @@ public:
117
  const llama_kv_cache_unified_state * get_swa() const;
118
 
119
  private:
120
- llama_memory_status status;
121
-
122
  //llama_kv_cache_unified_iswa * kv;
123
 
124
- llama_sbatch sbatch;
125
-
126
  // the index of the next ubatch to process
127
  size_t i_next = 0;
128
 
129
  std::vector<llama_ubatch> ubatches;
130
 
131
- llama_memory_state_ptr state_base;
132
- llama_memory_state_ptr state_swa;
 
 
133
  };
 
32
  //
33
 
34
  llama_memory_state_ptr init_batch(
35
+ llama_batch_allocr & balloc,
36
  uint32_t n_ubatch,
37
  bool embd_all) override;
38
 
 
90
  // used to create a state from a batch
91
  llama_kv_cache_unified_iswa_state(
92
  llama_kv_cache_unified_iswa * kv,
 
93
  std::vector<uint32_t> heads_base,
94
  std::vector<uint32_t> heads_swa,
95
  std::vector<llama_ubatch> ubatches);
 
103
  bool next() override;
104
  bool apply() override;
105
 
 
 
106
  llama_memory_status get_status() const override;
107
  const llama_ubatch & get_ubatch() const override;
108
 
 
114
  const llama_kv_cache_unified_state * get_swa() const;
115
 
116
  private:
 
 
117
  //llama_kv_cache_unified_iswa * kv;
118
 
 
 
119
  // the index of the next ubatch to process
120
  size_t i_next = 0;
121
 
122
  std::vector<llama_ubatch> ubatches;
123
 
124
+ const llama_memory_state_ptr state_base;
125
+ const llama_memory_state_ptr state_swa;
126
+
127
+ const llama_memory_status status;
128
  };
examples/talk-llama/llama-kv-cache-unified.cpp CHANGED
@@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
68
  continue;
69
  }
70
 
71
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
72
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
73
 
74
  const char * dev_name = "CPU";
75
 
@@ -308,17 +308,23 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
308
  }
309
 
310
  llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311
- const llama_batch & batch,
312
  uint32_t n_ubatch,
313
  bool embd_all) {
314
  GGML_UNUSED(embd_all);
315
 
316
  do {
317
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
318
 
319
  std::vector<llama_ubatch> ubatches;
320
- while (sbatch.n_tokens > 0) {
321
- ubatches.push_back(sbatch.split_simple(n_ubatch));
 
 
 
 
 
 
322
  }
323
 
324
  auto heads = prepare(ubatches);
@@ -327,7 +333,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
327
  }
328
 
329
  return std::make_unique<llama_kv_cache_unified_state>(
330
- this, std::move(sbatch), std::move(heads), std::move(ubatches));
331
  } while (false);
332
 
333
  return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -644,12 +650,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
644
  }
645
 
646
  void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
647
- if (debug > 0) {
648
- LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
649
- LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
650
- LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
651
- }
652
-
653
  // keep track of the max sequence position that we would overwrite with this ubatch
654
  // for non-SWA cache, this would be always empty
655
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -657,27 +657,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
657
  seq_pos_max_rm[s] = -1;
658
  }
659
 
660
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
661
- for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
662
- const uint32_t idx = s*ubatch.n_seq_tokens + j;
663
-
664
- if (!cells.is_empty(head_cur + idx)) {
665
- assert(cells.seq_count(head_cur + idx) == 1);
666
 
667
- const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
668
- const llama_pos pos = cells.pos_get(head_cur + idx);
669
 
670
- seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
671
 
672
- cells.rm(head_cur + idx);
673
- }
674
 
675
- cells.pos_set(head_cur + idx, ubatch.pos[idx]);
676
 
677
- // TODO: fix indexing [UBATCH_IDX]
678
- for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
679
- cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
680
- }
681
  }
682
  }
683
 
@@ -696,6 +691,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
696
  seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
697
  }
698
  }
 
699
  // move the head at the end of the slot
700
  head = head_cur + ubatch.n_tokens;
701
  }
@@ -792,9 +788,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
792
  }
793
 
794
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
795
- const uint32_t n_tokens = ubatch->n_tokens;
796
- const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
797
- const uint32_t n_seqs = ubatch->n_seqs;
798
 
799
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
800
  float * data = (float *) dst->data;
@@ -814,52 +808,48 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
814
  // xxxxx-----
815
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
816
  for (uint32_t h = 0; h < 1; ++h) {
817
- for (uint32_t s = 0; s < n_seqs; ++s) {
818
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
819
 
820
- for (uint32_t j = 0; j < n_seq_tokens; ++j) {
821
- const uint32_t idx = s*n_seq_tokens + j;
822
 
823
- const llama_pos p1 = ubatch->pos[idx];
 
824
 
825
- for (uint32_t i = 0; i < n_kv; ++i) {
826
- float f = 0.0f;
827
 
828
- bool masked = false;
829
-
830
- if (cells.is_empty(i)) {
831
- masked = true;
832
- } else {
833
- const llama_pos p0 = cells.pos_get(i);
834
-
835
- // mask the token if not the same sequence
836
- masked = masked || (!cells.seq_has(i, seq_id));
837
 
838
- // mask future tokens
839
- masked = masked || (causal_attn && p0 > p1);
840
 
841
- // apply SWA if any
842
- masked = masked || (is_masked_swa(p0, p1));
843
 
844
- if (!masked && hparams.use_alibi) {
845
- f = -std::abs(p0 - p1);
846
- }
847
- }
848
 
849
- if (masked) {
850
- f = -INFINITY;
851
  }
 
852
 
853
- data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
 
854
  }
 
 
855
  }
856
  }
857
 
858
  // mask padded tokens
859
  if (data) {
860
- for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
861
- for (uint32_t i = 0; i < n_kv; ++i) {
862
- data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
863
  }
864
  }
865
  }
@@ -887,12 +877,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
887
  const int32_t n_kv = dst->ne[0];
888
 
889
  for (int h = 0; h < 1; ++h) {
890
- for (int j = 0; j < n_tokens; ++j) {
891
- for (int i = 0; i < n_kv; ++i) {
892
  // the position when the cells is empty is irrelevant - it will be masked out later in the attention
893
- const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
894
 
895
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
896
  }
897
  }
898
  }
@@ -1430,7 +1420,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1430
  for (const auto & layer : layers) {
1431
  const uint32_t il = layer.il;
1432
 
1433
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1434
 
1435
  // Write key type
1436
  const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1452,7 +1442,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1452
  for (const auto & layer : layers) {
1453
  const uint32_t il = layer.il;
1454
 
1455
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1456
 
1457
  // Write value type
1458
  const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1476,7 +1466,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1476
  for (const auto & layer : layers) {
1477
  const uint32_t il = layer.il;
1478
 
1479
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1480
 
1481
  // Write value type
1482
  const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1509,12 +1499,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1509
 
1510
  seq_rm(dest_seq_id, -1, -1);
1511
 
1512
- llama_sbatch sbatch;
1513
- llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1514
 
1515
- ubatch.n_tokens = cell_count;
1516
- ubatch.n_seq_tokens = cell_count;
1517
- ubatch.n_seqs = 1;
1518
 
1519
  for (uint32_t i = 0; i < cell_count; ++i) {
1520
  llama_pos pos;
@@ -1621,7 +1608,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1621
  for (const auto & layer : layers) {
1622
  const uint32_t il = layer.il;
1623
 
1624
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1625
 
1626
  // Read type of key
1627
  int32_t k_type_i_ref;
@@ -1651,7 +1638,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1651
  for (const auto & layer : layers) {
1652
  const uint32_t il = layer.il;
1653
 
1654
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1655
 
1656
  // Read type of value
1657
  int32_t v_type_i_ref;
@@ -1681,7 +1668,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1681
  for (const auto & layer : layers) {
1682
  const uint32_t il = layer.il;
1683
 
1684
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1685
 
1686
  // Read type of value
1687
  int32_t v_type_i_ref;
@@ -1746,9 +1733,8 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1746
 
1747
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1748
  llama_kv_cache_unified * kv,
1749
- llama_sbatch sbatch,
1750
  llama_kv_cache_unified::ubatch_heads heads,
1751
- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1752
  }
1753
 
1754
  llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
@@ -1781,12 +1767,6 @@ bool llama_kv_cache_unified_state::apply() {
1781
  return true;
1782
  }
1783
 
1784
- std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
1785
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1786
-
1787
- return sbatch.out_ids;
1788
- }
1789
-
1790
  llama_memory_status llama_kv_cache_unified_state::get_status() const {
1791
  return status;
1792
  }
 
68
  continue;
69
  }
70
 
71
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
72
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
73
 
74
  const char * dev_name = "CPU";
75
 
 
308
  }
309
 
310
  llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311
+ llama_batch_allocr & balloc,
312
  uint32_t n_ubatch,
313
  bool embd_all) {
314
  GGML_UNUSED(embd_all);
315
 
316
  do {
317
+ balloc.split_reset();
318
 
319
  std::vector<llama_ubatch> ubatches;
320
+ while (true) {
321
+ auto ubatch = balloc.split_simple(n_ubatch);
322
+
323
+ if (ubatch.n_tokens == 0) {
324
+ break;
325
+ }
326
+
327
+ ubatches.push_back(std::move(ubatch)); // NOLINT
328
  }
329
 
330
  auto heads = prepare(ubatches);
 
333
  }
334
 
335
  return std::make_unique<llama_kv_cache_unified_state>(
336
+ this, std::move(heads), std::move(ubatches));
337
  } while (false);
338
 
339
  return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 
650
  }
651
 
652
  void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
 
 
 
 
 
 
653
  // keep track of the max sequence position that we would overwrite with this ubatch
654
  // for non-SWA cache, this would be always empty
655
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
 
657
  seq_pos_max_rm[s] = -1;
658
  }
659
 
660
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
661
+ if (!cells.is_empty(head_cur + i)) {
662
+ assert(cells.seq_count(head_cur + i) == 1);
 
 
 
663
 
664
+ const llama_seq_id seq_id = cells.seq_get(head_cur + i);
665
+ const llama_pos pos = cells.pos_get(head_cur + i);
666
 
667
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
668
 
669
+ cells.rm(head_cur + i);
670
+ }
671
 
672
+ cells.pos_set(head_cur + i, ubatch.pos[i]);
673
 
674
+ for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
675
+ cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
 
 
676
  }
677
  }
678
 
 
691
  seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
692
  }
693
  }
694
+
695
  // move the head at the end of the slot
696
  head = head_cur + ubatch.n_tokens;
697
  }
 
788
  }
789
 
790
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
791
+ const uint32_t n_tokens = ubatch->n_tokens;
 
 
792
 
793
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
794
  float * data = (float *) dst->data;
 
808
  // xxxxx-----
809
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
810
  for (uint32_t h = 0; h < 1; ++h) {
811
+ for (uint32_t i = 0; i < n_tokens; ++i) {
812
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
813
 
814
+ const llama_pos p1 = ubatch->pos[i];
 
815
 
816
+ for (uint32_t j = 0; j < n_kv; ++j) {
817
+ float f = 0.0f;
818
 
819
+ bool masked = false;
 
820
 
821
+ if (cells.is_empty(j)) {
822
+ masked = true;
823
+ } else {
824
+ const llama_pos p0 = cells.pos_get(j);
 
 
 
 
 
825
 
826
+ // mask the token if not the same sequence
827
+ masked = masked || (!cells.seq_has(j, seq_id));
828
 
829
+ // mask future tokens
830
+ masked = masked || (causal_attn && p0 > p1);
831
 
832
+ // apply SWA if any
833
+ masked = masked || (is_masked_swa(p0, p1));
 
 
834
 
835
+ if (!masked && hparams.use_alibi) {
836
+ f = -std::abs(p0 - p1);
837
  }
838
+ }
839
 
840
+ if (masked) {
841
+ f = -INFINITY;
842
  }
843
+
844
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
845
  }
846
  }
847
 
848
  // mask padded tokens
849
  if (data) {
850
+ for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
851
+ for (uint32_t j = 0; j < n_kv; ++j) {
852
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
853
  }
854
  }
855
  }
 
877
  const int32_t n_kv = dst->ne[0];
878
 
879
  for (int h = 0; h < 1; ++h) {
880
+ for (int i = 0; i < n_tokens; ++i) {
881
+ for (int j = 0; j < n_kv; ++j) {
882
  // the position when the cells is empty is irrelevant - it will be masked out later in the attention
883
+ const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
884
 
885
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
886
  }
887
  }
888
  }
 
1420
  for (const auto & layer : layers) {
1421
  const uint32_t il = layer.il;
1422
 
1423
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1424
 
1425
  // Write key type
1426
  const int32_t k_type_i = (int32_t)layer.k->type;
 
1442
  for (const auto & layer : layers) {
1443
  const uint32_t il = layer.il;
1444
 
1445
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1446
 
1447
  // Write value type
1448
  const int32_t v_type_i = (int32_t)layer.v->type;
 
1466
  for (const auto & layer : layers) {
1467
  const uint32_t il = layer.il;
1468
 
1469
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1470
 
1471
  // Write value type
1472
  const int32_t v_type_i = (int32_t)layer.v->type;
 
1499
 
1500
  seq_rm(dest_seq_id, -1, -1);
1501
 
1502
+ llama_batch_allocr balloc(hparams.n_pos_per_embd());
 
1503
 
1504
+ llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
 
 
1505
 
1506
  for (uint32_t i = 0; i < cell_count; ++i) {
1507
  llama_pos pos;
 
1608
  for (const auto & layer : layers) {
1609
  const uint32_t il = layer.il;
1610
 
1611
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1612
 
1613
  // Read type of key
1614
  int32_t k_type_i_ref;
 
1638
  for (const auto & layer : layers) {
1639
  const uint32_t il = layer.il;
1640
 
1641
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1642
 
1643
  // Read type of value
1644
  int32_t v_type_i_ref;
 
1668
  for (const auto & layer : layers) {
1669
  const uint32_t il = layer.il;
1670
 
1671
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1672
 
1673
  // Read type of value
1674
  int32_t v_type_i_ref;
 
1733
 
1734
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1735
  llama_kv_cache_unified * kv,
 
1736
  llama_kv_cache_unified::ubatch_heads heads,
1737
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1738
  }
1739
 
1740
  llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
 
1767
  return true;
1768
  }
1769
 
 
 
 
 
 
 
1770
  llama_memory_status llama_kv_cache_unified_state::get_status() const {
1771
  return status;
1772
  }
examples/talk-llama/llama-kv-cache-unified.h CHANGED
@@ -57,7 +57,7 @@ public:
57
  //
58
 
59
  llama_memory_state_ptr init_batch(
60
- const llama_batch & batch,
61
  uint32_t n_ubatch,
62
  bool embd_all) override;
63
 
@@ -231,7 +231,6 @@ public:
231
  // used to create a decode state from a batch
232
  llama_kv_cache_unified_state(
233
  llama_kv_cache_unified * kv,
234
- llama_sbatch sbatch,
235
  ubatch_heads heads,
236
  std::vector<llama_ubatch> ubatches);
237
 
@@ -244,8 +243,6 @@ public:
244
  bool next() override;
245
  bool apply() override;
246
 
247
- std::vector<int64_t> & out_ids() override;
248
-
249
  llama_memory_status get_status() const override;
250
  const llama_ubatch & get_ubatch() const override;
251
 
@@ -286,8 +283,6 @@ private:
286
  // batch processing state
287
  //
288
 
289
- llama_sbatch sbatch;
290
-
291
  // the index of the next ubatch to process
292
  size_t i_next = 0;
293
 
 
57
  //
58
 
59
  llama_memory_state_ptr init_batch(
60
+ llama_batch_allocr & balloc,
61
  uint32_t n_ubatch,
62
  bool embd_all) override;
63
 
 
231
  // used to create a decode state from a batch
232
  llama_kv_cache_unified_state(
233
  llama_kv_cache_unified * kv,
 
234
  ubatch_heads heads,
235
  std::vector<llama_ubatch> ubatches);
236
 
 
243
  bool next() override;
244
  bool apply() override;
245
 
 
 
246
  llama_memory_status get_status() const override;
247
  const llama_ubatch & get_ubatch() const override;
248
 
 
283
  // batch processing state
284
  //
285
 
 
 
286
  // the index of the next ubatch to process
287
  size_t i_next = 0;
288
 
examples/talk-llama/llama-kv-cells.h CHANGED
@@ -384,10 +384,10 @@ private:
384
  //
385
  std::vector<llama_pos> shift;
386
 
387
- using bits_t = std::bitset<LLAMA_MAX_SEQ>;
388
 
389
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390
- std::vector<bits_t> seq;
391
 
392
  // the set seq_pos[s] tells us which positions are currently present for sequence s
393
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
 
384
  //
385
  std::vector<llama_pos> shift;
386
 
387
+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
388
 
389
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390
+ std::vector<seq_set_t> seq;
391
 
392
  // the set seq_pos[s] tells us which positions are currently present for sequence s
393
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
examples/talk-llama/llama-memory-hybrid.cpp ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-memory-hybrid.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-model.h"
5
+ #include "llama-context.h"
6
+
7
+ //
8
+ // llama_memory_hybrid
9
+ //
10
+
11
+ llama_memory_hybrid::llama_memory_hybrid(
12
+ const llama_model & model,
13
+ /* attn */
14
+ ggml_type type_k,
15
+ ggml_type type_v,
16
+ bool v_trans,
17
+ uint32_t kv_size,
18
+ uint32_t n_pad,
19
+ uint32_t n_swa,
20
+ llama_swa_type swa_type,
21
+ /* recurrent */
22
+ ggml_type type_r,
23
+ ggml_type type_s,
24
+ uint32_t rs_size,
25
+ /* common */
26
+ uint32_t n_seq_max,
27
+ bool offload,
28
+ /* layer filters */
29
+ layer_filter_cb && filter_attn,
30
+ layer_filter_cb && filter_recr) :
31
+ hparams(model.hparams),
32
+ mem_attn(new llama_kv_cache_unified(
33
+ model,
34
+ filter_attn == nullptr ?
35
+ [&](int32_t il) { return !hparams.is_recurrent(il); }
36
+ : filter_attn,
37
+ type_k,
38
+ type_v,
39
+ v_trans,
40
+ offload,
41
+ kv_size,
42
+ n_seq_max,
43
+ n_pad,
44
+ n_swa,
45
+ swa_type
46
+ )),
47
+ mem_recr(new llama_memory_recurrent(
48
+ model,
49
+ filter_recr == nullptr ?
50
+ [&](int32_t il) { return hparams.is_recurrent(il); }
51
+ : filter_recr,
52
+ type_r,
53
+ type_s,
54
+ offload,
55
+ rs_size,
56
+ n_seq_max
57
+ )) {}
58
+
59
+ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
60
+ do {
61
+ balloc.split_reset();
62
+
63
+ // follow the recurrent pattern for creating the ubatch splits
64
+ std::vector<llama_ubatch> ubatches;
65
+
66
+ while (true) {
67
+ llama_ubatch ubatch;
68
+
69
+ if (embd_all) {
70
+ // if all tokens are output, split by sequence
71
+ ubatch = balloc.split_seq(n_ubatch);
72
+ } else {
73
+ ubatch = balloc.split_equal(n_ubatch);
74
+ }
75
+
76
+ if (ubatch.n_tokens == 0) {
77
+ break;
78
+ }
79
+
80
+ ubatches.push_back(std::move(ubatch)); // NOLINT
81
+ }
82
+
83
+ // prepare the recurrent batches first
84
+ if (!mem_recr->prepare(ubatches)) {
85
+ // TODO: will the recurrent cache be in an undefined state at this point?
86
+ LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
87
+ return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
88
+ }
89
+
90
+ // prepare the attention cache
91
+ auto heads_attn = mem_attn->prepare(ubatches);
92
+ if (heads_attn.empty()) {
93
+ LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
94
+ return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
95
+ }
96
+
97
+ return std::make_unique<llama_memory_hybrid_state>(
98
+ this, std::move(heads_attn), std::move(ubatches));
99
+ } while(false);
100
+
101
+ return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
102
+ }
103
+
104
+ llama_memory_state_ptr llama_memory_hybrid::init_full() {
105
+ return std::make_unique<llama_memory_hybrid_state>(this);
106
+ }
107
+
108
+ llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
109
+ return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
110
+ }
111
+
112
+ bool llama_memory_hybrid::get_can_shift() const {
113
+ // Shifting is trivially supported for recurrent
114
+ return mem_attn->get_can_shift();
115
+ }
116
+
117
+ void llama_memory_hybrid::clear(bool data) {
118
+ mem_attn->clear(data);
119
+ mem_recr->clear(data);
120
+ }
121
+
122
+ bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
123
+ // Try removing from the recurrent cache first since it may fail. If it does
124
+ // fail, the cache will not have been mutated.
125
+ if (!mem_recr->seq_rm(seq_id, p0, p1)) {
126
+ return false;
127
+ }
128
+ return mem_attn->seq_rm(seq_id, p0, p1);
129
+ }
130
+
131
+ void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
132
+ mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
133
+ mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
134
+ }
135
+
136
+ void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
137
+ mem_attn->seq_keep(seq_id);
138
+ mem_recr->seq_keep(seq_id);
139
+ }
140
+
141
+ void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
142
+ mem_attn->seq_add(seq_id, p0, p1, shift);
143
+ mem_recr->seq_add(seq_id, p0, p1, shift);
144
+ }
145
+
146
+ void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
147
+ mem_attn->seq_div(seq_id, p0, p1, d);
148
+ mem_recr->seq_div(seq_id, p0, p1, d);
149
+ }
150
+
151
+ llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
152
+ // the min of the total cache is the max of the two caches' min values
153
+ return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
154
+ }
155
+
156
+ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
157
+ // the max of the total cache is the min of the two caches' max values
158
+ return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
159
+ }
160
+
161
+ void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
162
+ mem_attn->state_write(io, seq_id);
163
+ mem_recr->state_write(io, seq_id);
164
+ }
165
+
166
+ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
167
+ mem_attn->state_read(io, seq_id);
168
+ mem_recr->state_read(io, seq_id);
169
+ }
170
+
171
+ llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
172
+ return mem_attn.get();
173
+ }
174
+
175
+ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
176
+ return mem_recr.get();
177
+ }
178
+
179
+ llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
180
+
181
+ llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
182
+ state_attn(mem->get_mem_attn()->init_full()),
183
+ state_recr(mem->get_mem_recr()->init_full()),
184
+ status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
185
+ }
186
+
187
+ llama_memory_hybrid_state::llama_memory_hybrid_state(
188
+ llama_memory_hybrid * mem,
189
+ llama_context * lctx,
190
+ bool optimize) :
191
+ state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
192
+ state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
193
+ status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
194
+ }
195
+
196
+ llama_memory_hybrid_state::llama_memory_hybrid_state(
197
+ llama_memory_hybrid * mem,
198
+ std::vector<uint32_t> heads_attn,
199
+ std::vector<llama_ubatch> ubatches) :
200
+ ubatches(std::move(ubatches)),
201
+ // note: here we copy the ubatches. not sure if this is ideal
202
+ state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
203
+ state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), this->ubatches)),
204
+ status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
205
+ }
206
+
207
+ bool llama_memory_hybrid_state::next() {
208
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
209
+
210
+ state_attn->next();
211
+ state_recr->next();
212
+
213
+ if (++i_next >= ubatches.size()) {
214
+ return false;
215
+ }
216
+
217
+ return true;
218
+ }
219
+
220
+ bool llama_memory_hybrid_state::apply() {
221
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
222
+
223
+ bool res = true;
224
+
225
+ res = res & state_attn->apply();
226
+ res = res & state_recr->apply();
227
+
228
+ return res;
229
+ }
230
+
231
+ llama_memory_status llama_memory_hybrid_state::get_status() const {
232
+ return status;
233
+ }
234
+
235
+ const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
236
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
237
+ return ubatches[i_next];
238
+ }
239
+
240
+ const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
241
+ return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
242
+ }
243
+
244
+ const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
245
+ return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
246
+ }
examples/talk-llama/llama-memory-hybrid.h ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama-batch.h"
4
+ #include "llama-graph.h"
5
+ #include "llama-kv-cache-unified.h"
6
+ #include "llama-memory.h"
7
+ #include "llama-memory-recurrent.h"
8
+
9
+ #include <memory>
10
+ #include <vector>
11
+
12
+ //
13
+ // llama_memory_hybrid
14
+ //
15
+
16
+ // utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
17
+ // support models where each layer may be either attention-based or recurrent
18
+
19
+ class llama_memory_hybrid : public llama_memory_i {
20
+ public:
21
+
22
+ // this callback is used to filter out layers that should not be included in the cache
23
+ using layer_filter_cb = std::function<bool(int32_t il)>;
24
+
25
+ llama_memory_hybrid(
26
+ const llama_model & model,
27
+ /* attn */
28
+ ggml_type type_k,
29
+ ggml_type type_v,
30
+ bool v_trans,
31
+ uint32_t kv_size,
32
+ uint32_t n_pad,
33
+ uint32_t n_swa,
34
+ llama_swa_type swa_type,
35
+ /* recurrent */
36
+ ggml_type type_r,
37
+ ggml_type type_s,
38
+ uint32_t rs_size,
39
+ /* common */
40
+ uint32_t n_seq_max,
41
+ bool offload,
42
+ /* layer filters */
43
+ layer_filter_cb && filter_attn = nullptr,
44
+ layer_filter_cb && filter_recr = nullptr);
45
+
46
+ ~llama_memory_hybrid() = default;
47
+
48
+ //
49
+ // llama_memory_i
50
+ //
51
+
52
+ llama_memory_state_ptr init_batch(
53
+ llama_batch_allocr & balloc,
54
+ uint32_t n_ubatch,
55
+ bool embd_all) override;
56
+
57
+ llama_memory_state_ptr init_full() override;
58
+
59
+ llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
60
+
61
+ bool get_can_shift() const override;
62
+
63
+ void clear(bool data) override;
64
+
65
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
66
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
67
+ void seq_keep(llama_seq_id seq_id) override;
68
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
69
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
70
+
71
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
72
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
73
+
74
+ // state write/load
75
+
76
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
77
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
78
+
79
+ //
80
+ // llama_memory_hybrid specific API
81
+ //
82
+
83
+ llama_kv_cache_unified * get_mem_attn() const;
84
+ llama_memory_recurrent * get_mem_recr() const;
85
+
86
+ private:
87
+ const llama_hparams & hparams;
88
+
89
+ const std::unique_ptr<llama_kv_cache_unified> mem_attn;
90
+ const std::unique_ptr<llama_memory_recurrent> mem_recr;
91
+ };
92
+
93
+ class llama_memory_hybrid_state : public llama_memory_state_i {
94
+ public:
95
+ // init failure
96
+ explicit llama_memory_hybrid_state(llama_memory_status status);
97
+
98
+ // init full
99
+ explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
100
+
101
+ // init update
102
+ explicit llama_memory_hybrid_state(
103
+ llama_memory_hybrid * mem,
104
+ llama_context * lctx,
105
+ bool optimize);
106
+
107
+ // init success
108
+ llama_memory_hybrid_state(
109
+ llama_memory_hybrid * mem,
110
+ std::vector<uint32_t> heads_attn,
111
+ std::vector<llama_ubatch> ubatches);
112
+
113
+ ~llama_memory_hybrid_state() = default;
114
+
115
+ bool next() override;
116
+ bool apply() override;
117
+
118
+ llama_memory_status get_status() const override;
119
+ const llama_ubatch & get_ubatch() const override;
120
+
121
+ //
122
+ // llama_memory_hybrid_state
123
+ //
124
+
125
+ const llama_kv_cache_unified_state * get_state_attn() const;
126
+ const llama_memory_recurrent_state * get_state_recr() const;
127
+
128
+ private:
129
+ // the index of the next ubatch to process
130
+ size_t i_next = 0;
131
+
132
+ std::vector<llama_ubatch> ubatches;
133
+
134
+ const llama_memory_state_ptr state_attn;
135
+ const llama_memory_state_ptr state_recr;
136
+
137
+ const llama_memory_status status;
138
+ };
examples/talk-llama/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} RENAMED
@@ -1,4 +1,4 @@
1
- #include "llama-kv-cache-recurrent.h"
2
 
3
  #include "llama-impl.h"
4
  #include "llama-io.h"
@@ -12,27 +12,28 @@
12
  #include <stdexcept>
13
 
14
  //
15
- // llama_kv_cache_recurrent
16
  //
17
 
18
- llama_kv_cache_recurrent::llama_kv_cache_recurrent(
19
- const llama_model & model,
20
- ggml_type type_k,
21
- ggml_type type_v,
22
- bool offload,
23
- uint32_t kv_size,
24
- uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
 
25
  const int32_t n_layer = hparams.n_layer;
26
 
27
- LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
28
- __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
29
 
30
  head = 0;
31
- size = kv_size;
32
  used = 0;
33
 
34
  cells.clear();
35
- cells.resize(kv_size);
36
 
37
  // create a context for each buffer type
38
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -59,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
59
  return it->second;
60
  };
61
 
62
- k_l.reserve(n_layer);
63
- v_l.reserve(n_layer);
64
 
65
  for (int i = 0; i < n_layer; i++) {
66
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
67
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
 
 
68
 
69
  const char * dev_name = "CPU";
70
 
@@ -84,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
84
  throw std::runtime_error("failed to create ggml context for kv cache");
85
  }
86
 
87
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
88
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
89
- ggml_format_name(k, "cache_k_l%d", i);
90
- ggml_format_name(v, "cache_v_l%d", i);
91
- k_l.push_back(k);
92
- v_l.push_back(v);
93
  }
94
 
95
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -107,17 +110,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
107
  }
108
 
109
  {
110
- const size_t memory_size_k = size_k_bytes();
111
- const size_t memory_size_v = size_v_bytes();
112
 
113
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
114
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
115
- ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
116
- ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
117
  }
118
  }
119
 
120
- void llama_kv_cache_recurrent::clear(bool data) {
121
  for (int32_t i = 0; i < (int32_t) size; ++i) {
122
  cells[i].pos = -1;
123
  cells[i].seq_id.clear();
@@ -135,7 +138,7 @@ void llama_kv_cache_recurrent::clear(bool data) {
135
  }
136
  }
137
 
138
- bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
139
  uint32_t new_head = size;
140
 
141
  if (p0 < 0) {
@@ -154,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
154
  if (0 <= seq_id) {
155
  int32_t & tail_id = cells[seq_id].tail;
156
  if (tail_id >= 0) {
157
- const kv_cell & cell = cells[tail_id];
158
  // partial intersection is invalid
159
  if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
160
  return false;
@@ -202,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
202
  return true;
203
  }
204
 
205
- void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
206
  if (seq_id_src == seq_id_dst) {
207
  return;
208
  }
@@ -216,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
216
  }
217
 
218
  if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
219
- kv_cell & tail_src = cells[seq_id_src];
220
- kv_cell & tail_dst = cells[seq_id_dst];
221
  if (tail_dst.tail >= 0) {
222
  // clear destination seq_id if it wasn't empty
223
- kv_cell & cell_dst = cells[tail_dst.tail];
224
 
225
  cell_dst.seq_id.erase(seq_id_dst);
226
  tail_dst.tail = -1;
@@ -231,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
231
  }
232
  }
233
  if (tail_src.tail >= 0) {
234
- kv_cell & cell_src = cells[tail_src.tail];
235
 
236
  cell_src.seq_id.insert(seq_id_dst);
237
  tail_dst.tail = tail_src.tail;
@@ -239,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
239
  }
240
  }
241
 
242
- void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
243
  uint32_t new_head = size;
244
 
245
  for (uint32_t i = 0; i < size; ++i) {
@@ -271,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
271
  }
272
  }
273
 
274
- void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
275
  if (shift == 0) {
276
  return;
277
  }
@@ -293,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
293
  if (0 <= seq_id && seq_id < (int64_t) size) {
294
  const int32_t tail_id = cells[seq_id].tail;
295
  if (tail_id >= 0) {
296
- kv_cell & cell = cells[tail_id];
297
  if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
298
  cell.pos += shift;
299
  }
@@ -301,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
301
  }
302
  }
303
 
304
- void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
305
  if (d == 1) {
306
  return;
307
  }
@@ -323,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
323
  if (0 <= seq_id && seq_id < (int64_t) size) {
324
  const int32_t tail_id = cells[seq_id].tail;
325
  if (tail_id >= 0) {
326
- kv_cell & cell = cells[tail_id];
327
  if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
328
  cell.pos /= d;
329
  }
@@ -331,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
331
  }
332
  }
333
 
334
- llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
335
  llama_pos result = std::numeric_limits<llama_pos>::max();
336
 
337
  for (uint32_t i = 0; i < size; ++i) {
@@ -347,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
347
  return result;
348
  }
349
 
350
- llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
351
  llama_pos result = -1;
352
 
353
  for (uint32_t i = 0; i < size; ++i) {
@@ -359,43 +362,45 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
359
  return result;
360
  }
361
 
362
- llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
363
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
364
-
365
  std::vector<llama_ubatch> ubatches;
366
 
367
- while (sbatch.n_tokens > 0) {
368
  llama_ubatch ubatch;
369
 
370
  if (embd_all) {
371
  // if all tokens are output, split by sequence
372
- ubatch = sbatch.split_seq(n_ubatch);
373
  } else {
374
- ubatch = sbatch.split_equal(n_ubatch);
375
  }
376
 
377
- ubatches.push_back(ubatch);
 
 
 
 
378
  }
379
 
380
  if (!prepare(ubatches)) {
381
- return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
382
  }
383
 
384
- return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
385
  }
386
 
387
- llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
388
- return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
389
  }
390
 
391
- llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
392
  GGML_UNUSED(lctx);
393
  GGML_UNUSED(optimize);
394
 
395
- return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
396
  }
397
 
398
- bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
399
  // simply remember the full state because it is very small for this type of cache
400
  // TODO: optimize
401
  auto org_cells = cells;
@@ -419,10 +424,9 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
419
  return success;
420
  }
421
 
422
- bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
423
- const uint32_t n_seqs = ubatch.n_seqs;
424
-
425
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
 
426
 
427
  // if we have enough unused cells before the current head ->
428
  // better to start searching from the beginning of the cache, hoping to fill it
@@ -442,9 +446,11 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
442
 
443
  // everything should fit if all seq_ids are smaller than the max
444
  for (uint32_t s = 0; s < n_seqs; ++s) {
445
- const uint32_t n_seq_id = ubatch.n_seq_id[s];
 
 
446
  for (uint32_t j = 0; j < n_seq_id; ++j) {
447
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
448
 
449
  if (seq_id < 0 || (uint32_t) seq_id >= size) {
450
  // too big seq_id
@@ -453,9 +459,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
453
  return false;
454
  }
455
  if (j > 0) {
456
- kv_cell & seq = cells[seq_id];
457
  if (seq.tail >= 0) {
458
- kv_cell & cell = cells[seq.tail];
459
  // clear cells from seq_ids that become shared
460
  // (should not normally happen, but let's handle it anyway)
461
  cell.seq_id.erase(seq_id);
@@ -475,7 +481,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
475
  std::vector<int32_t> tails_verif;
476
  tails_verif.assign(size, -1);
477
  for (uint32_t i = 0; i < size; ++i) {
478
- kv_cell & cell = cells[i];
479
  for (llama_seq_id seq_id : cell.seq_id) {
480
  if (tails_verif[seq_id] != -1) {
481
  LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
@@ -496,28 +502,29 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
496
 
497
  for (uint32_t i = 0; i < size; ++i) {
498
  if (next_empty_cell >= size) { next_empty_cell -= size; }
499
- kv_cell & cell = cells[next_empty_cell];
500
  if (cell.is_empty()) { break; }
501
  next_empty_cell += 1;
502
  }
503
 
504
  // find usable cell range
505
  for (uint32_t s = 0; s < n_seqs; ++s) {
506
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
507
- kv_cell & seq_meta = cells[seq_id];
 
508
  bool has_cell = false;
509
  if (seq_meta.tail >= 0) {
510
- kv_cell & cell = cells[seq_meta.tail];
511
  GGML_ASSERT(cell.has_seq_id(seq_id));
512
  // does this seq_id "own" the cell?
513
  if (cell.seq_id.size() == 1) { has_cell = true; }
514
  }
515
  if (!has_cell) {
516
- kv_cell & empty_cell = cells[next_empty_cell];
517
  GGML_ASSERT(empty_cell.is_empty());
518
  // copy old tail into the empty cell
519
  if (seq_meta.tail >= 0) {
520
- kv_cell & orig_cell = cells[seq_meta.tail];
521
  empty_cell.pos = orig_cell.pos;
522
  empty_cell.src = orig_cell.src;
523
  orig_cell.seq_id.erase(seq_id);
@@ -527,10 +534,10 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
527
  seq_meta.tail = next_empty_cell;
528
  // find next empty cell
529
  if (s + 1 < n_seqs) {
530
- for (uint32_t i = 0; i < size; ++i) {
531
  next_empty_cell += 1;
532
  if (next_empty_cell >= size) { next_empty_cell -= size; }
533
- kv_cell & cell = cells[next_empty_cell];
534
  if (cell.is_empty()) { break; }
535
  }
536
  }
@@ -541,19 +548,20 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
541
 
542
  // gather and re-order
543
  for (uint32_t s = 0; s < n_seqs; ++s) {
 
544
  const int32_t dst_id = s + min;
545
- const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
546
  if (dst_id != src_id) {
547
- kv_cell & dst_cell = cells[dst_id];
548
- kv_cell & src_cell = cells[src_id];
549
 
550
  std::swap(dst_cell.pos, src_cell.pos);
551
  std::swap(dst_cell.src, src_cell.src);
552
  std::swap(dst_cell.seq_id, src_cell.seq_id);
553
 
554
  // swap tails
555
- for (uint32_t i = 0; i < size; ++i) {
556
- int32_t & tail = cells[i].tail;
557
  if (tail == src_id) {
558
  tail = dst_id;
559
  } else if (tail == dst_id) {
@@ -565,20 +573,21 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
565
 
566
  // update the pos of the used seqs
567
  for (uint32_t s = 0; s < n_seqs; ++s) {
568
- const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
 
569
  const int32_t cell_id = s + min;
570
- kv_cell & cell = cells[cell_id];
571
 
572
  if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
573
  // What should happen when the pos backtracks or skips a value?
574
  // Clearing the state mid-batch would require special-casing which isn't done.
575
  LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
576
- __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
577
  }
578
  cell.pos = last_pos;
579
  cell.seq_id.clear();
580
- for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
581
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
582
  cell.seq_id.insert(seq_id);
583
  cells[seq_id].tail = cell_id;
584
  }
@@ -620,18 +629,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
620
  head = min;
621
  n = max - min + 1;
622
  used = std::count_if(cells.begin(), cells.end(),
623
- [](const kv_cell & cell){ return !cell.is_empty(); });
624
 
625
  // sanity check
626
  return n >= n_seqs;
627
  }
628
 
629
- bool llama_kv_cache_recurrent::get_can_shift() const {
630
  // shifting the pos is trivial for recurrent models
631
  return true;
632
  }
633
 
634
- size_t llama_kv_cache_recurrent::total_size() const {
635
  size_t size = 0;
636
  for (const auto & buf : bufs) {
637
  size += ggml_backend_buffer_get_size(buf.get());
@@ -640,27 +649,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
640
  return size;
641
  }
642
 
643
- size_t llama_kv_cache_recurrent::size_k_bytes() const {
644
- size_t size_k_bytes = 0;
645
 
646
- for (const auto & k : k_l) {
647
- size_k_bytes += ggml_nbytes(k);
 
 
648
  }
649
 
650
- return size_k_bytes;
651
  }
652
 
653
- size_t llama_kv_cache_recurrent::size_v_bytes() const {
654
- size_t size_v_bytes = 0;
655
 
656
- for (const auto & v : v_l) {
657
- size_v_bytes += ggml_nbytes(v);
 
 
658
  }
659
 
660
- return size_v_bytes;
661
  }
662
 
663
- void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
664
  std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
665
  uint32_t cell_count = 0;
666
 
@@ -698,7 +711,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
698
  state_write_data(io, cell_ranges);
699
  }
700
 
701
- void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
702
  uint32_t cell_count;
703
  io.read_to(&cell_count, sizeof(cell_count));
704
 
@@ -717,7 +730,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
717
  }
718
  }
719
 
720
- void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
721
  for (const auto & range : cell_ranges) {
722
  for (uint32_t i = range.first; i < range.second; ++i) {
723
  const auto & cell = cells[i];
@@ -736,98 +749,93 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
736
  }
737
  }
738
 
739
- void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
740
- const uint32_t v_trans = 0;
741
  const uint32_t n_layer = hparams.n_layer;
742
 
743
- io.write(&v_trans, sizeof(v_trans));
744
- io.write(&n_layer, sizeof(n_layer));
745
 
746
  std::vector<uint8_t> tmp_buf;
747
 
748
  // Iterate and write all the keys first, each row is a cell
749
  // Get whole range at a time
750
  for (uint32_t il = 0; il < n_layer; ++il) {
751
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
752
 
753
  // Write key type
754
- const int32_t k_type_i = (int32_t)k_l[il]->type;
755
- io.write(&k_type_i, sizeof(k_type_i));
756
 
757
  // Write row size of key
758
- const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
759
- io.write(&k_size_row, sizeof(k_size_row));
760
 
761
  // Read each range of cells of k_size length each into tmp_buf and write out
762
  for (const auto & range : cell_ranges) {
763
  const size_t range_size = range.second - range.first;
764
- const size_t buf_size = range_size * k_size_row;
765
- io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
766
  }
767
  }
768
 
769
- if (!v_trans) {
770
  for (uint32_t il = 0; il < n_layer; ++il) {
771
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
772
 
773
  // Write value type
774
- const int32_t v_type_i = (int32_t)v_l[il]->type;
775
- io.write(&v_type_i, sizeof(v_type_i));
776
 
777
  // Write row size of value
778
- const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
779
- io.write(&v_size_row, sizeof(v_size_row));
780
 
781
- // Read each range of cells of v_size length each into tmp_buf and write out
782
  for (const auto & range : cell_ranges) {
783
  const size_t range_size = range.second - range.first;
784
- const size_t buf_size = range_size * v_size_row;
785
- io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
786
  }
787
  }
788
  } else {
789
  // When v is transposed, we also need the element size and get the element ranges from each row
790
- const uint32_t kv_size = size;
791
  for (uint32_t il = 0; il < n_layer; ++il) {
792
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
793
 
794
  // Write value type
795
- const int32_t v_type_i = (int32_t)v_l[il]->type;
796
- io.write(&v_type_i, sizeof(v_type_i));
797
 
798
  // Write element size
799
- const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
800
- io.write(&v_size_el, sizeof(v_size_el));
801
 
802
  // Write GQA embedding size
803
- io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
804
 
805
  // For each row, we get the element values of each cell
806
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
807
  // Read each range of cells of v_size_el length each into tmp_buf and write out
808
  for (const auto & range : cell_ranges) {
809
  const size_t range_size = range.second - range.first;
810
- const size_t src_offset = (range.first + j * kv_size) * v_size_el;
811
- const size_t buf_size = range_size * v_size_el;
812
- io.write_tensor(v_l[il], src_offset, buf_size);
813
  }
814
  }
815
  }
816
  }
817
  }
818
 
819
- bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
820
  if (dest_seq_id != -1) {
821
  // single sequence
822
 
823
  seq_rm(dest_seq_id, -1, -1);
824
 
825
- llama_sbatch sbatch;
826
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
827
 
828
- batch.n_tokens = cell_count;
829
- batch.n_seq_tokens = cell_count;
830
- batch.n_seqs = 1;
831
 
832
  for (uint32_t i = 0; i < cell_count; ++i) {
833
  llama_pos pos;
@@ -841,12 +849,12 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
841
  return false;
842
  }
843
 
844
- batch.pos[i] = pos;
845
  }
846
- batch.n_seq_id[0] = 1;
847
- batch.seq_id[0] = &dest_seq_id;
848
 
849
- if (!find_slot(batch)) {
850
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
851
  return false;
852
  }
@@ -854,8 +862,8 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
854
  // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
855
  // Assume that this is one contiguous block of cells
856
  GGML_ASSERT(head + cell_count <= size);
857
- GGML_ASSERT(cells[head].pos == batch.pos[0]);
858
- GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
859
  GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
860
  GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
861
  } else {
@@ -869,7 +877,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
869
  clear(true);
870
 
871
  for (uint32_t i = 0; i < cell_count; ++i) {
872
- kv_cell & cell = cells[i];
873
 
874
  llama_pos pos;
875
  uint32_t n_seq_id;
@@ -883,7 +891,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
883
  llama_seq_id seq_id;
884
  io.read_to(&seq_id, sizeof(seq_id));
885
 
886
- // TODO: llama_kv_cache_recurrent should have a notion of max sequences
887
  //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
888
  if (seq_id < 0) {
889
  //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
@@ -915,10 +923,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
915
  return true;
916
  }
917
 
918
- bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
919
- uint32_t v_trans;
920
  uint32_t n_layer;
921
- io.read_to(&v_trans, sizeof(v_trans));
922
  io.read_to(&n_layer, sizeof(n_layer));
923
 
924
  if (n_layer != hparams.n_layer) {
@@ -929,102 +937,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
929
  LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
930
  return false;
931
  }
932
- if (false != (bool) v_trans) {
933
- LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
934
  return false;
935
  }
936
 
937
  // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
938
  for (uint32_t il = 0; il < n_layer; ++il) {
939
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
940
 
941
  // Read type of key
942
- int32_t k_type_i_ref;
943
- io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
944
- const int32_t k_type_i = (int32_t) k_l[il]->type;
945
- if (k_type_i != k_type_i_ref) {
946
- LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
947
  return false;
948
  }
949
 
950
  // Read row size of key
951
- uint64_t k_size_row_ref;
952
- io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
953
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
954
- if (k_size_row != k_size_row_ref) {
955
- LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
956
  return false;
957
  }
958
 
959
  if (cell_count) {
960
  // Read and set the keys for the whole cell range
961
- ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
962
  }
963
  }
964
 
965
- if (!v_trans) {
966
  for (uint32_t il = 0; il < n_layer; ++il) {
967
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
968
 
969
  // Read type of value
970
- int32_t v_type_i_ref;
971
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
972
- const int32_t v_type_i = (int32_t)v_l[il]->type;
973
- if (v_type_i != v_type_i_ref) {
974
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
975
  return false;
976
  }
977
 
978
  // Read row size of value
979
- uint64_t v_size_row_ref;
980
- io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
981
- const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
982
- if (v_size_row != v_size_row_ref) {
983
- LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
984
  return false;
985
  }
986
 
987
  if (cell_count) {
988
  // Read and set the values for the whole cell range
989
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
990
  }
991
  }
992
  } else {
993
  // For each layer, read the values for each cell (transposed)
994
  for (uint32_t il = 0; il < n_layer; ++il) {
995
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
996
 
997
  // Read type of value
998
- int32_t v_type_i_ref;
999
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1000
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1001
- if (v_type_i != v_type_i_ref) {
1002
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1003
  return false;
1004
  }
1005
 
1006
  // Read element size of value
1007
- uint32_t v_size_el_ref;
1008
- io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1009
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
1010
- if (v_size_el != v_size_el_ref) {
1011
- LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1012
  return false;
1013
  }
1014
 
1015
- // Read GQA embedding size
1016
- uint32_t n_embd_v_gqa_ref;
1017
- io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1018
- if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1019
- LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1020
  return false;
1021
  }
1022
 
1023
  if (cell_count) {
1024
  // For each row in the transposed matrix, read the values for the whole cell range
1025
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1026
- const size_t dst_offset = (head + j * size) * v_size_el;
1027
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1028
  }
1029
  }
1030
  }
@@ -1034,25 +1040,22 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
1034
  }
1035
 
1036
  //
1037
- // llama_kv_cache_recurrent_state
1038
  //
1039
 
1040
- llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
1041
 
1042
- llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
1043
- llama_memory_status status,
1044
- llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
1045
  }
1046
 
1047
- llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
1048
- llama_memory_status status,
1049
- llama_kv_cache_recurrent * kv,
1050
- llama_sbatch sbatch,
1051
- std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
1052
 
1053
- llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
1054
 
1055
- bool llama_kv_cache_recurrent_state::next() {
1056
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1057
 
1058
  if (++i_next >= ubatches.size()) {
@@ -1062,54 +1065,48 @@ bool llama_kv_cache_recurrent_state::next() {
1062
  return true;
1063
  }
1064
 
1065
- bool llama_kv_cache_recurrent_state::apply() {
1066
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1067
 
1068
- kv->find_slot(ubatches[i_next]);
1069
 
1070
  return true;
1071
  }
1072
 
1073
- std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
1074
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1075
-
1076
- return sbatch.out_ids;
1077
- }
1078
-
1079
- llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
1080
  return status;
1081
  }
1082
 
1083
- const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
1084
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1085
 
1086
  return ubatches[i_next];
1087
  }
1088
 
1089
- uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
1090
- return is_full ? kv->size : kv->n;
1091
  }
1092
 
1093
- uint32_t llama_kv_cache_recurrent_state::get_head() const {
1094
- return is_full ? 0 : kv->head;
1095
  }
1096
 
1097
- int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
1098
- return is_full ? 0 : kv->rs_z;
1099
  }
1100
 
1101
- uint32_t llama_kv_cache_recurrent_state::get_size() const {
1102
- return kv->size;
1103
  }
1104
 
1105
- ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
1106
- return kv->k_l[il];
1107
  }
1108
 
1109
- ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
1110
- return kv->v_l[il];
1111
  }
1112
 
1113
- int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
1114
- return kv->cells[i + kv->head].src0;
1115
  }
 
1
+ #include "llama-memory-recurrent.h"
2
 
3
  #include "llama-impl.h"
4
  #include "llama-io.h"
 
12
  #include <stdexcept>
13
 
14
  //
15
+ // llama_memory_recurrent
16
  //
17
 
18
+ llama_memory_recurrent::llama_memory_recurrent(
19
+ const llama_model & model,
20
+ layer_filter_cb && filter,
21
+ ggml_type type_r,
22
+ ggml_type type_s,
23
+ bool offload,
24
+ uint32_t mem_size,
25
+ uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
26
  const int32_t n_layer = hparams.n_layer;
27
 
28
+ LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
29
+ __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
30
 
31
  head = 0;
32
+ size = mem_size;
33
  used = 0;
34
 
35
  cells.clear();
36
+ cells.resize(mem_size);
37
 
38
  // create a context for each buffer type
39
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
 
60
  return it->second;
61
  };
62
 
63
+ r_l.resize(n_layer);
64
+ s_l.resize(n_layer);
65
 
66
  for (int i = 0; i < n_layer; i++) {
67
+ if (filter && !filter(i)) {
68
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
69
+ continue;
70
+ }
71
 
72
  const char * dev_name = "CPU";
73
 
 
87
  throw std::runtime_error("failed to create ggml context for kv cache");
88
  }
89
 
90
+ ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
91
+ ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
92
+ ggml_format_name(r, "cache_r_l%d", i);
93
+ ggml_format_name(s, "cache_s_l%d", i);
94
+ r_l[i] = r;
95
+ s_l[i] = s;
96
  }
97
 
98
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
 
110
  }
111
 
112
  {
113
+ const size_t memory_size_r = size_r_bytes();
114
+ const size_t memory_size_s = size_s_bytes();
115
 
116
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
117
+ (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
118
+ ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
119
+ ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
120
  }
121
  }
122
 
123
+ void llama_memory_recurrent::clear(bool data) {
124
  for (int32_t i = 0; i < (int32_t) size; ++i) {
125
  cells[i].pos = -1;
126
  cells[i].seq_id.clear();
 
138
  }
139
  }
140
 
141
+ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
142
  uint32_t new_head = size;
143
 
144
  if (p0 < 0) {
 
157
  if (0 <= seq_id) {
158
  int32_t & tail_id = cells[seq_id].tail;
159
  if (tail_id >= 0) {
160
+ const auto & cell = cells[tail_id];
161
  // partial intersection is invalid
162
  if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
163
  return false;
 
205
  return true;
206
  }
207
 
208
+ void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
209
  if (seq_id_src == seq_id_dst) {
210
  return;
211
  }
 
219
  }
220
 
221
  if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
222
+ auto & tail_src = cells[seq_id_src];
223
+ auto & tail_dst = cells[seq_id_dst];
224
  if (tail_dst.tail >= 0) {
225
  // clear destination seq_id if it wasn't empty
226
+ auto & cell_dst = cells[tail_dst.tail];
227
 
228
  cell_dst.seq_id.erase(seq_id_dst);
229
  tail_dst.tail = -1;
 
234
  }
235
  }
236
  if (tail_src.tail >= 0) {
237
+ auto & cell_src = cells[tail_src.tail];
238
 
239
  cell_src.seq_id.insert(seq_id_dst);
240
  tail_dst.tail = tail_src.tail;
 
242
  }
243
  }
244
 
245
+ void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
246
  uint32_t new_head = size;
247
 
248
  for (uint32_t i = 0; i < size; ++i) {
 
274
  }
275
  }
276
 
277
+ void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
278
  if (shift == 0) {
279
  return;
280
  }
 
296
  if (0 <= seq_id && seq_id < (int64_t) size) {
297
  const int32_t tail_id = cells[seq_id].tail;
298
  if (tail_id >= 0) {
299
+ auto & cell = cells[tail_id];
300
  if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
301
  cell.pos += shift;
302
  }
 
304
  }
305
  }
306
 
307
+ void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
308
  if (d == 1) {
309
  return;
310
  }
 
326
  if (0 <= seq_id && seq_id < (int64_t) size) {
327
  const int32_t tail_id = cells[seq_id].tail;
328
  if (tail_id >= 0) {
329
+ auto & cell = cells[tail_id];
330
  if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
331
  cell.pos /= d;
332
  }
 
334
  }
335
  }
336
 
337
+ llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
338
  llama_pos result = std::numeric_limits<llama_pos>::max();
339
 
340
  for (uint32_t i = 0; i < size; ++i) {
 
350
  return result;
351
  }
352
 
353
+ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
354
  llama_pos result = -1;
355
 
356
  for (uint32_t i = 0; i < size; ++i) {
 
362
  return result;
363
  }
364
 
365
+ llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
 
 
366
  std::vector<llama_ubatch> ubatches;
367
 
368
+ while (true) {
369
  llama_ubatch ubatch;
370
 
371
  if (embd_all) {
372
  // if all tokens are output, split by sequence
373
+ ubatch = balloc.split_seq(n_ubatch);
374
  } else {
375
+ ubatch = balloc.split_equal(n_ubatch);
376
  }
377
 
378
+ if (ubatch.n_tokens == 0) {
379
+ break;
380
+ }
381
+
382
+ ubatches.push_back(std::move(ubatch)); // NOLINT
383
  }
384
 
385
  if (!prepare(ubatches)) {
386
+ return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
387
  }
388
 
389
+ return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches));
390
  }
391
 
392
+ llama_memory_state_ptr llama_memory_recurrent::init_full() {
393
+ return std::make_unique<llama_memory_recurrent_state>(this);
394
  }
395
 
396
+ llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
397
  GGML_UNUSED(lctx);
398
  GGML_UNUSED(optimize);
399
 
400
+ return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
401
  }
402
 
403
+ bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
404
  // simply remember the full state because it is very small for this type of cache
405
  // TODO: optimize
406
  auto org_cells = cells;
 
424
  return success;
425
  }
426
 
427
+ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
 
 
428
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
429
+ const uint32_t n_seqs = ubatch.n_seqs;
430
 
431
  // if we have enough unused cells before the current head ->
432
  // better to start searching from the beginning of the cache, hoping to fill it
 
446
 
447
  // everything should fit if all seq_ids are smaller than the max
448
  for (uint32_t s = 0; s < n_seqs; ++s) {
449
+ const uint32_t i = s*n_seq_tokens; // first token of sequence set s
450
+ const uint32_t n_seq_id = ubatch.n_seq_id[i];
451
+
452
  for (uint32_t j = 0; j < n_seq_id; ++j) {
453
+ const llama_seq_id seq_id = ubatch.seq_id[i][j];
454
 
455
  if (seq_id < 0 || (uint32_t) seq_id >= size) {
456
  // too big seq_id
 
459
  return false;
460
  }
461
  if (j > 0) {
462
+ auto & seq = cells[seq_id];
463
  if (seq.tail >= 0) {
464
+ auto & cell = cells[seq.tail];
465
  // clear cells from seq_ids that become shared
466
  // (should not normally happen, but let's handle it anyway)
467
  cell.seq_id.erase(seq_id);
 
481
  std::vector<int32_t> tails_verif;
482
  tails_verif.assign(size, -1);
483
  for (uint32_t i = 0; i < size; ++i) {
484
+ auto & cell = cells[i];
485
  for (llama_seq_id seq_id : cell.seq_id) {
486
  if (tails_verif[seq_id] != -1) {
487
  LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
 
502
 
503
  for (uint32_t i = 0; i < size; ++i) {
504
  if (next_empty_cell >= size) { next_empty_cell -= size; }
505
+ auto & cell = cells[next_empty_cell];
506
  if (cell.is_empty()) { break; }
507
  next_empty_cell += 1;
508
  }
509
 
510
  // find usable cell range
511
  for (uint32_t s = 0; s < n_seqs; ++s) {
512
+ const uint32_t i = s*n_seq_tokens;
513
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
514
+ auto & seq_meta = cells[seq_id];
515
  bool has_cell = false;
516
  if (seq_meta.tail >= 0) {
517
+ auto & cell = cells[seq_meta.tail];
518
  GGML_ASSERT(cell.has_seq_id(seq_id));
519
  // does this seq_id "own" the cell?
520
  if (cell.seq_id.size() == 1) { has_cell = true; }
521
  }
522
  if (!has_cell) {
523
+ auto & empty_cell = cells[next_empty_cell];
524
  GGML_ASSERT(empty_cell.is_empty());
525
  // copy old tail into the empty cell
526
  if (seq_meta.tail >= 0) {
527
+ auto & orig_cell = cells[seq_meta.tail];
528
  empty_cell.pos = orig_cell.pos;
529
  empty_cell.src = orig_cell.src;
530
  orig_cell.seq_id.erase(seq_id);
 
534
  seq_meta.tail = next_empty_cell;
535
  // find next empty cell
536
  if (s + 1 < n_seqs) {
537
+ for (uint32_t j = 0; j < size; ++j) {
538
  next_empty_cell += 1;
539
  if (next_empty_cell >= size) { next_empty_cell -= size; }
540
+ auto & cell = cells[next_empty_cell];
541
  if (cell.is_empty()) { break; }
542
  }
543
  }
 
548
 
549
  // gather and re-order
550
  for (uint32_t s = 0; s < n_seqs; ++s) {
551
+ const uint32_t i = s*n_seq_tokens;
552
  const int32_t dst_id = s + min;
553
+ const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
554
  if (dst_id != src_id) {
555
+ auto & dst_cell = cells[dst_id];
556
+ auto & src_cell = cells[src_id];
557
 
558
  std::swap(dst_cell.pos, src_cell.pos);
559
  std::swap(dst_cell.src, src_cell.src);
560
  std::swap(dst_cell.seq_id, src_cell.seq_id);
561
 
562
  // swap tails
563
+ for (uint32_t j = 0; j < size; ++j) {
564
+ int32_t & tail = cells[j].tail;
565
  if (tail == src_id) {
566
  tail = dst_id;
567
  } else if (tail == dst_id) {
 
573
 
574
  // update the pos of the used seqs
575
  for (uint32_t s = 0; s < n_seqs; ++s) {
576
+ const uint32_t i = s*n_seq_tokens;
577
+ const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
578
  const int32_t cell_id = s + min;
579
+ auto & cell = cells[cell_id];
580
 
581
  if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
582
  // What should happen when the pos backtracks or skips a value?
583
  // Clearing the state mid-batch would require special-casing which isn't done.
584
  LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
585
+ __func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
586
  }
587
  cell.pos = last_pos;
588
  cell.seq_id.clear();
589
+ for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
590
+ const llama_seq_id seq_id = ubatch.seq_id[i][j];
591
  cell.seq_id.insert(seq_id);
592
  cells[seq_id].tail = cell_id;
593
  }
 
629
  head = min;
630
  n = max - min + 1;
631
  used = std::count_if(cells.begin(), cells.end(),
632
+ [](const mem_cell & cell){ return !cell.is_empty(); });
633
 
634
  // sanity check
635
  return n >= n_seqs;
636
  }
637
 
638
+ bool llama_memory_recurrent::get_can_shift() const {
639
  // shifting the pos is trivial for recurrent models
640
  return true;
641
  }
642
 
643
+ size_t llama_memory_recurrent::total_size() const {
644
  size_t size = 0;
645
  for (const auto & buf : bufs) {
646
  size += ggml_backend_buffer_get_size(buf.get());
 
649
  return size;
650
  }
651
 
652
+ size_t llama_memory_recurrent::size_r_bytes() const {
653
+ size_t size_r_bytes = 0;
654
 
655
+ for (const auto & r : r_l) {
656
+ if (r != nullptr) {
657
+ size_r_bytes += ggml_nbytes(r);
658
+ }
659
  }
660
 
661
+ return size_r_bytes;
662
  }
663
 
664
+ size_t llama_memory_recurrent::size_s_bytes() const {
665
+ size_t size_s_bytes = 0;
666
 
667
+ for (const auto & s : s_l) {
668
+ if (s != nullptr) {
669
+ size_s_bytes += ggml_nbytes(s);
670
+ }
671
  }
672
 
673
+ return size_s_bytes;
674
  }
675
 
676
+ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
677
  std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
678
  uint32_t cell_count = 0;
679
 
 
711
  state_write_data(io, cell_ranges);
712
  }
713
 
714
+ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
715
  uint32_t cell_count;
716
  io.read_to(&cell_count, sizeof(cell_count));
717
 
 
730
  }
731
  }
732
 
733
+ void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
734
  for (const auto & range : cell_ranges) {
735
  for (uint32_t i = range.first; i < range.second; ++i) {
736
  const auto & cell = cells[i];
 
749
  }
750
  }
751
 
752
+ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
753
+ const uint32_t s_trans = 0;
754
  const uint32_t n_layer = hparams.n_layer;
755
 
756
+ io.write(&s_trans, sizeof(s_trans));
757
+ io.write(&n_layer, sizeof(n_layer));
758
 
759
  std::vector<uint8_t> tmp_buf;
760
 
761
  // Iterate and write all the keys first, each row is a cell
762
  // Get whole range at a time
763
  for (uint32_t il = 0; il < n_layer; ++il) {
 
764
 
765
  // Write key type
766
+ const int32_t r_type_i = (int32_t)r_l[il]->type;
767
+ io.write(&r_type_i, sizeof(r_type_i));
768
 
769
  // Write row size of key
770
+ const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
771
+ io.write(&r_size_row, sizeof(r_size_row));
772
 
773
  // Read each range of cells of k_size length each into tmp_buf and write out
774
  for (const auto & range : cell_ranges) {
775
  const size_t range_size = range.second - range.first;
776
+ const size_t buf_size = range_size * r_size_row;
777
+ io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
778
  }
779
  }
780
 
781
+ if (!s_trans) {
782
  for (uint32_t il = 0; il < n_layer; ++il) {
 
783
 
784
  // Write value type
785
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
786
+ io.write(&s_type_i, sizeof(s_type_i));
787
 
788
  // Write row size of value
789
+ const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
790
+ io.write(&s_size_row, sizeof(s_size_row));
791
 
792
+ // Read each range of cells of s_size length each into tmp_buf and write out
793
  for (const auto & range : cell_ranges) {
794
  const size_t range_size = range.second - range.first;
795
+ const size_t buf_size = range_size * s_size_row;
796
+ io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
797
  }
798
  }
799
  } else {
800
  // When v is transposed, we also need the element size and get the element ranges from each row
801
+ const uint32_t mem_size = size;
802
  for (uint32_t il = 0; il < n_layer; ++il) {
803
+ const uint32_t n_embd_s = hparams.n_embd_s();
804
 
805
  // Write value type
806
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
807
+ io.write(&s_type_i, sizeof(s_type_i));
808
 
809
  // Write element size
810
+ const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
811
+ io.write(&s_size_el, sizeof(s_size_el));
812
 
813
  // Write GQA embedding size
814
+ io.write(&n_embd_s, sizeof(n_embd_s));
815
 
816
  // For each row, we get the element values of each cell
817
+ for (uint32_t j = 0; j < n_embd_s; ++j) {
818
  // Read each range of cells of v_size_el length each into tmp_buf and write out
819
  for (const auto & range : cell_ranges) {
820
  const size_t range_size = range.second - range.first;
821
+ const size_t src_offset = (range.first + j * mem_size) * s_size_el;
822
+ const size_t buf_size = range_size * s_size_el;
823
+ io.write_tensor(s_l[il], src_offset, buf_size);
824
  }
825
  }
826
  }
827
  }
828
  }
829
 
830
+ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
831
  if (dest_seq_id != -1) {
832
  // single sequence
833
 
834
  seq_rm(dest_seq_id, -1, -1);
835
 
836
+ llama_batch_allocr balloc(hparams.n_pos_per_embd());
 
837
 
838
+ llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
 
 
839
 
840
  for (uint32_t i = 0; i < cell_count; ++i) {
841
  llama_pos pos;
 
849
  return false;
850
  }
851
 
852
+ ubatch.pos[i] = pos;
853
  }
854
+ ubatch.n_seq_id[0] = 1;
855
+ ubatch.seq_id[0] = &dest_seq_id;
856
 
857
+ if (!find_slot(ubatch)) {
858
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
859
  return false;
860
  }
 
862
  // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
863
  // Assume that this is one contiguous block of cells
864
  GGML_ASSERT(head + cell_count <= size);
865
+ GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
866
+ GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
867
  GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
868
  GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
869
  } else {
 
877
  clear(true);
878
 
879
  for (uint32_t i = 0; i < cell_count; ++i) {
880
+ auto & cell = cells[i];
881
 
882
  llama_pos pos;
883
  uint32_t n_seq_id;
 
891
  llama_seq_id seq_id;
892
  io.read_to(&seq_id, sizeof(seq_id));
893
 
894
+ // TODO: llama_memory_recurrent should have a notion of max sequences
895
  //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
896
  if (seq_id < 0) {
897
  //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
 
923
  return true;
924
  }
925
 
926
+ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
927
+ uint32_t s_trans;
928
  uint32_t n_layer;
929
+ io.read_to(&s_trans, sizeof(s_trans));
930
  io.read_to(&n_layer, sizeof(n_layer));
931
 
932
  if (n_layer != hparams.n_layer) {
 
937
  LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
938
  return false;
939
  }
940
+ if (false != (bool) s_trans) {
941
+ LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
942
  return false;
943
  }
944
 
945
  // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
946
  for (uint32_t il = 0; il < n_layer; ++il) {
 
947
 
948
  // Read type of key
949
+ int32_t r_type_i_ref;
950
+ io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
951
+ const int32_t r_type_i = (int32_t) r_l[il]->type;
952
+ if (r_type_i != r_type_i_ref) {
953
+ LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
954
  return false;
955
  }
956
 
957
  // Read row size of key
958
+ uint64_t r_size_row_ref;
959
+ io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
960
+ const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
961
+ if (r_size_row != r_size_row_ref) {
962
+ LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
963
  return false;
964
  }
965
 
966
  if (cell_count) {
967
  // Read and set the keys for the whole cell range
968
+ ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
969
  }
970
  }
971
 
972
+ if (!s_trans) {
973
  for (uint32_t il = 0; il < n_layer; ++il) {
 
974
 
975
  // Read type of value
976
+ int32_t s_type_i_ref;
977
+ io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
978
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
979
+ if (s_type_i != s_type_i_ref) {
980
+ LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
981
  return false;
982
  }
983
 
984
  // Read row size of value
985
+ uint64_t s_size_row_ref;
986
+ io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
987
+ const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
988
+ if (s_size_row != s_size_row_ref) {
989
+ LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
990
  return false;
991
  }
992
 
993
  if (cell_count) {
994
  // Read and set the values for the whole cell range
995
+ ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
996
  }
997
  }
998
  } else {
999
  // For each layer, read the values for each cell (transposed)
1000
  for (uint32_t il = 0; il < n_layer; ++il) {
1001
+ const uint32_t n_embd_s = hparams.n_embd_s();
1002
 
1003
  // Read type of value
1004
+ int32_t s_type_i_ref;
1005
+ io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
1006
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
1007
+ if (s_type_i != s_type_i_ref) {
1008
+ LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
1009
  return false;
1010
  }
1011
 
1012
  // Read element size of value
1013
+ uint32_t s_size_el_ref;
1014
+ io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
1015
+ const size_t s_size_el = ggml_type_size(s_l[il]->type);
1016
+ if (s_size_el != s_size_el_ref) {
1017
+ LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
1018
  return false;
1019
  }
1020
 
1021
+ // Read state embedding size
1022
+ uint32_t n_embd_s_ref;
1023
+ io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
1024
+ if (n_embd_s != n_embd_s_ref) {
1025
+ LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
1026
  return false;
1027
  }
1028
 
1029
  if (cell_count) {
1030
  // For each row in the transposed matrix, read the values for the whole cell range
1031
+ for (uint32_t j = 0; j < n_embd_s; ++j) {
1032
+ const size_t dst_offset = (head + j * size) * s_size_el;
1033
+ ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
1034
  }
1035
  }
1036
  }
 
1040
  }
1041
 
1042
  //
1043
+ // llama_memory_recurrent_state
1044
  //
1045
 
1046
+ llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
1047
 
1048
+ llama_memory_recurrent_state::llama_memory_recurrent_state(
1049
+ llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
 
1050
  }
1051
 
1052
+ llama_memory_recurrent_state::llama_memory_recurrent_state(
1053
+ llama_memory_recurrent * mem,
1054
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
 
 
1055
 
1056
+ llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
1057
 
1058
+ bool llama_memory_recurrent_state::next() {
1059
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1060
 
1061
  if (++i_next >= ubatches.size()) {
 
1065
  return true;
1066
  }
1067
 
1068
+ bool llama_memory_recurrent_state::apply() {
1069
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1070
 
1071
+ mem->find_slot(ubatches[i_next]);
1072
 
1073
  return true;
1074
  }
1075
 
1076
+ llama_memory_status llama_memory_recurrent_state::get_status() const {
 
 
 
 
 
 
1077
  return status;
1078
  }
1079
 
1080
+ const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
1081
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1082
 
1083
  return ubatches[i_next];
1084
  }
1085
 
1086
+ uint32_t llama_memory_recurrent_state::get_n_rs() const {
1087
+ return is_full ? mem->size : mem->n;
1088
  }
1089
 
1090
+ uint32_t llama_memory_recurrent_state::get_head() const {
1091
+ return is_full ? 0 : mem->head;
1092
  }
1093
 
1094
+ int32_t llama_memory_recurrent_state::get_rs_z() const {
1095
+ return is_full ? 0 : mem->rs_z;
1096
  }
1097
 
1098
+ uint32_t llama_memory_recurrent_state::get_size() const {
1099
+ return mem->size;
1100
  }
1101
 
1102
+ ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
1103
+ return mem->r_l[il];
1104
  }
1105
 
1106
+ ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
1107
+ return mem->s_l[il];
1108
  }
1109
 
1110
+ int32_t llama_memory_recurrent_state::s_copy(int i) const {
1111
+ return mem->cells[i + mem->head].src0;
1112
  }
examples/talk-llama/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} RENAMED
@@ -8,29 +8,34 @@
8
  #include <vector>
9
 
10
  //
11
- // llama_kv_cache_recurrent
12
  //
13
 
14
- // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
15
  // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
16
- class llama_kv_cache_recurrent : public llama_memory_i {
17
  public:
18
- llama_kv_cache_recurrent(
19
- const llama_model & model,
20
- ggml_type type_k,
21
- ggml_type type_v,
22
- bool offload,
23
- uint32_t kv_size,
24
- uint32_t n_seq_max);
25
 
26
- ~llama_kv_cache_recurrent() = default;
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  //
29
  // llama_memory_i
30
  //
31
 
32
  llama_memory_state_ptr init_batch(
33
- const llama_batch & batch,
34
  uint32_t n_ubatch,
35
  bool embd_all) override;
36
 
@@ -51,7 +56,7 @@ public:
51
 
52
  bool prepare(const std::vector<llama_ubatch> & ubatches);
53
 
54
- // find a contiguous slot of kv cells and emplace the ubatch there
55
  bool find_slot(const llama_ubatch & ubatch);
56
 
57
  bool get_can_shift() const override;
@@ -72,7 +77,7 @@ public:
72
  int32_t rs_z = -1;
73
 
74
  // TODO: optimize for recurrent state needs
75
- struct kv_cell {
76
  llama_pos pos = -1;
77
  int32_t src = -1; // used to know where states should be copied from
78
  int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
@@ -88,15 +93,16 @@ public:
88
  return seq_id.empty();
89
  }
90
 
91
- bool is_same_seq(const kv_cell & other) const {
92
  return seq_id == other.seq_id;
93
  }
94
  };
95
 
96
- std::vector<kv_cell> cells;
97
 
98
- std::vector<ggml_tensor *> k_l; // per layer
99
- std::vector<ggml_tensor *> v_l;
 
100
 
101
  private:
102
  //const llama_model & model;
@@ -109,8 +115,8 @@ private:
109
 
110
  size_t total_size() const;
111
 
112
- size_t size_k_bytes() const;
113
- size_t size_v_bytes() const;
114
 
115
  void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
116
  void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@@ -119,24 +125,21 @@ private:
119
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
120
  };
121
 
122
- class llama_kv_cache_recurrent_state : public llama_memory_state_i {
123
  public:
124
  // used for errors
125
- llama_kv_cache_recurrent_state(llama_memory_status status);
126
 
127
  // used to create a full-cache state
128
- llama_kv_cache_recurrent_state(
129
- llama_memory_status status,
130
- llama_kv_cache_recurrent * kv);
131
 
132
  // used to create a state from a batch
133
- llama_kv_cache_recurrent_state(
134
- llama_memory_status status,
135
- llama_kv_cache_recurrent * kv,
136
- llama_sbatch sbatch,
137
  std::vector<llama_ubatch> ubatches);
138
 
139
- virtual ~llama_kv_cache_recurrent_state();
140
 
141
  //
142
  // llama_memory_state_i
@@ -145,31 +148,27 @@ public:
145
  bool next() override;
146
  bool apply() override;
147
 
148
- std::vector<int64_t> & out_ids() override;
149
-
150
  llama_memory_status get_status() const override;
151
  const llama_ubatch & get_ubatch() const override;
152
 
153
  //
154
- // llama_kv_cache_recurrent_state specific API
155
  //
156
 
157
- uint32_t get_n_kv() const;
158
  uint32_t get_head() const;
159
  int32_t get_rs_z() const;
160
  uint32_t get_size() const;
161
 
162
- ggml_tensor * get_k_l(int32_t il) const;
163
- ggml_tensor * get_v_l(int32_t il) const;
164
 
165
  int32_t s_copy(int i) const;
166
 
167
  private:
168
  const llama_memory_status status;
169
 
170
- llama_kv_cache_recurrent * kv;
171
-
172
- llama_sbatch sbatch;
173
 
174
  size_t i_next = 0;
175
 
 
8
  #include <vector>
9
 
10
  //
11
+ // llama_memory_recurrent
12
  //
13
 
14
+ // TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
15
  // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
16
+ class llama_memory_recurrent : public llama_memory_i {
17
  public:
 
 
 
 
 
 
 
18
 
19
+ // this callback is used to filter out layers that should not be included in the cache
20
+ using layer_filter_cb = std::function<bool(int32_t il)>;
21
+
22
+ llama_memory_recurrent(
23
+ const llama_model & model,
24
+ layer_filter_cb && filter,
25
+ ggml_type type_r,
26
+ ggml_type type_s,
27
+ bool offload,
28
+ uint32_t mem_size,
29
+ uint32_t n_seq_max);
30
+
31
+ ~llama_memory_recurrent() = default;
32
 
33
  //
34
  // llama_memory_i
35
  //
36
 
37
  llama_memory_state_ptr init_batch(
38
+ llama_batch_allocr & balloc,
39
  uint32_t n_ubatch,
40
  bool embd_all) override;
41
 
 
56
 
57
  bool prepare(const std::vector<llama_ubatch> & ubatches);
58
 
59
+ // find a contiguous slot of memory cells and emplace the ubatch there
60
  bool find_slot(const llama_ubatch & ubatch);
61
 
62
  bool get_can_shift() const override;
 
77
  int32_t rs_z = -1;
78
 
79
  // TODO: optimize for recurrent state needs
80
+ struct mem_cell {
81
  llama_pos pos = -1;
82
  int32_t src = -1; // used to know where states should be copied from
83
  int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
 
93
  return seq_id.empty();
94
  }
95
 
96
+ bool is_same_seq(const mem_cell & other) const {
97
  return seq_id == other.seq_id;
98
  }
99
  };
100
 
101
+ std::vector<mem_cell> cells;
102
 
103
+ // per layer
104
+ std::vector<ggml_tensor *> r_l;
105
+ std::vector<ggml_tensor *> s_l;
106
 
107
  private:
108
  //const llama_model & model;
 
115
 
116
  size_t total_size() const;
117
 
118
+ size_t size_r_bytes() const;
119
+ size_t size_s_bytes() const;
120
 
121
  void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
122
  void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
 
125
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
126
  };
127
 
128
+ class llama_memory_recurrent_state : public llama_memory_state_i {
129
  public:
130
  // used for errors
131
+ llama_memory_recurrent_state(llama_memory_status status);
132
 
133
  // used to create a full-cache state
134
+ llama_memory_recurrent_state(
135
+ llama_memory_recurrent * mem);
 
136
 
137
  // used to create a state from a batch
138
+ llama_memory_recurrent_state(
139
+ llama_memory_recurrent * mem,
 
 
140
  std::vector<llama_ubatch> ubatches);
141
 
142
+ virtual ~llama_memory_recurrent_state();
143
 
144
  //
145
  // llama_memory_state_i
 
148
  bool next() override;
149
  bool apply() override;
150
 
 
 
151
  llama_memory_status get_status() const override;
152
  const llama_ubatch & get_ubatch() const override;
153
 
154
  //
155
+ // llama_memory_recurrent_state specific API
156
  //
157
 
158
+ uint32_t get_n_rs() const;
159
  uint32_t get_head() const;
160
  int32_t get_rs_z() const;
161
  uint32_t get_size() const;
162
 
163
+ ggml_tensor * get_r_l(int32_t il) const;
164
+ ggml_tensor * get_s_l(int32_t il) const;
165
 
166
  int32_t s_copy(int i) const;
167
 
168
  private:
169
  const llama_memory_status status;
170
 
171
+ llama_memory_recurrent * mem;
 
 
172
 
173
  size_t i_next = 0;
174
 
examples/talk-llama/llama-memory.h CHANGED
@@ -7,6 +7,8 @@
7
 
8
  struct llama_ubatch;
9
 
 
 
10
  class llama_io_write_i;
11
  class llama_io_read_i;
12
 
@@ -50,9 +52,6 @@ struct llama_memory_state_i {
50
  // return false on failure
51
  virtual bool apply() = 0;
52
 
53
- // TODO: this might get reworked in the future when refactoring llama_batch
54
- virtual std::vector<int64_t> & out_ids() = 0;
55
-
56
  // get the current ubatch
57
  virtual const llama_ubatch & get_ubatch() const = 0;
58
 
@@ -71,7 +70,7 @@ struct llama_memory_i {
71
  // return a state object containing the ubatches and KV cache state required to process them
72
  // check the llama_memory_state_i::get_status() for the result
73
  virtual llama_memory_state_ptr init_batch(
74
- const llama_batch & batch,
75
  uint32_t n_ubatch,
76
  bool embd_all) = 0;
77
 
 
7
 
8
  struct llama_ubatch;
9
 
10
+ class llama_batch_allocr;
11
+
12
  class llama_io_write_i;
13
  class llama_io_read_i;
14
 
 
52
  // return false on failure
53
  virtual bool apply() = 0;
54
 
 
 
 
55
  // get the current ubatch
56
  virtual const llama_ubatch & get_ubatch() const = 0;
57
 
 
70
  // return a state object containing the ubatches and KV cache state required to process them
71
  // check the llama_memory_state_i::get_status() for the result
72
  virtual llama_memory_state_ptr init_batch(
73
+ llama_batch_allocr & balloc,
74
  uint32_t n_ubatch,
75
  bool embd_all) = 0;
76
 
examples/talk-llama/llama-model-saver.cpp CHANGED
@@ -228,6 +228,7 @@ void llama_model_saver::add_kv_from_model() {
228
  // add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
229
  add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
230
  add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
 
231
  add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
232
  add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
233
  add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());
 
228
  // add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
229
  add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
230
  add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
231
+ add_kv(LLM_KV_TOKENIZER_ADD_SEP, vocab.get_add_sep());
232
  add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
233
  add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
234
  add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());
examples/talk-llama/llama-model.cpp CHANGED
@@ -8,7 +8,8 @@
8
 
9
  #include "llama-kv-cache-unified.h"
10
  #include "llama-kv-cache-unified-iswa.h"
11
- #include "llama-kv-cache-recurrent.h"
 
12
 
13
  #include "ggml-cpp.h"
14
 
@@ -470,6 +471,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
470
  std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
471
  std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
472
  std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
 
 
 
 
473
 
474
  std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
475
 
@@ -4702,6 +4707,8 @@ struct llm_build_llama : public llm_graph_context {
4702
 
4703
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
4704
 
 
 
4705
  for (int il = 0; il < n_layer; ++il) {
4706
  ggml_tensor * inpSA = inpL;
4707
 
@@ -4764,9 +4771,7 @@ struct llm_build_llama : public llm_graph_context {
4764
  cb(cur, "attn_out", il);
4765
  }
4766
 
4767
- if (il == n_layer - 1) {
4768
- // skip computing output for unused tokens
4769
- ggml_tensor * inp_out_ids = build_inp_out_ids();
4770
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
4771
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
4772
  }
@@ -4862,6 +4867,8 @@ struct llm_build_llama_iswa : public llm_graph_context {
4862
 
4863
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
4864
 
 
 
4865
  for (int il = 0; il < n_layer; ++il) {
4866
  ggml_tensor * inpSA = inpL;
4867
 
@@ -4938,9 +4945,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
4938
  cb(cur, "attn_out", il);
4939
  }
4940
 
4941
- if (il == n_layer - 1) {
4942
- // skip computing output for unused tokens
4943
- ggml_tensor * inp_out_ids = build_inp_out_ids();
4944
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
4945
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
4946
  }
@@ -5040,6 +5045,9 @@ struct llm_build_deci : public llm_graph_context {
5040
  auto * inp_attn = build_attn_inp_kv_unified();
5041
 
5042
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
 
 
5043
  for (int il = 0; il < n_layer; ++il) {
5044
  ggml_tensor * inpSA = inpL;
5045
  const int64_t n_head_kv = hparams.n_head_kv(il);
@@ -5113,9 +5121,7 @@ struct llm_build_deci : public llm_graph_context {
5113
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
5114
  }
5115
 
5116
- if (il == n_layer - 1) {
5117
- // skip computing output for unused tokens
5118
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5119
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5120
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5121
  }
@@ -5194,6 +5200,8 @@ struct llm_build_baichuan : public llm_graph_context {
5194
 
5195
  auto * inp_attn = build_attn_inp_kv_unified();
5196
 
 
 
5197
  for (int il = 0; il < n_layer; ++il) {
5198
  ggml_tensor * inpSA = inpL;
5199
 
@@ -5245,9 +5253,7 @@ struct llm_build_baichuan : public llm_graph_context {
5245
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5246
  }
5247
 
5248
- if (il == n_layer - 1) {
5249
- // skip computing output for unused tokens
5250
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5251
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5252
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5253
  }
@@ -5316,6 +5322,8 @@ struct llm_build_xverse : public llm_graph_context {
5316
 
5317
  auto * inp_attn = build_attn_inp_kv_unified();
5318
 
 
 
5319
  for (int il = 0; il < n_layer; ++il) {
5320
  ggml_tensor * inpSA = inpL;
5321
 
@@ -5360,9 +5368,7 @@ struct llm_build_xverse : public llm_graph_context {
5360
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5361
  }
5362
 
5363
- if (il == n_layer - 1) {
5364
- // skip computing output for unused tokens
5365
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5366
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5367
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5368
  }
@@ -5430,6 +5436,8 @@ struct llm_build_falcon : public llm_graph_context {
5430
 
5431
  auto * inp_attn = build_attn_inp_kv_unified();
5432
 
 
 
5433
  for (int il = 0; il < n_layer; ++il) {
5434
  ggml_tensor * attn_norm;
5435
 
@@ -5485,9 +5493,7 @@ struct llm_build_falcon : public llm_graph_context {
5485
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5486
  }
5487
 
5488
- if (il == n_layer - 1) {
5489
- // skip computing output for unused tokens
5490
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5491
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5492
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
5493
  attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
@@ -5556,6 +5562,8 @@ struct llm_build_grok : public llm_graph_context {
5556
 
5557
  auto * inp_attn = build_attn_inp_kv_unified();
5558
 
 
 
5559
  for (int il = 0; il < n_layer; ++il) {
5560
  ggml_tensor * inpSA = inpL;
5561
 
@@ -5615,9 +5623,7 @@ struct llm_build_grok : public llm_graph_context {
5615
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
5616
  }
5617
 
5618
- if (il == n_layer - 1) {
5619
- // skip computing output for unused tokens
5620
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5621
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5622
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5623
  }
@@ -5716,6 +5722,8 @@ struct llm_build_dbrx : public llm_graph_context {
5716
 
5717
  auto * inp_attn = build_attn_inp_kv_unified();
5718
 
 
 
5719
  for (int il = 0; il < n_layer; ++il) {
5720
  ggml_tensor * inpSA = inpL;
5721
 
@@ -5766,9 +5774,7 @@ struct llm_build_dbrx : public llm_graph_context {
5766
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5767
  }
5768
 
5769
- if (il == n_layer - 1) {
5770
- // skip computing output for unused tokens
5771
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5772
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5773
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5774
  }
@@ -5848,6 +5854,8 @@ struct llm_build_starcoder : public llm_graph_context {
5848
  inpL = ggml_add(ctx0, inpL, pos);
5849
  cb(inpL, "inpL", -1);
5850
 
 
 
5851
  for (int il = 0; il < n_layer; ++il) {
5852
  cur = build_norm(inpL,
5853
  model.layers[il].attn_norm,
@@ -5880,9 +5888,7 @@ struct llm_build_starcoder : public llm_graph_context {
5880
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5881
  }
5882
 
5883
- if (il == n_layer - 1) {
5884
- // skip computing output for unused tokens
5885
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5886
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5887
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
5888
  }
@@ -5947,6 +5953,8 @@ struct llm_build_refact : public llm_graph_context {
5947
 
5948
  auto * inp_attn = build_attn_inp_kv_unified();
5949
 
 
 
5950
  for (int il = 0; il < n_layer; ++il) {
5951
  ggml_tensor * inpSA = inpL;
5952
 
@@ -5979,9 +5987,7 @@ struct llm_build_refact : public llm_graph_context {
5979
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5980
  }
5981
 
5982
- if (il == n_layer - 1) {
5983
- // skip computing output for unused tokens
5984
- ggml_tensor * inp_out_ids = build_inp_out_ids();
5985
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5986
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5987
  }
@@ -6067,78 +6073,79 @@ struct llm_build_bert : public llm_graph_context {
6067
 
6068
  auto * inp_attn = build_attn_inp_no_cache();
6069
 
6070
- // iterate layers
 
6071
  for (int il = 0; il < n_layer; ++il) {
6072
  ggml_tensor * cur = inpL;
6073
 
6074
- ggml_tensor * Qcur;
6075
- ggml_tensor * Kcur;
6076
- ggml_tensor * Vcur;
 
6077
 
6078
- // self-attention
6079
- if (model.layers[il].wqkv) {
6080
- cur = build_lora_mm(model.layers[il].wqkv, cur);
6081
- cb(cur, "wqkv", il);
6082
 
6083
- if (model.layers[il].bqkv) {
6084
- cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
6085
- cb(cur, "bqkv", il);
6086
- }
6087
 
6088
- Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6089
- Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6090
- Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6091
- } else {
6092
- Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
6093
- Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
6094
- Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
6095
- }
6096
 
6097
- if (model.layers[il].attn_q_norm) {
6098
- Qcur = build_norm(Qcur,
6099
- model.layers[il].attn_q_norm,
6100
- model.layers[il].attn_q_norm_b,
6101
- LLM_NORM, il);
6102
- }
6103
 
6104
- if (model.layers[il].attn_k_norm) {
6105
- Kcur = build_norm(Kcur,
6106
- model.layers[il].attn_k_norm,
6107
- model.layers[il].attn_k_norm_b,
6108
- LLM_NORM, il);
6109
- }
6110
 
6111
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6112
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6113
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6114
 
6115
- // RoPE
6116
- if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
6117
- Qcur = ggml_rope_ext(
6118
- ctx0, Qcur, inp_pos, nullptr,
6119
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6120
- ext_factor, attn_factor, beta_fast, beta_slow
6121
- );
6122
 
6123
- Kcur = ggml_rope_ext(
6124
- ctx0, Kcur, inp_pos, nullptr,
6125
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6126
- ext_factor, attn_factor, beta_fast, beta_slow
6127
- );
6128
- }
6129
 
6130
- cb(Qcur, "Qcur", il);
6131
- cb(Kcur, "Kcur", il);
6132
- cb(Vcur, "Vcur", il);
6133
 
6134
- cur = build_attn(inp_attn, gf,
6135
- model.layers[il].wo, model.layers[il].bo,
6136
- Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6137
- cb(cur, "kqv_out", il);
 
6138
 
6139
- if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
6140
- // skip computing output for unused tokens
6141
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6142
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6143
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6144
  }
@@ -6235,56 +6242,57 @@ struct llm_build_neo_bert : public llm_graph_context {
6235
 
6236
  auto * inp_attn = build_attn_inp_no_cache();
6237
 
6238
- // iterate layers
 
6239
  for (int il = 0; il < n_layer; ++il) {
6240
  ggml_tensor * cur = inpL;
6241
 
6242
- ggml_tensor * Qcur;
6243
- ggml_tensor * Kcur;
6244
- ggml_tensor * Vcur;
6245
-
6246
  // pre-norm
6247
  cur = build_norm(inpL,
6248
  model.layers[il].attn_norm, NULL,
6249
  LLM_NORM_RMS, il);
6250
 
6251
- // self-attention
6252
- cur = build_lora_mm(model.layers[il].wqkv, cur);
6253
- cb(cur, "wqkv", il);
6254
-
6255
- Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6256
- Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6257
- Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6258
-
6259
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6260
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6261
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6262
-
6263
- // RoPE
6264
- Qcur = ggml_rope_ext(
6265
- ctx0, Qcur, inp_pos, nullptr,
6266
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6267
- ext_factor, attn_factor, beta_fast, beta_slow
6268
- );
6269
 
6270
- Kcur = ggml_rope_ext(
6271
- ctx0, Kcur, inp_pos, nullptr,
6272
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6273
- ext_factor, attn_factor, beta_fast, beta_slow
6274
- );
 
 
6275
 
6276
- cb(Qcur, "Qcur", il);
6277
- cb(Kcur, "Kcur", il);
6278
- cb(Vcur, "Vcur", il);
6279
 
6280
- cur = build_attn(inp_attn, gf,
6281
- model.layers[il].wo, nullptr,
6282
- Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6283
- cb(cur, "kqv_out", il);
 
 
6284
 
6285
- if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
6286
- // skip computing output for unused tokens
6287
- ggml_tensor * inp_out_ids = build_inp_out_ids();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6288
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6289
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6290
  }
@@ -6349,6 +6357,8 @@ struct llm_build_bloom : public llm_graph_context {
6349
  LLM_NORM, -1);
6350
  cb(inpL, "inp_norm", -1);
6351
 
 
 
6352
  for (int il = 0; il < n_layer; ++il) {
6353
  cur = build_norm(inpL,
6354
  model.layers[il].attn_norm,
@@ -6381,9 +6391,7 @@ struct llm_build_bloom : public llm_graph_context {
6381
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6382
  }
6383
 
6384
- if (il == n_layer - 1) {
6385
- // skip computing output for unused tokens
6386
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6387
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6388
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6389
  }
@@ -6460,6 +6468,8 @@ struct llm_build_mpt : public llm_graph_context {
6460
  cb(inpL, "inpL", -1);
6461
  }
6462
 
 
 
6463
  for (int il = 0; il < n_layer; ++il) {
6464
  ggml_tensor * attn_norm;
6465
 
@@ -6522,9 +6532,7 @@ struct llm_build_mpt : public llm_graph_context {
6522
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6523
  }
6524
 
6525
- if (il == n_layer - 1) {
6526
- // skip computing output for unused tokens
6527
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6528
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6529
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6530
  }
@@ -6593,6 +6601,8 @@ struct llm_build_stablelm : public llm_graph_context {
6593
 
6594
  auto * inp_attn = build_attn_inp_kv_unified();
6595
 
 
 
6596
  for (int il = 0; il < n_layer; ++il) {
6597
  // norm
6598
  cur = build_norm(inpL,
@@ -6668,9 +6678,7 @@ struct llm_build_stablelm : public llm_graph_context {
6668
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6669
  }
6670
 
6671
- if (il == n_layer - 1) {
6672
- // skip computing output for unused tokens
6673
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6674
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6675
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6676
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
@@ -6745,6 +6753,8 @@ struct llm_build_qwen : public llm_graph_context {
6745
 
6746
  auto * inp_attn = build_attn_inp_kv_unified();
6747
 
 
 
6748
  for (int il = 0; il < n_layer; ++il) {
6749
  ggml_tensor * inpSA = inpL;
6750
 
@@ -6791,9 +6801,7 @@ struct llm_build_qwen : public llm_graph_context {
6791
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6792
  }
6793
 
6794
- if (il == n_layer - 1) {
6795
- // skip computing output for unused tokens
6796
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6797
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6798
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
6799
  }
@@ -6862,6 +6870,8 @@ struct llm_build_qwen2 : public llm_graph_context {
6862
 
6863
  auto * inp_attn = build_attn_inp_kv_unified();
6864
 
 
 
6865
  for (int il = 0; il < n_layer; ++il) {
6866
  ggml_tensor * inpSA = inpL;
6867
 
@@ -6911,9 +6921,7 @@ struct llm_build_qwen2 : public llm_graph_context {
6911
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6912
  }
6913
 
6914
- if (il == n_layer - 1) {
6915
- // skip computing output for unused tokens
6916
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6917
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6918
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
6919
  }
@@ -6983,6 +6991,8 @@ struct llm_build_qwen2vl : public llm_graph_context {
6983
  int sections[4];
6984
  std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
6985
 
 
 
6986
  for (int il = 0; il < n_layer; ++il) {
6987
  ggml_tensor * inpSA = inpL;
6988
 
@@ -7032,9 +7042,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
7032
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7033
  }
7034
 
7035
- if (il == n_layer - 1) {
7036
- // skip computing output for unused tokens
7037
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7038
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7039
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7040
  }
@@ -7101,6 +7109,8 @@ struct llm_build_qwen2moe : public llm_graph_context {
7101
 
7102
  auto * inp_attn = build_attn_inp_kv_unified();
7103
 
 
 
7104
  for (int il = 0; il < n_layer; ++il) {
7105
  ggml_tensor * inpSA = inpL;
7106
 
@@ -7159,9 +7169,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
7159
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7160
  }
7161
 
7162
- if (il == n_layer - 1) {
7163
- // skip computing output for unused tokens
7164
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7165
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7166
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7167
  }
@@ -7260,6 +7268,8 @@ struct llm_build_qwen3 : public llm_graph_context {
7260
 
7261
  auto * inp_attn = build_attn_inp_kv_unified();
7262
 
 
 
7263
  for (int il = 0; il < n_layer; ++il) {
7264
  ggml_tensor * inpSA = inpL;
7265
 
@@ -7312,9 +7322,7 @@ struct llm_build_qwen3 : public llm_graph_context {
7312
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7313
  }
7314
 
7315
- if (il == n_layer - 1) {
7316
- // skip computing output for unused tokens
7317
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7318
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7319
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7320
  }
@@ -7381,6 +7389,8 @@ struct llm_build_qwen3moe : public llm_graph_context {
7381
 
7382
  auto * inp_attn = build_attn_inp_kv_unified();
7383
 
 
 
7384
  for (int il = 0; il < n_layer; ++il) {
7385
  ggml_tensor * inpSA = inpL;
7386
 
@@ -7433,9 +7443,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
7433
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7434
  }
7435
 
7436
- if (il == n_layer - 1) {
7437
- // skip computing output for unused tokens
7438
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7439
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7440
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7441
  }
@@ -7511,6 +7519,8 @@ struct llm_build_phi2 : public llm_graph_context {
7511
 
7512
  auto * inp_attn = build_attn_inp_kv_unified();
7513
 
 
 
7514
  for (int il = 0; il < n_layer; ++il) {
7515
  attn_norm_output = build_norm(inpL,
7516
  model.layers[il].attn_norm,
@@ -7573,9 +7583,7 @@ struct llm_build_phi2 : public llm_graph_context {
7573
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
7574
  }
7575
 
7576
- if (il == n_layer - 1) {
7577
- // skip computing output for unused tokens
7578
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7579
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7580
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7581
  attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
@@ -7647,6 +7655,8 @@ struct llm_build_phi3 : public llm_graph_context {
7647
  inp_attn = build_attn_inp_kv_unified();
7648
  }
7649
 
 
 
7650
  for (int il = 0; il < n_layer; ++il) {
7651
  auto * residual = inpL;
7652
 
@@ -7710,9 +7720,7 @@ struct llm_build_phi3 : public llm_graph_context {
7710
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
7711
  }
7712
 
7713
- if (il == n_layer - 1) {
7714
- // skip computing output for unused tokens
7715
- ggml_tensor* inp_out_ids = build_inp_out_ids();
7716
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7717
  residual = ggml_get_rows(ctx0, residual, inp_out_ids);
7718
  }
@@ -7798,15 +7806,16 @@ struct llm_build_plamo : public llm_graph_context {
7798
 
7799
  auto * inp_attn = build_attn_inp_kv_unified();
7800
 
7801
- for (int il = 0; il < n_layer; ++il) {
7802
 
 
7803
  // norm
7804
  cur = build_norm(inpL,
7805
  model.layers[il].attn_norm, NULL,
7806
  LLM_NORM_RMS, il);
7807
  cb(cur, "attn_norm", il);
7808
 
7809
- ggml_tensor * attention_norm = cur;
7810
 
7811
  // self-attention
7812
  {
@@ -7844,18 +7853,17 @@ struct llm_build_plamo : public llm_graph_context {
7844
  model.layers[il].wo, NULL,
7845
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7846
  }
7847
- ggml_tensor * sa_out = cur;
7848
-
7849
- cur = attention_norm;
7850
 
7851
- if (il == n_layer - 1) {
7852
- // skip computing output for unused tokens
7853
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7854
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7855
- sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
7856
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7857
  }
7858
 
 
 
 
 
7859
  // feed-forward network
7860
  {
7861
  cur = build_ffn(cur,
@@ -7920,6 +7928,8 @@ struct llm_build_gpt2 : public llm_graph_context {
7920
  inpL = ggml_add(ctx0, inpL, pos);
7921
  cb(inpL, "inpL", -1);
7922
 
 
 
7923
  for (int il = 0; il < n_layer; ++il) {
7924
  cur = build_norm(inpL,
7925
  model.layers[il].attn_norm,
@@ -7952,9 +7962,7 @@ struct llm_build_gpt2 : public llm_graph_context {
7952
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7953
  }
7954
 
7955
- if (il == n_layer - 1) {
7956
- // skip computing output for unused tokens
7957
- ggml_tensor * inp_out_ids = build_inp_out_ids();
7958
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7959
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7960
  }
@@ -8024,6 +8032,8 @@ struct llm_build_codeshell : public llm_graph_context {
8024
 
8025
  auto * inp_attn = build_attn_inp_kv_unified();
8026
 
 
 
8027
  for (int il = 0; il < n_layer; ++il) {
8028
  cur = build_norm(inpL,
8029
  model.layers[il].attn_norm,
@@ -8068,9 +8078,7 @@ struct llm_build_codeshell : public llm_graph_context {
8068
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8069
  }
8070
 
8071
- if (il == n_layer - 1) {
8072
- // skip computing output for unused tokens
8073
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8074
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8075
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8076
  }
@@ -8124,128 +8132,128 @@ struct llm_build_codeshell : public llm_graph_context {
8124
 
8125
  struct llm_build_orion : public llm_graph_context {
8126
  llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8127
- const int64_t n_embd_head = hparams.n_embd_head_v;
8128
 
8129
- GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8130
- GGML_ASSERT(n_embd_head == hparams.n_rot);
8131
 
8132
- ggml_tensor * cur;
8133
- ggml_tensor * inpL;
8134
 
8135
- inpL = build_inp_embd(model.tok_embd);
8136
 
8137
- // inp_pos - contains the positions
8138
- ggml_tensor * inp_pos = build_inp_pos();
8139
 
8140
- auto * inp_attn = build_attn_inp_kv_unified();
8141
 
8142
- for (int il = 0; il < n_layer; ++il) {
8143
- ggml_tensor * inpSA = inpL;
8144
 
8145
- // norm
8146
- cur = build_norm(inpL,
8147
- model.layers[il].attn_norm, model.layers[il].attn_norm_b,
8148
- LLM_NORM, il);
8149
- cb(cur, "attn_norm", il);
8150
 
8151
- // self-attention
8152
- {
8153
- // compute Q and K and RoPE them
8154
- ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
8155
- cb(Qcur, "Qcur", il);
8156
- // if (model.layers[il].bq) {
8157
- // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8158
- // cb(Qcur, "Qcur", il);
8159
- // }
8160
-
8161
- ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
8162
- cb(Kcur, "Kcur", il);
8163
- // if (model.layers[il].bk) {
8164
- // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8165
- // cb(Kcur, "Kcur", il);
8166
- // }
8167
-
8168
- ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
8169
- cb(Vcur, "Vcur", il);
8170
- // if (model.layers[il].bv) {
8171
- // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8172
- // cb(Vcur, "Vcur", il);
8173
- // }
8174
-
8175
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
8176
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
8177
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
8178
-
8179
- Qcur = ggml_rope_ext(
8180
- ctx0, Qcur, inp_pos, nullptr,
8181
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8182
- ext_factor, attn_factor, beta_fast, beta_slow
8183
- );
8184
 
8185
- Kcur = ggml_rope_ext(
8186
- ctx0, Kcur, inp_pos, nullptr,
8187
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8188
- ext_factor, attn_factor, beta_fast, beta_slow
8189
- );
 
 
 
 
 
 
 
 
 
 
 
8190
 
8191
- cb(Qcur, "Qcur", il);
8192
- cb(Kcur, "Kcur", il);
8193
- cb(Vcur, "Vcur", il);
 
 
 
8194
 
8195
- cur = build_attn(inp_attn, gf,
8196
- model.layers[il].wo, NULL,
8197
- Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8198
- }
8199
 
8200
- if (il == n_layer - 1) {
8201
- // skip computing output for unused tokens
8202
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8203
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8204
- inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8205
- }
8206
 
8207
- ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
8208
- cb(ffn_inp, "ffn_inp", il);
 
 
 
8209
 
8210
- // feed-forward network
8211
- cur = build_norm(ffn_inp,
8212
- model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
8213
- LLM_NORM, il);
8214
- cb(cur, "ffn_norm", il);
8215
 
8216
- cur = build_ffn(cur,
8217
- model.layers[il].ffn_up, NULL, NULL,
8218
- model.layers[il].ffn_gate, NULL, NULL,
8219
- model.layers[il].ffn_down, NULL, NULL,
8220
- NULL,
8221
- LLM_FFN_SILU, LLM_FFN_PAR, il);
8222
- cb(cur, "ffn_out", il);
8223
 
8224
- cur = ggml_add(ctx0, cur, ffn_inp);
 
 
 
8225
 
8226
- cur = build_cvec(cur, il);
8227
- cb(cur, "l_out", il);
8228
 
8229
- // input for next layer
8230
- inpL = cur;
8231
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8232
 
8233
- cur = inpL;
8234
 
8235
- cur = build_norm(cur,
8236
- model.output_norm, model.output_norm_b,
8237
- LLM_NORM, -1);
8238
 
8239
- cb(cur, "result_norm", -1);
8240
- res->t_embd = cur;
8241
 
8242
- // lm_head
8243
- cur = build_lora_mm(model.output, cur);
8244
 
8245
- cb(cur, "result_output", -1);
8246
- res->t_logits = cur;
8247
 
8248
- ggml_build_forward_expand(gf, cur);
8249
  }
8250
  };
8251
 
@@ -8266,6 +8274,8 @@ struct llm_build_internlm2 : public llm_graph_context {
8266
 
8267
  auto * inp_attn = build_attn_inp_kv_unified();
8268
 
 
 
8269
  for (int il = 0; il < n_layer; ++il) {
8270
  ggml_tensor * inpSA = inpL;
8271
 
@@ -8324,9 +8334,7 @@ struct llm_build_internlm2 : public llm_graph_context {
8324
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8325
  }
8326
 
8327
- if (il == n_layer - 1) {
8328
- // skip computing output for unused tokens
8329
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8330
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8331
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8332
  }
@@ -8402,6 +8410,8 @@ struct llm_build_minicpm3 : public llm_graph_context {
8402
 
8403
  auto * inp_attn = build_attn_inp_kv_unified();
8404
 
 
 
8405
  for (int il = 0; il < n_layer; ++il) {
8406
  ggml_tensor * inpSA = inpL;
8407
 
@@ -8521,15 +8531,13 @@ struct llm_build_minicpm3 : public llm_graph_context {
8521
  q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
8522
  }
8523
 
8524
- if (il == n_layer - 1) {
8525
- // skip computing output for unused tokens
8526
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8527
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8528
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8529
  }
8530
 
8531
  // scale_res - scale the hidden states for residual connection
8532
- const float scale_res = scale_depth/sqrtf(float(n_layer));
8533
  cur = ggml_scale(ctx0, cur, scale_res);
8534
  cb(cur, "hidden_scaled", il);
8535
 
@@ -8606,6 +8614,8 @@ struct llm_build_gemma : public llm_graph_context {
8606
 
8607
  auto * inp_attn = build_attn_inp_kv_unified();
8608
 
 
 
8609
  for (int il = 0; il < n_layer; ++il) {
8610
  // norm
8611
  cur = build_norm(inpL,
@@ -8651,9 +8661,7 @@ struct llm_build_gemma : public llm_graph_context {
8651
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8652
  }
8653
 
8654
- if (il == n_layer - 1) {
8655
- // skip computing output for unused tokens
8656
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8657
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8658
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8659
  }
@@ -8722,6 +8730,8 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
8722
 
8723
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
8724
 
 
 
8725
  for (int il = 0; il < n_layer; ++il) {
8726
  // norm
8727
  cur = build_norm(inpL,
@@ -8766,18 +8776,16 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
8766
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8767
  }
8768
 
 
 
 
 
 
8769
  cur = build_norm(cur,
8770
  model.layers[il].attn_post_norm, NULL,
8771
  LLM_NORM_RMS, il);
8772
  cb(cur, "attn_post_norm", il);
8773
 
8774
- if (il == n_layer - 1) {
8775
- // skip computing output for unused tokens
8776
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8777
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8778
- inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8779
- }
8780
-
8781
  ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
8782
  cb(sa_out, "sa_out", il);
8783
 
@@ -8856,6 +8864,8 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8856
  // TODO: is causal == true correct? might need some changes
8857
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
8858
 
 
 
8859
  for (int il = 0; il < n_layer; ++il) {
8860
  const float freq_base_l = model.get_rope_freq_base (cparams, il);
8861
  const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
@@ -8908,18 +8918,16 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8908
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8909
  }
8910
 
 
 
 
 
 
8911
  cur = build_norm(cur,
8912
  model.layers[il].attn_post_norm, NULL,
8913
  LLM_NORM_RMS, il);
8914
  cb(cur, "attn_post_norm", il);
8915
 
8916
- if (il == n_layer - 1) {
8917
- // skip computing output for unused tokens
8918
- ggml_tensor * inp_out_ids = build_inp_out_ids();
8919
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8920
- inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8921
- }
8922
-
8923
  ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
8924
  cb(sa_out, "sa_out", il);
8925
 
@@ -8990,6 +8998,8 @@ struct llm_build_starcoder2 : public llm_graph_context {
8990
 
8991
  auto * inp_attn = build_attn_inp_kv_unified();
8992
 
 
 
8993
  for (int il = 0; il < n_layer; ++il) {
8994
  ggml_tensor * inpSA = inpL;
8995
 
@@ -9048,9 +9058,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
9048
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9049
  }
9050
 
9051
- if (il == n_layer - 1) {
9052
- // skip computing output for unused tokens
9053
- ggml_tensor * inp_out_ids = build_inp_out_ids();
9054
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9055
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9056
  }
@@ -9111,7 +9119,9 @@ struct llm_build_mamba : public llm_graph_context {
9111
  // {n_embd, n_tokens}
9112
  inpL = build_inp_embd(model.tok_embd);
9113
 
9114
- ggml_tensor * state_copy = build_inp_s_copy();
 
 
9115
 
9116
  for (int il = 0; il < n_layer; ++il) {
9117
  // norm
@@ -9120,11 +9130,9 @@ struct llm_build_mamba : public llm_graph_context {
9120
  LLM_NORM_RMS, il);
9121
  cb(cur, "attn_norm", il);
9122
 
9123
- cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
9124
 
9125
- if (il == n_layer - 1) {
9126
- // skip computing output for unused tokens
9127
- ggml_tensor * inp_out_ids = build_inp_out_ids();
9128
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9129
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
9130
  }
@@ -9158,12 +9166,12 @@ struct llm_build_mamba : public llm_graph_context {
9158
 
9159
  // TODO: split
9160
  ggml_tensor * build_mamba_layer(
9161
- ggml_cgraph * gf,
9162
- ggml_tensor * cur,
9163
- ggml_tensor * state_copy,
9164
- const llama_ubatch & ubatch,
9165
- int il) const {
9166
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
9167
 
9168
  const auto kv_head = kv_state->get_head();
9169
 
@@ -9183,17 +9191,17 @@ struct llm_build_mamba : public llm_graph_context {
9183
  GGML_ASSERT(ubatch.equal_seqs);
9184
  GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
9185
 
9186
- ggml_tensor * conv_states_all = kv_state->get_k_l(il);
9187
- ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
9188
 
9189
  // (ab)using the KV cache to store the states
9190
- ggml_tensor * conv = build_recurrent_state(
9191
- gf, conv_states_all, state_copy,
9192
- hparams.n_embd_k_s(), n_seqs);
9193
  conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9194
- ggml_tensor * ssm = build_recurrent_state(
9195
- gf, ssm_states_all, state_copy,
9196
- hparams.n_embd_v_s(), n_seqs);
9197
  ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
9198
 
9199
  // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -9306,13 +9314,15 @@ struct llm_build_command_r : public llm_graph_context {
9306
 
9307
  auto * inp_attn = build_attn_inp_kv_unified();
9308
 
9309
- for (int il = 0; il < n_layer; ++il) {
9310
 
 
9311
  // norm
9312
  cur = build_norm(inpL,
9313
  model.layers[il].attn_norm, NULL,
9314
  LLM_NORM, il);
9315
  cb(cur, "attn_norm", il);
 
9316
  ggml_tensor * ffn_inp = cur;
9317
 
9318
  // self-attention
@@ -9380,9 +9390,7 @@ struct llm_build_command_r : public llm_graph_context {
9380
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9381
  }
9382
 
9383
- if (il == n_layer - 1) {
9384
- // skip computing output for unused tokens
9385
- ggml_tensor * inp_out_ids = build_inp_out_ids();
9386
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9387
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
9388
  ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
@@ -9453,6 +9461,8 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
9453
 
9454
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
9455
 
 
 
9456
  for (int il = 0; il < n_layer; ++il) {
9457
  const bool is_swa = hparams.is_swa(il);
9458
 
@@ -9515,9 +9525,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
9515
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9516
  }
9517
 
9518
- if (il == n_layer - 1) {
9519
- // skip computing output for unused tokens
9520
- ggml_tensor * inp_out_ids = build_inp_out_ids();
9521
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9522
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
9523
  ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
@@ -9588,6 +9596,8 @@ struct llm_build_olmo : public llm_graph_context {
9588
 
9589
  auto * inp_attn = build_attn_inp_kv_unified();
9590
 
 
 
9591
  for (int il = 0; il < n_layer; ++il) {
9592
  ggml_tensor * inpSA = inpL;
9593
 
@@ -9646,9 +9656,7 @@ struct llm_build_olmo : public llm_graph_context {
9646
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9647
  }
9648
 
9649
- if (il == n_layer - 1) {
9650
- // skip computing output for unused tokens
9651
- ggml_tensor * inp_out_ids = build_inp_out_ids();
9652
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9653
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9654
  }
@@ -9716,6 +9724,8 @@ struct llm_build_olmo2 : public llm_graph_context {
9716
 
9717
  auto * inp_attn = build_attn_inp_kv_unified();
9718
 
 
 
9719
  for (int il = 0; il < n_layer; ++il) {
9720
  ggml_tensor * inpSA = inpL;
9721
 
@@ -9766,18 +9776,16 @@ struct llm_build_olmo2 : public llm_graph_context {
9766
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9767
  }
9768
 
 
 
 
 
 
9769
  cur = build_norm(cur,
9770
  model.layers[il].attn_post_norm, NULL,
9771
  LLM_NORM_RMS, il);
9772
  cb(cur, "attn_post_norm", il);
9773
 
9774
- if (il == n_layer - 1) {
9775
- // skip computing output for unused tokens
9776
- ggml_tensor * inp_out_ids = build_inp_out_ids();
9777
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9778
- inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9779
- }
9780
-
9781
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
9782
  cb(ffn_inp, "ffn_inp", il);
9783
 
@@ -9845,6 +9853,8 @@ struct llm_build_olmoe : public llm_graph_context {
9845
 
9846
  auto * inp_attn = build_attn_inp_kv_unified();
9847
 
 
 
9848
  for (int il = 0; il < n_layer; ++il) {
9849
  ggml_tensor * inpSA = inpL;
9850
 
@@ -9899,9 +9909,7 @@ struct llm_build_olmoe : public llm_graph_context {
9899
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9900
  }
9901
 
9902
- if (il == n_layer - 1) {
9903
- // skip computing output for unused tokens
9904
- ggml_tensor * inp_out_ids = build_inp_out_ids();
9905
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9906
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9907
  }
@@ -9971,6 +9979,8 @@ struct llm_build_openelm : public llm_graph_context {
9971
 
9972
  auto * inp_attn = build_attn_inp_kv_unified();
9973
 
 
 
9974
  for (int il = 0; il < n_layer; ++il) {
9975
  const int64_t n_head = hparams.n_head(il);
9976
  const int64_t n_head_kv = hparams.n_head_kv(il);
@@ -10032,11 +10042,9 @@ struct llm_build_openelm : public llm_graph_context {
10032
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10033
  }
10034
 
10035
- if (il == n_layer - 1) {
10036
- // skip computing output for unused tokens
10037
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10038
  residual = ggml_get_rows(ctx0, residual, inp_out_ids);
10039
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10040
  }
10041
 
10042
  ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
@@ -10102,6 +10110,8 @@ struct llm_build_gptneox : public llm_graph_context {
10102
 
10103
  auto * inp_attn = build_attn_inp_kv_unified();
10104
 
 
 
10105
  for (int il = 0; il < n_layer; ++il) {
10106
  cur = build_norm(inpL,
10107
  model.layers[il].attn_norm,
@@ -10146,9 +10156,7 @@ struct llm_build_gptneox : public llm_graph_context {
10146
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10147
  }
10148
 
10149
- if (il == n_layer - 1) {
10150
- // skip computing output for unused tokens
10151
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10152
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10153
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
10154
  }
@@ -10250,6 +10258,8 @@ struct llm_build_arctic : public llm_graph_context {
10250
 
10251
  auto * inp_attn = build_attn_inp_kv_unified();
10252
 
 
 
10253
  for (int il = 0; il < n_layer; ++il) {
10254
  ggml_tensor * inpSA = inpL;
10255
 
@@ -10296,9 +10306,7 @@ struct llm_build_arctic : public llm_graph_context {
10296
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10297
  }
10298
 
10299
- if (il == n_layer - 1) {
10300
- // skip computing output for unused tokens
10301
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10302
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10303
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10304
  }
@@ -10390,6 +10398,8 @@ struct llm_build_deepseek : public llm_graph_context {
10390
 
10391
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
10392
 
 
 
10393
  for (int il = 0; il < n_layer; ++il) {
10394
  ggml_tensor * inpSA = inpL;
10395
 
@@ -10451,14 +10461,11 @@ struct llm_build_deepseek : public llm_graph_context {
10451
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
10452
  }
10453
 
10454
- if (il == n_layer - 1) {
10455
- // skip computing output for unused tokens
10456
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10457
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10458
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10459
  }
10460
 
10461
-
10462
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
10463
  cb(ffn_inp, "ffn_inp", il);
10464
 
@@ -10566,6 +10573,8 @@ struct llm_build_deepseek2 : public llm_graph_context {
10566
 
10567
  auto * inp_attn = build_attn_inp_kv_unified();
10568
 
 
 
10569
  for (int il = 0; il < n_layer; ++il) {
10570
  ggml_tensor * inpSA = inpL;
10571
 
@@ -10715,9 +10724,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
10715
  }
10716
  }
10717
 
10718
- if (il == n_layer - 1) {
10719
- // skip computing output for unused tokens
10720
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10721
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10722
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10723
  }
@@ -10813,6 +10820,8 @@ struct llm_build_bitnet : public llm_graph_context {
10813
 
10814
  auto * inp_attn = build_attn_inp_kv_unified();
10815
 
 
 
10816
  for (int il = 0; il < n_layer; ++il) {
10817
  ggml_tensor * inpSA = inpL;
10818
 
@@ -10895,9 +10904,7 @@ struct llm_build_bitnet : public llm_graph_context {
10895
  cb(cur, "attn_o_out", il);
10896
  }
10897
 
10898
- if (il == n_layer - 1) {
10899
- // skip computing output for unused tokens
10900
- ggml_tensor * inp_out_ids = build_inp_out_ids();
10901
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10902
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10903
  }
@@ -10972,6 +10979,8 @@ struct llm_build_t5_enc : public llm_graph_context {
10972
 
10973
  auto * inp_attn = build_attn_inp_no_cache();
10974
 
 
 
10975
  for (int il = 0; il < n_layer; ++il) {
10976
  ggml_tensor * inpSA = inpL;
10977
 
@@ -11005,9 +11014,7 @@ struct llm_build_t5_enc : public llm_graph_context {
11005
  cb(cur, "kqv_out", il);
11006
  }
11007
 
11008
- if (il == n_layer - 1) {
11009
- // skip computing output for unused tokens
11010
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11011
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11012
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11013
  }
@@ -11078,6 +11085,8 @@ struct llm_build_t5_dec : public llm_graph_context {
11078
  auto * inp_attn_self = build_attn_inp_kv_unified();
11079
  auto * inp_attn_cross = build_attn_inp_cross();
11080
 
 
 
11081
  for (int il = 0; il < n_layer; ++il) {
11082
  ggml_tensor * inpSA = inpL;
11083
 
@@ -11169,11 +11178,8 @@ struct llm_build_t5_dec : public llm_graph_context {
11169
  //cb(cur, "kqv_out", il);
11170
  }
11171
 
11172
- if (il == n_layer - 1) {
11173
- // skip computing output for unused tokens
11174
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11175
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11176
- inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11177
  inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
11178
  }
11179
 
@@ -11243,6 +11249,8 @@ struct llm_build_jais : public llm_graph_context {
11243
 
11244
  auto * inp_attn = build_attn_inp_kv_unified();
11245
 
 
 
11246
  for (int il = 0; il < n_layer; ++il) {
11247
  cur = build_norm(inpL,
11248
  model.layers[il].attn_norm,
@@ -11275,9 +11283,7 @@ struct llm_build_jais : public llm_graph_context {
11275
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
11276
  }
11277
 
11278
- if (il == n_layer - 1) {
11279
- // skip computing output for unused tokens
11280
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11281
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11282
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
11283
  }
@@ -11341,6 +11347,8 @@ struct llm_build_chatglm : public llm_graph_context {
11341
 
11342
  auto * inp_attn = build_attn_inp_kv_unified();
11343
 
 
 
11344
  for (int il = 0; il < n_layer; ++il) {
11345
  ggml_tensor * inpSA = inpL;
11346
 
@@ -11407,9 +11415,7 @@ struct llm_build_chatglm : public llm_graph_context {
11407
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11408
  }
11409
 
11410
- if (il == n_layer - 1) {
11411
- // skip computing output for unused tokens
11412
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11413
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11414
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11415
  }
@@ -11474,6 +11480,8 @@ struct llm_build_glm4 : public llm_graph_context {
11474
 
11475
  auto * inp_attn = build_attn_inp_kv_unified();
11476
 
 
 
11477
  for (int il = 0; il < n_layer; ++il) {
11478
  ggml_tensor * inpSA = inpL;
11479
 
@@ -11540,9 +11548,7 @@ struct llm_build_glm4 : public llm_graph_context {
11540
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11541
  }
11542
 
11543
- if (il == n_layer - 1) {
11544
- // skip computing output for unused tokens
11545
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11546
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11547
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11548
  }
@@ -11625,6 +11631,8 @@ struct llm_build_nemotron : public llm_graph_context {
11625
 
11626
  auto * inp_attn = build_attn_inp_kv_unified();
11627
 
 
 
11628
  for (int il = 0; il < n_layer; ++il) {
11629
  ggml_tensor * inpSA = inpL;
11630
 
@@ -11684,9 +11692,7 @@ struct llm_build_nemotron : public llm_graph_context {
11684
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11685
  }
11686
 
11687
- if (il == n_layer - 1) {
11688
- // skip computing output for unused tokens
11689
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11690
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11691
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11692
  }
@@ -11754,6 +11760,8 @@ struct llm_build_exaone : public llm_graph_context {
11754
 
11755
  auto * inp_attn = build_attn_inp_kv_unified();
11756
 
 
 
11757
  for (int il = 0; il < n_layer; ++il) {
11758
  ggml_tensor * inpSA = inpL;
11759
 
@@ -11815,9 +11823,7 @@ struct llm_build_exaone : public llm_graph_context {
11815
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11816
  }
11817
 
11818
- if (il == n_layer - 1) {
11819
- // skip computing output for unused tokens
11820
- ggml_tensor * inp_out_ids = build_inp_out_ids();
11821
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11822
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11823
  }
@@ -11904,13 +11910,13 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11904
  }
11905
 
11906
  ggml_tensor * build_rwkv6_time_mix(
 
11907
  ggml_cgraph * gf,
11908
  ggml_tensor * cur,
11909
  ggml_tensor * x_prev,
11910
- ggml_tensor * state_copy,
11911
  const llama_ubatch & ubatch,
11912
  int il) const {
11913
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
11914
 
11915
  const auto n_tokens = ubatch.n_tokens;
11916
  const auto n_seqs = ubatch.n_seqs;
@@ -12031,9 +12037,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
12031
  k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
12032
  }
12033
 
12034
- ggml_tensor * wkv_state = build_recurrent_state(
12035
- gf, kv_state->get_v_l(il), state_copy,
12036
- hparams.n_embd_v_s(), n_seqs);
12037
 
12038
  ggml_tensor * wkv_output;
12039
  if (is_qrwkv) {
@@ -12051,9 +12057,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
12051
  wkv_state,
12052
  ggml_view_1d(
12053
  ctx0,
12054
- kv_state->get_v_l(il),
12055
- hparams.n_embd_v_s() * n_seqs,
12056
- hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
12057
  )
12058
  )
12059
  );
@@ -12087,19 +12093,19 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
12087
  inpL = build_inp_embd(model.tok_embd);
12088
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
12089
 
12090
- ggml_tensor * state_copy = build_inp_s_copy();
12091
 
12092
  const auto n_embd = hparams.n_embd;
12093
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12094
  const auto n_seqs = ubatch.n_seqs;
12095
 
 
 
12096
  for (int il = 0; il < n_layer; ++il) {
12097
  const llama_layer * layer = &model.layers[il];
12098
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12099
 
12100
- ggml_tensor * token_shift = build_rwkv_token_shift_load(
12101
- gf, state_copy, ubatch, il
12102
- );
12103
 
12104
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
12105
  ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -12114,7 +12120,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
12114
  1
12115
  );
12116
 
12117
- cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
12118
 
12119
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12120
  cb(ffn_inp, "ffn_inp", il);
@@ -12136,13 +12142,16 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
12136
  );
12137
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
12138
 
12139
- if (il == n_layer - 1) {
12140
- // skip computing output for unused tokens
12141
- struct ggml_tensor * inp_out_ids = build_inp_out_ids();
12142
- ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
12143
- ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
12144
- x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
12145
- cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
 
 
 
12146
  }
12147
 
12148
  cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
@@ -12177,26 +12186,26 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
12177
  // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
12178
  struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
12179
  llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
12180
- GGML_ASSERT(n_embd == hparams.n_embd_k_s());
12181
 
12182
  ggml_tensor * cur;
12183
  ggml_tensor * inpL;
12184
 
12185
  inpL = build_inp_embd(model.tok_embd);
12186
 
12187
- ggml_tensor * state_copy = build_inp_s_copy();
12188
 
12189
  const auto n_embd = hparams.n_embd;
12190
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12191
  const auto n_seqs = ubatch.n_seqs;
12192
 
 
 
12193
  for (int il = 0; il < n_layer; ++il) {
12194
  const llama_layer * layer = &model.layers[il];
12195
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12196
 
12197
- ggml_tensor * token_shift = build_rwkv_token_shift_load(
12198
- gf, state_copy, ubatch, il
12199
- );
12200
 
12201
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
12202
  cb(att_norm, "attn_norm", il);
@@ -12208,7 +12217,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
12208
  1
12209
  );
12210
 
12211
- cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
12212
 
12213
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
12214
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -12216,11 +12225,12 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
12216
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12217
  cb(ffn_inp, "ffn_inp", il);
12218
 
12219
- if (il == n_layer - 1) {
12220
- // skip computing output for unused tokens
12221
- struct ggml_tensor * inp_out_ids = build_inp_out_ids();
12222
- cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
12223
- ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
 
12224
  }
12225
 
12226
  // feed-forward network
@@ -12296,14 +12306,14 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12296
  }
12297
 
12298
  ggml_tensor * build_rwkv7_time_mix(
 
12299
  ggml_cgraph * gf,
12300
  ggml_tensor * cur,
12301
  ggml_tensor * x_prev,
12302
- ggml_tensor * state_copy,
12303
  ggml_tensor *& first_layer_value,
12304
  const llama_ubatch & ubatch,
12305
  int il) const {
12306
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
12307
 
12308
  const auto n_tokens = ubatch.n_tokens;
12309
  const auto n_seqs = ubatch.n_seqs;
@@ -12382,9 +12392,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12382
  v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
12383
  a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
12384
 
12385
- ggml_tensor * wkv_state = build_recurrent_state(
12386
- gf, kv_state->get_v_l(il), state_copy,
12387
- hparams.n_embd_v_s(), n_seqs);
12388
 
12389
  ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
12390
  cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
@@ -12397,9 +12407,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12397
  wkv_state,
12398
  ggml_view_1d(
12399
  ctx0,
12400
- kv_state->get_v_l(il),
12401
- hparams.n_embd_v_s() * n_seqs,
12402
- hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
12403
  )
12404
  )
12405
  );
@@ -12440,19 +12450,19 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12440
  inpL = build_inp_embd(model.tok_embd);
12441
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
12442
 
12443
- ggml_tensor * state_copy = build_inp_s_copy();
12444
 
12445
  const auto n_embd = hparams.n_embd;
12446
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12447
  const auto n_seqs = ubatch.n_seqs;
12448
 
 
 
12449
  for (int il = 0; il < n_layer; ++il) {
12450
  const llama_layer * layer = &model.layers[il];
12451
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12452
 
12453
- ggml_tensor * token_shift = build_rwkv_token_shift_load(
12454
- gf, state_copy, ubatch, il
12455
- );
12456
 
12457
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
12458
  ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -12467,7 +12477,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12467
  1
12468
  );
12469
 
12470
- cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
12471
 
12472
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12473
  cb(ffn_inp, "ffn_inp", il);
@@ -12489,12 +12499,14 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12489
  );
12490
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
12491
 
12492
- if (il == n_layer - 1) {
12493
- // skip computing output for unused tokens
12494
- struct ggml_tensor * inp_out_ids = build_inp_out_ids();
12495
- ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
12496
- ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
12497
- x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
 
 
12498
  }
12499
 
12500
  cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
@@ -12525,7 +12537,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12525
 
12526
  struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12527
  llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
12528
- GGML_ASSERT(n_embd == hparams.n_embd_k_s());
12529
 
12530
  ggml_tensor * cur;
12531
  ggml_tensor * inpL;
@@ -12533,19 +12545,19 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12533
 
12534
  inpL = build_inp_embd(model.tok_embd);
12535
 
12536
- ggml_tensor * state_copy = build_inp_s_copy();
12537
 
12538
  const auto n_embd = hparams.n_embd;
12539
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12540
  const auto n_seqs = ubatch.n_seqs;
12541
 
 
 
12542
  for (int il = 0; il < n_layer; ++il) {
12543
  const llama_layer * layer = &model.layers[il];
12544
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12545
 
12546
- ggml_tensor * token_shift = build_rwkv_token_shift_load(
12547
- gf, state_copy, ubatch, il
12548
- );
12549
 
12550
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
12551
  cb(att_norm, "attn_norm", il);
@@ -12557,7 +12569,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12557
  1
12558
  );
12559
 
12560
- cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
12561
 
12562
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
12563
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -12565,11 +12577,12 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12565
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12566
  cb(ffn_inp, "ffn_inp", il);
12567
 
12568
- if (il == n_layer - 1) {
12569
- // skip computing output for unused tokens
12570
- struct ggml_tensor * inp_out_ids = build_inp_out_ids();
12571
- cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
12572
- ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
 
12573
  }
12574
 
12575
  // feed-forward network
@@ -12638,6 +12651,9 @@ struct llm_build_granite : public llm_graph_context {
12638
  auto * inp_attn = build_attn_inp_kv_unified();
12639
 
12640
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
 
 
12641
  for (int il = 0; il < n_layer; ++il) {
12642
  ggml_tensor * inpSA = inpL;
12643
 
@@ -12700,9 +12716,7 @@ struct llm_build_granite : public llm_graph_context {
12700
  cb(cur, "attn_out", il);
12701
  }
12702
 
12703
- if (il == n_layer - 1) {
12704
- // skip computing output for unused tokens
12705
- ggml_tensor * inp_out_ids = build_inp_out_ids();
12706
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12707
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
12708
  }
@@ -12821,6 +12835,8 @@ struct llm_build_chameleon : public llm_graph_context {
12821
 
12822
  auto * inp_attn = build_attn_inp_kv_unified();
12823
 
 
 
12824
  for (int il = 0; il < n_layer; ++il) {
12825
  ggml_tensor * inpSA = inpL;
12826
 
@@ -12897,21 +12913,19 @@ struct llm_build_chameleon : public llm_graph_context {
12897
  cur = build_attn(inp_attn, gf,
12898
  model.layers[il].wo, nullptr,
12899
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
12900
-
12901
- if (hparams.swin_norm) {
12902
- cur = build_norm(cur,
12903
- model.layers[il].attn_norm, NULL,
12904
- LLM_NORM_RMS, il);
12905
- }
12906
  }
12907
 
12908
- if (il == n_layer - 1) {
12909
- // skip computing output for unused tokens
12910
- ggml_tensor * inp_out_ids = build_inp_out_ids();
12911
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12912
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
12913
  }
12914
 
 
 
 
 
 
 
12915
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
12916
  cb(ffn_inp, "ffn_inp", il);
12917
 
@@ -13152,6 +13166,8 @@ struct llm_build_plm : public llm_graph_context {
13152
 
13153
  auto * inp_attn = build_attn_inp_kv_unified();
13154
 
 
 
13155
  for (int il = 0; il < n_layer; ++il) {
13156
  ggml_tensor * inpSA = inpL;
13157
 
@@ -13255,9 +13271,7 @@ struct llm_build_plm : public llm_graph_context {
13255
  q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
13256
  }
13257
 
13258
- if (il == n_layer - 1) {
13259
- // skip computing output for unused tokens
13260
- ggml_tensor * inp_out_ids = build_inp_out_ids();
13261
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13262
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13263
  }
@@ -13317,6 +13331,8 @@ struct llm_build_bailingmoe : public llm_graph_context {
13317
 
13318
  auto * inp_attn = build_attn_inp_kv_unified();
13319
 
 
 
13320
  for (int il = 0; il < n_layer; ++il) {
13321
  ggml_tensor * inpSA = inpL;
13322
 
@@ -13378,9 +13394,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
13378
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
13379
  }
13380
 
13381
- if (il == n_layer - 1) {
13382
- // skip computing output for unused tokens
13383
- ggml_tensor * inp_out_ids = build_inp_out_ids();
13384
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13385
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13386
  }
@@ -13466,6 +13480,8 @@ struct llm_build_dots1 : public llm_graph_context {
13466
 
13467
  auto * inp_attn = build_attn_inp_kv_unified();
13468
 
 
 
13469
  for (int il = 0; il < n_layer; ++il) {
13470
  ggml_tensor * inpSA = inpL;
13471
 
@@ -13518,9 +13534,7 @@ struct llm_build_dots1 : public llm_graph_context {
13518
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
13519
  }
13520
 
13521
- if (il == n_layer - 1) {
13522
- // skip computing output for unused tokens
13523
- ggml_tensor * inp_out_ids = build_inp_out_ids();
13524
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13525
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13526
  }
@@ -13618,6 +13632,8 @@ struct llm_build_arcee : public llm_graph_context {
13618
 
13619
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
13620
 
 
 
13621
  for (int il = 0; il < n_layer; ++il) {
13622
  ggml_tensor * inpSA = inpL;
13623
 
@@ -13680,9 +13696,7 @@ struct llm_build_arcee : public llm_graph_context {
13680
  cb(cur, "attn_out", il);
13681
  }
13682
 
13683
- if (il == n_layer - 1) {
13684
- // skip computing output for unused tokens
13685
- ggml_tensor * inp_out_ids = build_inp_out_ids();
13686
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13687
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13688
  }
@@ -13738,6 +13752,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13738
  llama_memory_i * res;
13739
 
13740
  switch (arch) {
 
 
13741
  case LLM_ARCH_BERT:
13742
  case LLM_ARCH_JINA_BERT_V2:
13743
  case LLM_ARCH_NOMIC_BERT:
@@ -13747,57 +13763,75 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13747
  {
13748
  res = nullptr;
13749
  } break;
13750
- case LLM_ARCH_MAMBA:
13751
- case LLM_ARCH_RWKV6:
13752
- case LLM_ARCH_RWKV6QWEN2:
13753
- case LLM_ARCH_RWKV7:
13754
- case LLM_ARCH_ARWKV7:
13755
- {
13756
- res = new llama_kv_cache_recurrent(
13757
- *this,
13758
- GGML_TYPE_F32,
13759
- GGML_TYPE_F32,
13760
- cparams.offload_kqv,
13761
- std::max((uint32_t) 1, cparams.n_seq_max),
13762
- cparams.n_seq_max);
13763
- } break;
13764
  default:
13765
  {
13766
- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13767
-
13768
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13769
-
13770
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13771
-
13772
- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13773
- GGML_ASSERT(hparams.is_swa_any());
13774
-
13775
- res = new llama_kv_cache_unified_iswa(
13776
- *this,
13777
- params.type_k,
13778
- params.type_v,
13779
- !cparams.flash_attn,
13780
- cparams.offload_kqv,
13781
- params.swa_full,
13782
- cparams.n_ctx,
13783
- cparams.n_seq_max,
13784
- cparams.n_ubatch,
13785
- padding);
13786
- } else {
13787
- GGML_ASSERT(!hparams.is_swa_any());
13788
-
13789
- res = new llama_kv_cache_unified(
13790
  *this,
13791
  nullptr,
13792
- params.type_k,
13793
- params.type_v,
13794
- !cparams.flash_attn,
13795
  cparams.offload_kqv,
13796
- cparams.n_ctx,
13797
- cparams.n_seq_max,
13798
- padding,
13799
- hparams.n_swa,
13800
- hparams.swa_type);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13801
  }
13802
  }
13803
  }
@@ -14377,14 +14411,7 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
14377
  }
14378
 
14379
  bool llama_model_is_recurrent(const llama_model * model) {
14380
- switch (model->arch) {
14381
- case LLM_ARCH_MAMBA: return true;
14382
- case LLM_ARCH_RWKV6: return true;
14383
- case LLM_ARCH_RWKV6QWEN2: return true;
14384
- case LLM_ARCH_RWKV7: return true;
14385
- case LLM_ARCH_ARWKV7: return true;
14386
- default: return false;
14387
- }
14388
  }
14389
 
14390
  const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
 
8
 
9
  #include "llama-kv-cache-unified.h"
10
  #include "llama-kv-cache-unified-iswa.h"
11
+ #include "llama-memory-hybrid.h"
12
+ #include "llama-memory-recurrent.h"
13
 
14
  #include "ggml-cpp.h"
15
 
 
471
  std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
472
  std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
473
  std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
474
+ std::fill(
475
+ hparams.recurrent_layer_arr.begin(),
476
+ hparams.recurrent_layer_arr.end(),
477
+ llm_arch_is_recurrent(ml.get_arch()));
478
 
479
  std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
480
 
 
4707
 
4708
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
4709
 
4710
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
4711
+
4712
  for (int il = 0; il < n_layer; ++il) {
4713
  ggml_tensor * inpSA = inpL;
4714
 
 
4771
  cb(cur, "attn_out", il);
4772
  }
4773
 
4774
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
4775
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
4776
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
4777
  }
 
4867
 
4868
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
4869
 
4870
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
4871
+
4872
  for (int il = 0; il < n_layer; ++il) {
4873
  ggml_tensor * inpSA = inpL;
4874
 
 
4945
  cb(cur, "attn_out", il);
4946
  }
4947
 
4948
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
4949
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
4950
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
4951
  }
 
5045
  auto * inp_attn = build_attn_inp_kv_unified();
5046
 
5047
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
5048
+
5049
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5050
+
5051
  for (int il = 0; il < n_layer; ++il) {
5052
  ggml_tensor * inpSA = inpL;
5053
  const int64_t n_head_kv = hparams.n_head_kv(il);
 
5121
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
5122
  }
5123
 
5124
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
5125
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5126
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5127
  }
 
5200
 
5201
  auto * inp_attn = build_attn_inp_kv_unified();
5202
 
5203
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5204
+
5205
  for (int il = 0; il < n_layer; ++il) {
5206
  ggml_tensor * inpSA = inpL;
5207
 
 
5253
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5254
  }
5255
 
5256
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
5257
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5258
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5259
  }
 
5322
 
5323
  auto * inp_attn = build_attn_inp_kv_unified();
5324
 
5325
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5326
+
5327
  for (int il = 0; il < n_layer; ++il) {
5328
  ggml_tensor * inpSA = inpL;
5329
 
 
5368
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5369
  }
5370
 
5371
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
5372
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5373
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5374
  }
 
5436
 
5437
  auto * inp_attn = build_attn_inp_kv_unified();
5438
 
5439
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5440
+
5441
  for (int il = 0; il < n_layer; ++il) {
5442
  ggml_tensor * attn_norm;
5443
 
 
5493
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5494
  }
5495
 
5496
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
5497
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5498
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
5499
  attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
 
5562
 
5563
  auto * inp_attn = build_attn_inp_kv_unified();
5564
 
5565
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5566
+
5567
  for (int il = 0; il < n_layer; ++il) {
5568
  ggml_tensor * inpSA = inpL;
5569
 
 
5623
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
5624
  }
5625
 
5626
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
5627
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5628
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5629
  }
 
5722
 
5723
  auto * inp_attn = build_attn_inp_kv_unified();
5724
 
5725
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5726
+
5727
  for (int il = 0; il < n_layer; ++il) {
5728
  ggml_tensor * inpSA = inpL;
5729
 
 
5774
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5775
  }
5776
 
5777
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
5778
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5779
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5780
  }
 
5854
  inpL = ggml_add(ctx0, inpL, pos);
5855
  cb(inpL, "inpL", -1);
5856
 
5857
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5858
+
5859
  for (int il = 0; il < n_layer; ++il) {
5860
  cur = build_norm(inpL,
5861
  model.layers[il].attn_norm,
 
5888
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5889
  }
5890
 
5891
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
5892
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5893
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
5894
  }
 
5953
 
5954
  auto * inp_attn = build_attn_inp_kv_unified();
5955
 
5956
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
5957
+
5958
  for (int il = 0; il < n_layer; ++il) {
5959
  ggml_tensor * inpSA = inpL;
5960
 
 
5987
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5988
  }
5989
 
5990
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
5991
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
5992
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
5993
  }
 
6073
 
6074
  auto * inp_attn = build_attn_inp_no_cache();
6075
 
6076
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6077
+
6078
  for (int il = 0; il < n_layer; ++il) {
6079
  ggml_tensor * cur = inpL;
6080
 
6081
+ {
6082
+ ggml_tensor * Qcur;
6083
+ ggml_tensor * Kcur;
6084
+ ggml_tensor * Vcur;
6085
 
6086
+ // self-attention
6087
+ if (model.layers[il].wqkv) {
6088
+ cur = build_lora_mm(model.layers[il].wqkv, cur);
6089
+ cb(cur, "wqkv", il);
6090
 
6091
+ if (model.layers[il].bqkv) {
6092
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
6093
+ cb(cur, "bqkv", il);
6094
+ }
6095
 
6096
+ Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6097
+ Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6098
+ Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6099
+ } else {
6100
+ Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
6101
+ Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
6102
+ Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
6103
+ }
6104
 
6105
+ if (model.layers[il].attn_q_norm) {
6106
+ Qcur = build_norm(Qcur,
6107
+ model.layers[il].attn_q_norm,
6108
+ model.layers[il].attn_q_norm_b,
6109
+ LLM_NORM, il);
6110
+ }
6111
 
6112
+ if (model.layers[il].attn_k_norm) {
6113
+ Kcur = build_norm(Kcur,
6114
+ model.layers[il].attn_k_norm,
6115
+ model.layers[il].attn_k_norm_b,
6116
+ LLM_NORM, il);
6117
+ }
6118
 
6119
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6120
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6121
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6122
 
6123
+ // RoPE
6124
+ if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
6125
+ Qcur = ggml_rope_ext(
6126
+ ctx0, Qcur, inp_pos, nullptr,
6127
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6128
+ ext_factor, attn_factor, beta_fast, beta_slow
6129
+ );
6130
 
6131
+ Kcur = ggml_rope_ext(
6132
+ ctx0, Kcur, inp_pos, nullptr,
6133
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6134
+ ext_factor, attn_factor, beta_fast, beta_slow
6135
+ );
6136
+ }
6137
 
6138
+ cb(Qcur, "Qcur", il);
6139
+ cb(Kcur, "Kcur", il);
6140
+ cb(Vcur, "Vcur", il);
6141
 
6142
+ cur = build_attn(inp_attn, gf,
6143
+ model.layers[il].wo, model.layers[il].bo,
6144
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6145
+ cb(cur, "kqv_out", il);
6146
+ }
6147
 
6148
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
6149
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6150
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6151
  }
 
6242
 
6243
  auto * inp_attn = build_attn_inp_no_cache();
6244
 
6245
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6246
+
6247
  for (int il = 0; il < n_layer; ++il) {
6248
  ggml_tensor * cur = inpL;
6249
 
 
 
 
 
6250
  // pre-norm
6251
  cur = build_norm(inpL,
6252
  model.layers[il].attn_norm, NULL,
6253
  LLM_NORM_RMS, il);
6254
 
6255
+ {
6256
+ ggml_tensor * Qcur;
6257
+ ggml_tensor * Kcur;
6258
+ ggml_tensor * Vcur;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6259
 
6260
+ // self-attention
6261
+ cur = build_lora_mm(model.layers[il].wqkv, cur);
6262
+ cb(cur, "wqkv", il);
6263
+
6264
+ Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6265
+ Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6266
+ Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6267
 
6268
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6269
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6270
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6271
 
6272
+ // RoPE
6273
+ Qcur = ggml_rope_ext(
6274
+ ctx0, Qcur, inp_pos, nullptr,
6275
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6276
+ ext_factor, attn_factor, beta_fast, beta_slow
6277
+ );
6278
 
6279
+ Kcur = ggml_rope_ext(
6280
+ ctx0, Kcur, inp_pos, nullptr,
6281
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6282
+ ext_factor, attn_factor, beta_fast, beta_slow
6283
+ );
6284
+
6285
+ cb(Qcur, "Qcur", il);
6286
+ cb(Kcur, "Kcur", il);
6287
+ cb(Vcur, "Vcur", il);
6288
+
6289
+ cur = build_attn(inp_attn, gf,
6290
+ model.layers[il].wo, nullptr,
6291
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6292
+ cb(cur, "kqv_out", il);
6293
+ }
6294
+
6295
+ if (il == n_layer - 1 && inp_out_ids) {
6296
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6297
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6298
  }
 
6357
  LLM_NORM, -1);
6358
  cb(inpL, "inp_norm", -1);
6359
 
6360
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6361
+
6362
  for (int il = 0; il < n_layer; ++il) {
6363
  cur = build_norm(inpL,
6364
  model.layers[il].attn_norm,
 
6391
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6392
  }
6393
 
6394
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
6395
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6396
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6397
  }
 
6468
  cb(inpL, "inpL", -1);
6469
  }
6470
 
6471
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6472
+
6473
  for (int il = 0; il < n_layer; ++il) {
6474
  ggml_tensor * attn_norm;
6475
 
 
6532
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6533
  }
6534
 
6535
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
6536
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6537
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6538
  }
 
6601
 
6602
  auto * inp_attn = build_attn_inp_kv_unified();
6603
 
6604
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6605
+
6606
  for (int il = 0; il < n_layer; ++il) {
6607
  // norm
6608
  cur = build_norm(inpL,
 
6678
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6679
  }
6680
 
6681
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
6682
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6683
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6684
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
 
6753
 
6754
  auto * inp_attn = build_attn_inp_kv_unified();
6755
 
6756
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6757
+
6758
  for (int il = 0; il < n_layer; ++il) {
6759
  ggml_tensor * inpSA = inpL;
6760
 
 
6801
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6802
  }
6803
 
6804
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
6805
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6806
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
6807
  }
 
6870
 
6871
  auto * inp_attn = build_attn_inp_kv_unified();
6872
 
6873
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6874
+
6875
  for (int il = 0; il < n_layer; ++il) {
6876
  ggml_tensor * inpSA = inpL;
6877
 
 
6921
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6922
  }
6923
 
6924
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
6925
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6926
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
6927
  }
 
6991
  int sections[4];
6992
  std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
6993
 
6994
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6995
+
6996
  for (int il = 0; il < n_layer; ++il) {
6997
  ggml_tensor * inpSA = inpL;
6998
 
 
7042
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7043
  }
7044
 
7045
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
7046
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7047
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7048
  }
 
7109
 
7110
  auto * inp_attn = build_attn_inp_kv_unified();
7111
 
7112
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7113
+
7114
  for (int il = 0; il < n_layer; ++il) {
7115
  ggml_tensor * inpSA = inpL;
7116
 
 
7169
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7170
  }
7171
 
7172
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
7173
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7174
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7175
  }
 
7268
 
7269
  auto * inp_attn = build_attn_inp_kv_unified();
7270
 
7271
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7272
+
7273
  for (int il = 0; il < n_layer; ++il) {
7274
  ggml_tensor * inpSA = inpL;
7275
 
 
7322
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7323
  }
7324
 
7325
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
7326
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7327
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7328
  }
 
7389
 
7390
  auto * inp_attn = build_attn_inp_kv_unified();
7391
 
7392
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7393
+
7394
  for (int il = 0; il < n_layer; ++il) {
7395
  ggml_tensor * inpSA = inpL;
7396
 
 
7443
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7444
  }
7445
 
7446
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
7447
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7448
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7449
  }
 
7519
 
7520
  auto * inp_attn = build_attn_inp_kv_unified();
7521
 
7522
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7523
+
7524
  for (int il = 0; il < n_layer; ++il) {
7525
  attn_norm_output = build_norm(inpL,
7526
  model.layers[il].attn_norm,
 
7583
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
7584
  }
7585
 
7586
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
7587
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7588
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7589
  attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
 
7655
  inp_attn = build_attn_inp_kv_unified();
7656
  }
7657
 
7658
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7659
+
7660
  for (int il = 0; il < n_layer; ++il) {
7661
  auto * residual = inpL;
7662
 
 
7720
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
7721
  }
7722
 
7723
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
7724
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7725
  residual = ggml_get_rows(ctx0, residual, inp_out_ids);
7726
  }
 
7806
 
7807
  auto * inp_attn = build_attn_inp_kv_unified();
7808
 
7809
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7810
 
7811
+ for (int il = 0; il < n_layer; ++il) {
7812
  // norm
7813
  cur = build_norm(inpL,
7814
  model.layers[il].attn_norm, NULL,
7815
  LLM_NORM_RMS, il);
7816
  cb(cur, "attn_norm", il);
7817
 
7818
+ ggml_tensor * sa_inp = cur;
7819
 
7820
  // self-attention
7821
  {
 
7853
  model.layers[il].wo, NULL,
7854
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7855
  }
 
 
 
7856
 
7857
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
7858
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7859
+ sa_inp = ggml_get_rows(ctx0, sa_inp, inp_out_ids);
7860
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7861
  }
7862
 
7863
+ ggml_tensor * sa_out = cur;
7864
+
7865
+ cur = sa_inp;
7866
+
7867
  // feed-forward network
7868
  {
7869
  cur = build_ffn(cur,
 
7928
  inpL = ggml_add(ctx0, inpL, pos);
7929
  cb(inpL, "inpL", -1);
7930
 
7931
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7932
+
7933
  for (int il = 0; il < n_layer; ++il) {
7934
  cur = build_norm(inpL,
7935
  model.layers[il].attn_norm,
 
7962
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7963
  }
7964
 
7965
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
7966
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7967
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7968
  }
 
8032
 
8033
  auto * inp_attn = build_attn_inp_kv_unified();
8034
 
8035
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8036
+
8037
  for (int il = 0; il < n_layer; ++il) {
8038
  cur = build_norm(inpL,
8039
  model.layers[il].attn_norm,
 
8078
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8079
  }
8080
 
8081
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
8082
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8083
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8084
  }
 
8132
 
8133
  struct llm_build_orion : public llm_graph_context {
8134
  llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8135
+ const int64_t n_embd_head = hparams.n_embd_head_v;
8136
 
8137
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8138
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
8139
 
8140
+ ggml_tensor * cur;
8141
+ ggml_tensor * inpL;
8142
 
8143
+ inpL = build_inp_embd(model.tok_embd);
8144
 
8145
+ // inp_pos - contains the positions
8146
+ ggml_tensor * inp_pos = build_inp_pos();
8147
 
8148
+ auto * inp_attn = build_attn_inp_kv_unified();
8149
 
8150
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
 
8151
 
8152
+ for (int il = 0; il < n_layer; ++il) {
8153
+ ggml_tensor * inpSA = inpL;
 
 
 
8154
 
8155
+ // norm
8156
+ cur = build_norm(inpL,
8157
+ model.layers[il].attn_norm, model.layers[il].attn_norm_b,
8158
+ LLM_NORM, il);
8159
+ cb(cur, "attn_norm", il);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8160
 
8161
+ // self-attention
8162
+ {
8163
+ // compute Q and K and RoPE them
8164
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
8165
+ cb(Qcur, "Qcur", il);
8166
+ // if (model.layers[il].bq) {
8167
+ // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8168
+ // cb(Qcur, "Qcur", il);
8169
+ // }
8170
+
8171
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
8172
+ cb(Kcur, "Kcur", il);
8173
+ // if (model.layers[il].bk) {
8174
+ // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8175
+ // cb(Kcur, "Kcur", il);
8176
+ // }
8177
 
8178
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
8179
+ cb(Vcur, "Vcur", il);
8180
+ // if (model.layers[il].bv) {
8181
+ // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8182
+ // cb(Vcur, "Vcur", il);
8183
+ // }
8184
 
8185
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
8186
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
8187
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
8188
 
8189
+ Qcur = ggml_rope_ext(
8190
+ ctx0, Qcur, inp_pos, nullptr,
8191
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8192
+ ext_factor, attn_factor, beta_fast, beta_slow
8193
+ );
 
8194
 
8195
+ Kcur = ggml_rope_ext(
8196
+ ctx0, Kcur, inp_pos, nullptr,
8197
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8198
+ ext_factor, attn_factor, beta_fast, beta_slow
8199
+ );
8200
 
8201
+ cb(Qcur, "Qcur", il);
8202
+ cb(Kcur, "Kcur", il);
8203
+ cb(Vcur, "Vcur", il);
 
 
8204
 
8205
+ cur = build_attn(inp_attn, gf,
8206
+ model.layers[il].wo, NULL,
8207
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8208
+ }
 
 
 
8209
 
8210
+ if (il == n_layer - 1 && inp_out_ids) {
8211
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8212
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8213
+ }
8214
 
8215
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
8216
+ cb(ffn_inp, "ffn_inp", il);
8217
 
8218
+ // feed-forward network
8219
+ cur = build_norm(ffn_inp,
8220
+ model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
8221
+ LLM_NORM, il);
8222
+ cb(cur, "ffn_norm", il);
8223
+
8224
+ cur = build_ffn(cur,
8225
+ model.layers[il].ffn_up, NULL, NULL,
8226
+ model.layers[il].ffn_gate, NULL, NULL,
8227
+ model.layers[il].ffn_down, NULL, NULL,
8228
+ NULL,
8229
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
8230
+ cb(cur, "ffn_out", il);
8231
+
8232
+ cur = ggml_add(ctx0, cur, ffn_inp);
8233
+
8234
+ cur = build_cvec(cur, il);
8235
+ cb(cur, "l_out", il);
8236
+
8237
+ // input for next layer
8238
+ inpL = cur;
8239
+ }
8240
 
8241
+ cur = inpL;
8242
 
8243
+ cur = build_norm(cur,
8244
+ model.output_norm, model.output_norm_b,
8245
+ LLM_NORM, -1);
8246
 
8247
+ cb(cur, "result_norm", -1);
8248
+ res->t_embd = cur;
8249
 
8250
+ // lm_head
8251
+ cur = build_lora_mm(model.output, cur);
8252
 
8253
+ cb(cur, "result_output", -1);
8254
+ res->t_logits = cur;
8255
 
8256
+ ggml_build_forward_expand(gf, cur);
8257
  }
8258
  };
8259
 
 
8274
 
8275
  auto * inp_attn = build_attn_inp_kv_unified();
8276
 
8277
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8278
+
8279
  for (int il = 0; il < n_layer; ++il) {
8280
  ggml_tensor * inpSA = inpL;
8281
 
 
8334
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8335
  }
8336
 
8337
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
8338
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8339
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8340
  }
 
8410
 
8411
  auto * inp_attn = build_attn_inp_kv_unified();
8412
 
8413
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8414
+
8415
  for (int il = 0; il < n_layer; ++il) {
8416
  ggml_tensor * inpSA = inpL;
8417
 
 
8531
  q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
8532
  }
8533
 
8534
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
8535
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8536
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8537
  }
8538
 
8539
  // scale_res - scale the hidden states for residual connection
8540
+ const float scale_res = scale_depth/sqrtf(float(n_layer)); // TODO: is this correct?
8541
  cur = ggml_scale(ctx0, cur, scale_res);
8542
  cb(cur, "hidden_scaled", il);
8543
 
 
8614
 
8615
  auto * inp_attn = build_attn_inp_kv_unified();
8616
 
8617
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8618
+
8619
  for (int il = 0; il < n_layer; ++il) {
8620
  // norm
8621
  cur = build_norm(inpL,
 
8661
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8662
  }
8663
 
8664
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
8665
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8666
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8667
  }
 
8730
 
8731
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
8732
 
8733
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8734
+
8735
  for (int il = 0; il < n_layer; ++il) {
8736
  // norm
8737
  cur = build_norm(inpL,
 
8776
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8777
  }
8778
 
8779
+ if (il == n_layer - 1 && inp_out_ids) {
8780
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8781
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8782
+ }
8783
+
8784
  cur = build_norm(cur,
8785
  model.layers[il].attn_post_norm, NULL,
8786
  LLM_NORM_RMS, il);
8787
  cb(cur, "attn_post_norm", il);
8788
 
 
 
 
 
 
 
 
8789
  ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
8790
  cb(sa_out, "sa_out", il);
8791
 
 
8864
  // TODO: is causal == true correct? might need some changes
8865
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
8866
 
8867
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
8868
+
8869
  for (int il = 0; il < n_layer; ++il) {
8870
  const float freq_base_l = model.get_rope_freq_base (cparams, il);
8871
  const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 
8918
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8919
  }
8920
 
8921
+ if (il == n_layer - 1 && inp_out_ids) {
8922
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8923
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
8924
+ }
8925
+
8926
  cur = build_norm(cur,
8927
  model.layers[il].attn_post_norm, NULL,
8928
  LLM_NORM_RMS, il);
8929
  cb(cur, "attn_post_norm", il);
8930
 
 
 
 
 
 
 
 
8931
  ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
8932
  cb(sa_out, "sa_out", il);
8933
 
 
8998
 
8999
  auto * inp_attn = build_attn_inp_kv_unified();
9000
 
9001
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9002
+
9003
  for (int il = 0; il < n_layer; ++il) {
9004
  ggml_tensor * inpSA = inpL;
9005
 
 
9058
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9059
  }
9060
 
9061
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
9062
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9063
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9064
  }
 
9119
  // {n_embd, n_tokens}
9120
  inpL = build_inp_embd(model.tok_embd);
9121
 
9122
+ auto * rs_inp = build_rs_inp();
9123
+
9124
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9125
 
9126
  for (int il = 0; il < n_layer; ++il) {
9127
  // norm
 
9130
  LLM_NORM_RMS, il);
9131
  cb(cur, "attn_norm", il);
9132
 
9133
+ cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
9134
 
9135
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
9136
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9137
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
9138
  }
 
9166
 
9167
  // TODO: split
9168
  ggml_tensor * build_mamba_layer(
9169
+ llm_graph_input_rs * inp,
9170
+ ggml_cgraph * gf,
9171
+ ggml_tensor * cur,
9172
+ const llama_ubatch & ubatch,
9173
+ int il) const {
9174
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
9175
 
9176
  const auto kv_head = kv_state->get_head();
9177
 
 
9191
  GGML_ASSERT(ubatch.equal_seqs);
9192
  GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
9193
 
9194
+ ggml_tensor * conv_states_all = kv_state->get_r_l(il);
9195
+ ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
9196
 
9197
  // (ab)using the KV cache to store the states
9198
+ ggml_tensor * conv = build_rs(
9199
+ inp, gf, conv_states_all,
9200
+ hparams.n_embd_r(), n_seqs);
9201
  conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9202
+ ggml_tensor * ssm = build_rs(
9203
+ inp, gf, ssm_states_all,
9204
+ hparams.n_embd_s(), n_seqs);
9205
  ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
9206
 
9207
  // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
 
9314
 
9315
  auto * inp_attn = build_attn_inp_kv_unified();
9316
 
9317
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9318
 
9319
+ for (int il = 0; il < n_layer; ++il) {
9320
  // norm
9321
  cur = build_norm(inpL,
9322
  model.layers[il].attn_norm, NULL,
9323
  LLM_NORM, il);
9324
  cb(cur, "attn_norm", il);
9325
+
9326
  ggml_tensor * ffn_inp = cur;
9327
 
9328
  // self-attention
 
9390
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9391
  }
9392
 
9393
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
9394
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9395
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
9396
  ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
 
9461
 
9462
  auto * inp_attn = build_attn_inp_kv_unified_iswa();
9463
 
9464
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9465
+
9466
  for (int il = 0; il < n_layer; ++il) {
9467
  const bool is_swa = hparams.is_swa(il);
9468
 
 
9525
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9526
  }
9527
 
9528
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
9529
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9530
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
9531
  ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
 
9596
 
9597
  auto * inp_attn = build_attn_inp_kv_unified();
9598
 
9599
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9600
+
9601
  for (int il = 0; il < n_layer; ++il) {
9602
  ggml_tensor * inpSA = inpL;
9603
 
 
9656
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9657
  }
9658
 
9659
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
9660
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9661
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9662
  }
 
9724
 
9725
  auto * inp_attn = build_attn_inp_kv_unified();
9726
 
9727
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9728
+
9729
  for (int il = 0; il < n_layer; ++il) {
9730
  ggml_tensor * inpSA = inpL;
9731
 
 
9776
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9777
  }
9778
 
9779
+ if (il == n_layer - 1 && inp_out_ids) {
9780
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9781
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9782
+ }
9783
+
9784
  cur = build_norm(cur,
9785
  model.layers[il].attn_post_norm, NULL,
9786
  LLM_NORM_RMS, il);
9787
  cb(cur, "attn_post_norm", il);
9788
 
 
 
 
 
 
 
 
9789
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
9790
  cb(ffn_inp, "ffn_inp", il);
9791
 
 
9853
 
9854
  auto * inp_attn = build_attn_inp_kv_unified();
9855
 
9856
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9857
+
9858
  for (int il = 0; il < n_layer; ++il) {
9859
  ggml_tensor * inpSA = inpL;
9860
 
 
9909
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9910
  }
9911
 
9912
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
9913
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9914
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9915
  }
 
9979
 
9980
  auto * inp_attn = build_attn_inp_kv_unified();
9981
 
9982
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
9983
+
9984
  for (int il = 0; il < n_layer; ++il) {
9985
  const int64_t n_head = hparams.n_head(il);
9986
  const int64_t n_head_kv = hparams.n_head_kv(il);
 
10042
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10043
  }
10044
 
10045
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
10046
  residual = ggml_get_rows(ctx0, residual, inp_out_ids);
10047
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10048
  }
10049
 
10050
  ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
 
10110
 
10111
  auto * inp_attn = build_attn_inp_kv_unified();
10112
 
10113
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10114
+
10115
  for (int il = 0; il < n_layer; ++il) {
10116
  cur = build_norm(inpL,
10117
  model.layers[il].attn_norm,
 
10156
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10157
  }
10158
 
10159
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
10160
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10161
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
10162
  }
 
10258
 
10259
  auto * inp_attn = build_attn_inp_kv_unified();
10260
 
10261
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10262
+
10263
  for (int il = 0; il < n_layer; ++il) {
10264
  ggml_tensor * inpSA = inpL;
10265
 
 
10306
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10307
  }
10308
 
10309
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
10310
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10311
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10312
  }
 
10398
 
10399
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
10400
 
10401
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10402
+
10403
  for (int il = 0; il < n_layer; ++il) {
10404
  ggml_tensor * inpSA = inpL;
10405
 
 
10461
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
10462
  }
10463
 
10464
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
10465
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10466
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10467
  }
10468
 
 
10469
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
10470
  cb(ffn_inp, "ffn_inp", il);
10471
 
 
10573
 
10574
  auto * inp_attn = build_attn_inp_kv_unified();
10575
 
10576
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10577
+
10578
  for (int il = 0; il < n_layer; ++il) {
10579
  ggml_tensor * inpSA = inpL;
10580
 
 
10724
  }
10725
  }
10726
 
10727
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
10728
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10729
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10730
  }
 
10820
 
10821
  auto * inp_attn = build_attn_inp_kv_unified();
10822
 
10823
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10824
+
10825
  for (int il = 0; il < n_layer; ++il) {
10826
  ggml_tensor * inpSA = inpL;
10827
 
 
10904
  cb(cur, "attn_o_out", il);
10905
  }
10906
 
10907
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
10908
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10909
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10910
  }
 
10979
 
10980
  auto * inp_attn = build_attn_inp_no_cache();
10981
 
10982
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
10983
+
10984
  for (int il = 0; il < n_layer; ++il) {
10985
  ggml_tensor * inpSA = inpL;
10986
 
 
11014
  cb(cur, "kqv_out", il);
11015
  }
11016
 
11017
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
11018
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11019
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11020
  }
 
11085
  auto * inp_attn_self = build_attn_inp_kv_unified();
11086
  auto * inp_attn_cross = build_attn_inp_cross();
11087
 
11088
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11089
+
11090
  for (int il = 0; il < n_layer; ++il) {
11091
  ggml_tensor * inpSA = inpL;
11092
 
 
11178
  //cb(cur, "kqv_out", il);
11179
  }
11180
 
11181
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
11182
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
 
11183
  inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
11184
  }
11185
 
 
11249
 
11250
  auto * inp_attn = build_attn_inp_kv_unified();
11251
 
11252
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11253
+
11254
  for (int il = 0; il < n_layer; ++il) {
11255
  cur = build_norm(inpL,
11256
  model.layers[il].attn_norm,
 
11283
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
11284
  }
11285
 
11286
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
11287
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11288
  inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
11289
  }
 
11347
 
11348
  auto * inp_attn = build_attn_inp_kv_unified();
11349
 
11350
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11351
+
11352
  for (int il = 0; il < n_layer; ++il) {
11353
  ggml_tensor * inpSA = inpL;
11354
 
 
11415
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11416
  }
11417
 
11418
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
11419
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11420
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11421
  }
 
11480
 
11481
  auto * inp_attn = build_attn_inp_kv_unified();
11482
 
11483
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11484
+
11485
  for (int il = 0; il < n_layer; ++il) {
11486
  ggml_tensor * inpSA = inpL;
11487
 
 
11548
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11549
  }
11550
 
11551
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
11552
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11553
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11554
  }
 
11631
 
11632
  auto * inp_attn = build_attn_inp_kv_unified();
11633
 
11634
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11635
+
11636
  for (int il = 0; il < n_layer; ++il) {
11637
  ggml_tensor * inpSA = inpL;
11638
 
 
11692
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11693
  }
11694
 
11695
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
11696
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11697
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11698
  }
 
11760
 
11761
  auto * inp_attn = build_attn_inp_kv_unified();
11762
 
11763
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
11764
+
11765
  for (int il = 0; il < n_layer; ++il) {
11766
  ggml_tensor * inpSA = inpL;
11767
 
 
11823
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11824
  }
11825
 
11826
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
11827
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11828
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11829
  }
 
11910
  }
11911
 
11912
  ggml_tensor * build_rwkv6_time_mix(
11913
+ llm_graph_input_rs * inp,
11914
  ggml_cgraph * gf,
11915
  ggml_tensor * cur,
11916
  ggml_tensor * x_prev,
 
11917
  const llama_ubatch & ubatch,
11918
  int il) const {
11919
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
11920
 
11921
  const auto n_tokens = ubatch.n_tokens;
11922
  const auto n_seqs = ubatch.n_seqs;
 
12037
  k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
12038
  }
12039
 
12040
+ ggml_tensor * wkv_state = build_rs(
12041
+ inp, gf, kv_state->get_s_l(il),
12042
+ hparams.n_embd_s(), n_seqs);
12043
 
12044
  ggml_tensor * wkv_output;
12045
  if (is_qrwkv) {
 
12057
  wkv_state,
12058
  ggml_view_1d(
12059
  ctx0,
12060
+ kv_state->get_s_l(il),
12061
+ hparams.n_embd_s() * n_seqs,
12062
+ hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
12063
  )
12064
  )
12065
  );
 
12093
  inpL = build_inp_embd(model.tok_embd);
12094
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
12095
 
12096
+ auto * rs_inp = build_rs_inp();
12097
 
12098
  const auto n_embd = hparams.n_embd;
12099
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12100
  const auto n_seqs = ubatch.n_seqs;
12101
 
12102
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12103
+
12104
  for (int il = 0; il < n_layer; ++il) {
12105
  const llama_layer * layer = &model.layers[il];
12106
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12107
 
12108
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
 
12109
 
12110
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
12111
  ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
 
12120
  1
12121
  );
12122
 
12123
+ cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
12124
 
12125
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12126
  cb(ffn_inp, "ffn_inp", il);
 
12142
  );
12143
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
12144
 
12145
+ ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
12146
+ ffn_norm = ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens);
12147
+ x_prev = ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens);
12148
+ cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
12149
+
12150
+ if (il == n_layer - 1 && inp_out_ids) {
12151
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
12152
+ ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
12153
+ x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
12154
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12155
  }
12156
 
12157
  cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
 
12186
  // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
12187
  struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
12188
  llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
12189
+ GGML_ASSERT(n_embd == hparams.n_embd_r());
12190
 
12191
  ggml_tensor * cur;
12192
  ggml_tensor * inpL;
12193
 
12194
  inpL = build_inp_embd(model.tok_embd);
12195
 
12196
+ auto * rs_inp = build_rs_inp();
12197
 
12198
  const auto n_embd = hparams.n_embd;
12199
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12200
  const auto n_seqs = ubatch.n_seqs;
12201
 
12202
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12203
+
12204
  for (int il = 0; il < n_layer; ++il) {
12205
  const llama_layer * layer = &model.layers[il];
12206
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12207
 
12208
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
 
12209
 
12210
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
12211
  cb(att_norm, "attn_norm", il);
 
12217
  1
12218
  );
12219
 
12220
+ cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
12221
 
12222
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
12223
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
 
12225
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12226
  cb(ffn_inp, "ffn_inp", il);
12227
 
12228
+ cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
12229
+ ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
12230
+
12231
+ if (il == n_layer - 1 && inp_out_ids) {
12232
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12233
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
12234
  }
12235
 
12236
  // feed-forward network
 
12306
  }
12307
 
12308
  ggml_tensor * build_rwkv7_time_mix(
12309
+ llm_graph_input_rs * inp,
12310
  ggml_cgraph * gf,
12311
  ggml_tensor * cur,
12312
  ggml_tensor * x_prev,
 
12313
  ggml_tensor *& first_layer_value,
12314
  const llama_ubatch & ubatch,
12315
  int il) const {
12316
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
12317
 
12318
  const auto n_tokens = ubatch.n_tokens;
12319
  const auto n_seqs = ubatch.n_seqs;
 
12392
  v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
12393
  a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
12394
 
12395
+ ggml_tensor * wkv_state = build_rs(
12396
+ inp, gf, kv_state->get_s_l(il),
12397
+ hparams.n_embd_s(), n_seqs);
12398
 
12399
  ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
12400
  cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
 
12407
  wkv_state,
12408
  ggml_view_1d(
12409
  ctx0,
12410
+ kv_state->get_s_l(il),
12411
+ hparams.n_embd_s() * n_seqs,
12412
+ hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
12413
  )
12414
  )
12415
  );
 
12450
  inpL = build_inp_embd(model.tok_embd);
12451
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
12452
 
12453
+ auto * rs_inp = build_rs_inp();
12454
 
12455
  const auto n_embd = hparams.n_embd;
12456
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12457
  const auto n_seqs = ubatch.n_seqs;
12458
 
12459
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12460
+
12461
  for (int il = 0; il < n_layer; ++il) {
12462
  const llama_layer * layer = &model.layers[il];
12463
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12464
 
12465
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
 
12466
 
12467
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
12468
  ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
 
12477
  1
12478
  );
12479
 
12480
+ cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
12481
 
12482
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12483
  cb(ffn_inp, "ffn_inp", il);
 
12499
  );
12500
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
12501
 
12502
+ ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
12503
+ ffn_norm = ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens);
12504
+ x_prev = ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens);
12505
+
12506
+ if (il == n_layer - 1 && inp_out_ids) {
12507
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
12508
+ ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
12509
+ x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
12510
  }
12511
 
12512
  cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
 
12537
 
12538
  struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12539
  llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
12540
+ GGML_ASSERT(n_embd == hparams.n_embd_r());
12541
 
12542
  ggml_tensor * cur;
12543
  ggml_tensor * inpL;
 
12545
 
12546
  inpL = build_inp_embd(model.tok_embd);
12547
 
12548
+ auto * rs_inp = build_rs_inp();
12549
 
12550
  const auto n_embd = hparams.n_embd;
12551
  const auto n_seq_tokens = ubatch.n_seq_tokens;
12552
  const auto n_seqs = ubatch.n_seqs;
12553
 
12554
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12555
+
12556
  for (int il = 0; il < n_layer; ++il) {
12557
  const llama_layer * layer = &model.layers[il];
12558
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12559
 
12560
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
 
12561
 
12562
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
12563
  cb(att_norm, "attn_norm", il);
 
12569
  1
12570
  );
12571
 
12572
+ cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
12573
 
12574
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
12575
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
 
12577
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12578
  cb(ffn_inp, "ffn_inp", il);
12579
 
12580
+ cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
12581
+ ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
12582
+
12583
+ if (il == n_layer - 1 && inp_out_ids) {
12584
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12585
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
12586
  }
12587
 
12588
  // feed-forward network
 
12651
  auto * inp_attn = build_attn_inp_kv_unified();
12652
 
12653
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
12654
+
12655
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12656
+
12657
  for (int il = 0; il < n_layer; ++il) {
12658
  ggml_tensor * inpSA = inpL;
12659
 
 
12716
  cb(cur, "attn_out", il);
12717
  }
12718
 
12719
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
12720
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12721
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
12722
  }
 
12835
 
12836
  auto * inp_attn = build_attn_inp_kv_unified();
12837
 
12838
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12839
+
12840
  for (int il = 0; il < n_layer; ++il) {
12841
  ggml_tensor * inpSA = inpL;
12842
 
 
12913
  cur = build_attn(inp_attn, gf,
12914
  model.layers[il].wo, nullptr,
12915
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
 
 
 
 
 
 
12916
  }
12917
 
12918
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
12919
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12920
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
12921
  }
12922
 
12923
+ if (hparams.swin_norm) {
12924
+ cur = build_norm(cur,
12925
+ model.layers[il].attn_norm, NULL,
12926
+ LLM_NORM_RMS, il);
12927
+ }
12928
+
12929
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
12930
  cb(ffn_inp, "ffn_inp", il);
12931
 
 
13166
 
13167
  auto * inp_attn = build_attn_inp_kv_unified();
13168
 
13169
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13170
+
13171
  for (int il = 0; il < n_layer; ++il) {
13172
  ggml_tensor * inpSA = inpL;
13173
 
 
13271
  q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
13272
  }
13273
 
13274
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
13275
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13276
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13277
  }
 
13331
 
13332
  auto * inp_attn = build_attn_inp_kv_unified();
13333
 
13334
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13335
+
13336
  for (int il = 0; il < n_layer; ++il) {
13337
  ggml_tensor * inpSA = inpL;
13338
 
 
13394
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
13395
  }
13396
 
13397
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
13398
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13399
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13400
  }
 
13480
 
13481
  auto * inp_attn = build_attn_inp_kv_unified();
13482
 
13483
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13484
+
13485
  for (int il = 0; il < n_layer; ++il) {
13486
  ggml_tensor * inpSA = inpL;
13487
 
 
13534
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
13535
  }
13536
 
13537
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
13538
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13539
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13540
  }
 
13632
 
13633
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
13634
 
13635
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13636
+
13637
  for (int il = 0; il < n_layer; ++il) {
13638
  ggml_tensor * inpSA = inpL;
13639
 
 
13696
  cb(cur, "attn_out", il);
13697
  }
13698
 
13699
+ if (il == n_layer - 1 && inp_out_ids) {
 
 
13700
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13701
  inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13702
  }
 
13752
  llama_memory_i * res;
13753
 
13754
  switch (arch) {
13755
+ // Models that need specific instantiation should be handled in the
13756
+ // switch statement
13757
  case LLM_ARCH_BERT:
13758
  case LLM_ARCH_JINA_BERT_V2:
13759
  case LLM_ARCH_NOMIC_BERT:
 
13763
  {
13764
  res = nullptr;
13765
  } break;
13766
+ // Models that need standard caching should rely on recurrent/hybrid
13767
+ // checks
 
 
 
 
 
 
 
 
 
 
 
 
13768
  default:
13769
  {
13770
+ if (llm_arch_is_recurrent(arch)) {
13771
+ res = new llama_memory_recurrent(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13772
  *this,
13773
  nullptr,
13774
+ GGML_TYPE_F32,
13775
+ GGML_TYPE_F32,
 
13776
  cparams.offload_kqv,
13777
+ std::max((uint32_t) 1, cparams.n_seq_max),
13778
+ cparams.n_seq_max);
13779
+ } else if (llm_arch_is_hybrid(arch)) {
13780
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13781
+
13782
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13783
+
13784
+ res = new llama_memory_hybrid(
13785
+ /* model */ *this,
13786
+ /* attn_type_k */ params.type_k,
13787
+ /* attn_type_v */ params.type_v,
13788
+ /* attn_v_trans */ !cparams.flash_attn,
13789
+ /* attn_kv_size */ cparams.n_ctx,
13790
+ /* attn_n_pad */ padding,
13791
+ /* attn_n_swa */ hparams.n_swa,
13792
+ /* attn_swa_type */ hparams.swa_type,
13793
+ /* recurrent_type_k */ GGML_TYPE_F32,
13794
+ /* recurrent_type_v */ GGML_TYPE_F32,
13795
+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13796
+ /* n_seq_max */ cparams.n_seq_max,
13797
+ /* offload */ cparams.offload_kqv);
13798
+ } else {
13799
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13800
+
13801
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13802
+
13803
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13804
+
13805
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13806
+ GGML_ASSERT(hparams.is_swa_any());
13807
+
13808
+ res = new llama_kv_cache_unified_iswa(
13809
+ *this,
13810
+ params.type_k,
13811
+ params.type_v,
13812
+ !cparams.flash_attn,
13813
+ cparams.offload_kqv,
13814
+ params.swa_full,
13815
+ cparams.n_ctx,
13816
+ cparams.n_seq_max,
13817
+ cparams.n_ubatch,
13818
+ padding);
13819
+ } else {
13820
+ GGML_ASSERT(!hparams.is_swa_any());
13821
+
13822
+ res = new llama_kv_cache_unified(
13823
+ *this,
13824
+ nullptr,
13825
+ params.type_k,
13826
+ params.type_v,
13827
+ !cparams.flash_attn,
13828
+ cparams.offload_kqv,
13829
+ cparams.n_ctx,
13830
+ cparams.n_seq_max,
13831
+ padding,
13832
+ hparams.n_swa,
13833
+ hparams.swa_type);
13834
+ }
13835
  }
13836
  }
13837
  }
 
14411
  }
14412
 
14413
  bool llama_model_is_recurrent(const llama_model * model) {
14414
+ return llm_arch_is_recurrent(model->arch);
 
 
 
 
 
 
 
14415
  }
14416
 
14417
  const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -1269,6 +1269,7 @@ struct llama_vocab::impl {
1269
  bool add_space_prefix = false;
1270
  bool add_bos = false;
1271
  bool add_eos = false;
 
1272
  bool ignore_merges = false;
1273
  bool clean_spaces = false; // clean_up_tokenization_spaces
1274
  bool remove_extra_whitespaces = false;
@@ -1421,6 +1422,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1421
  special_sep_id = 102;
1422
  special_pad_id = 0;
1423
  special_mask_id = 103;
 
 
1424
  } else if (tokenizer_model == "gpt2") {
1425
  type = LLAMA_VOCAB_TYPE_BPE;
1426
 
@@ -1550,12 +1553,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1550
  tokenizer_pre == "jina-es" ||
1551
  tokenizer_pre == "jina-de" ||
1552
  tokenizer_pre == "gigachat" ||
1553
- tokenizer_pre == "jina-v1-en" ||
1554
  tokenizer_pre == "jina-v2-es" ||
1555
- tokenizer_pre == "jina-v2-de" ||
 
 
 
1556
  tokenizer_pre == "jina-v2-code" ||
1557
  tokenizer_pre == "roberta-bpe") {
1558
  pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
 
1559
  } else if (
1560
  tokenizer_pre == "refact") {
1561
  pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
@@ -1665,6 +1671,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1665
  clean_spaces = true;
1666
  add_bos = true;
1667
  add_eos = false;
 
1668
  } else if (type == LLAMA_VOCAB_TYPE_UGM) {
1669
  pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
1670
  add_bos = false;
@@ -1801,7 +1808,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1801
  }
1802
  }
1803
 
1804
- // Handle add_bos and add_eos
1805
  {
1806
  bool temp = true;
1807
 
@@ -1811,6 +1818,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1811
  if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
1812
  add_eos = temp;
1813
  }
 
 
 
1814
  }
1815
 
1816
  // auto-detect special tokens by text
@@ -2060,9 +2070,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
2060
  //NOTE: Per token attributes are missing from the GGUF file.
2061
  //TODO: Extract attributes from GGUF file.
2062
  {
2063
- auto _contains_any = [] (const std::string & str, const std::vector<std::string> & substrs) -> bool {
2064
  for (const auto & substr : substrs) {
2065
- if (str.find(substr) < std::string::npos) {
2066
  return true;
2067
  }
2068
  }
@@ -3000,6 +3010,10 @@ bool llama_vocab::get_add_eos() const {
3000
  return pimpl->add_eos;
3001
  }
3002
 
 
 
 
 
3003
  bool llama_vocab::get_ignore_merges() const {
3004
  return pimpl->ignore_merges;
3005
  }
@@ -3060,6 +3074,11 @@ int32_t llama_vocab::tokenize(
3060
  bool add_special,
3061
  bool parse_special) const {
3062
  auto res = tokenize(std::string(text, text_len), add_special, parse_special);
 
 
 
 
 
3063
  if (n_tokens_max < (int) res.size()) {
3064
  // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
3065
  return -((int) res.size());
@@ -3191,6 +3210,10 @@ bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
3191
  return vocab->get_add_eos();
3192
  }
3193
 
 
 
 
 
3194
  llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
3195
  return vocab->token_fim_pre();
3196
  }
 
1269
  bool add_space_prefix = false;
1270
  bool add_bos = false;
1271
  bool add_eos = false;
1272
+ bool add_sep = false;
1273
  bool ignore_merges = false;
1274
  bool clean_spaces = false; // clean_up_tokenization_spaces
1275
  bool remove_extra_whitespaces = false;
 
1422
  special_sep_id = 102;
1423
  special_pad_id = 0;
1424
  special_mask_id = 103;
1425
+
1426
+ add_sep = true;
1427
  } else if (tokenizer_model == "gpt2") {
1428
  type = LLAMA_VOCAB_TYPE_BPE;
1429
 
 
1553
  tokenizer_pre == "jina-es" ||
1554
  tokenizer_pre == "jina-de" ||
1555
  tokenizer_pre == "gigachat" ||
 
1556
  tokenizer_pre == "jina-v2-es" ||
1557
+ tokenizer_pre == "jina-v2-de") {
1558
+ pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
1559
+ } else if (
1560
+ tokenizer_pre == "jina-v1-en" ||
1561
  tokenizer_pre == "jina-v2-code" ||
1562
  tokenizer_pre == "roberta-bpe") {
1563
  pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
1564
+ add_sep = true;
1565
  } else if (
1566
  tokenizer_pre == "refact") {
1567
  pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
 
1671
  clean_spaces = true;
1672
  add_bos = true;
1673
  add_eos = false;
1674
+ add_sep = true;
1675
  } else if (type == LLAMA_VOCAB_TYPE_UGM) {
1676
  pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
1677
  add_bos = false;
 
1808
  }
1809
  }
1810
 
1811
+ // Handle add_bos, add_eos and add_sep
1812
  {
1813
  bool temp = true;
1814
 
 
1818
  if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
1819
  add_eos = temp;
1820
  }
1821
+ if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) {
1822
+ add_sep = temp;
1823
+ }
1824
  }
1825
 
1826
  // auto-detect special tokens by text
 
2070
  //NOTE: Per token attributes are missing from the GGUF file.
2071
  //TODO: Extract attributes from GGUF file.
2072
  {
2073
+ auto _contains_any = [] (const std::string & str, const std::vector<std::string_view> & substrs) -> bool {
2074
  for (const auto & substr : substrs) {
2075
+ if (str.find(substr) != std::string::npos) {
2076
  return true;
2077
  }
2078
  }
 
3010
  return pimpl->add_eos;
3011
  }
3012
 
3013
+ bool llama_vocab::get_add_sep() const {
3014
+ return pimpl->add_sep;
3015
+ }
3016
+
3017
  bool llama_vocab::get_ignore_merges() const {
3018
  return pimpl->ignore_merges;
3019
  }
 
3074
  bool add_special,
3075
  bool parse_special) const {
3076
  auto res = tokenize(std::string(text, text_len), add_special, parse_special);
3077
+ if (res.size() >= static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
3078
+ LLAMA_LOG_ERROR("%s: tokenization result size %zu exceeds int32_t limit\n", __func__, res.size());
3079
+ return std::numeric_limits<int32_t>::min();
3080
+ }
3081
+
3082
  if (n_tokens_max < (int) res.size()) {
3083
  // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
3084
  return -((int) res.size());
 
3210
  return vocab->get_add_eos();
3211
  }
3212
 
3213
+ bool llama_vocab_get_add_sep(const struct llama_vocab * vocab) {
3214
+ return vocab->get_add_sep();
3215
+ }
3216
+
3217
  llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
3218
  return vocab->token_fim_pre();
3219
  }
examples/talk-llama/llama-vocab.h CHANGED
@@ -74,6 +74,7 @@ struct llama_vocab {
74
  bool get_add_space_prefix () const;
75
  bool get_add_bos () const;
76
  bool get_add_eos () const;
 
77
  bool get_ignore_merges () const;
78
  bool get_clean_spaces () const;
79
  bool get_remove_extra_whitespaces () const;
 
74
  bool get_add_space_prefix () const;
75
  bool get_add_bos () const;
76
  bool get_add_eos () const;
77
+ bool get_add_sep () const;
78
  bool get_ignore_merges () const;
79
  bool get_clean_spaces () const;
80
  bool get_remove_extra_whitespaces () const;
examples/talk-llama/llama.h CHANGED
@@ -1044,6 +1044,7 @@ extern "C" {
1044
 
1045
  LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
1046
  LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
 
1047
 
1048
  LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
1049
  LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
@@ -1087,6 +1088,7 @@ extern "C" {
1087
  /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
1088
  /// @return Returns the number of tokens on success, no more than n_tokens_max
1089
  /// @return Returns a negative number on failure - the number of tokens that would have been returned
 
1090
  /// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
1091
  /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
1092
  /// as plaintext. Does not insert a leading space.
 
1044
 
1045
  LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
1046
  LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
1047
+ LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
1048
 
1049
  LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
1050
  LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
 
1088
  /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
1089
  /// @return Returns the number of tokens on success, no more than n_tokens_max
1090
  /// @return Returns a negative number on failure - the number of tokens that would have been returned
1091
+ /// @return Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
1092
  /// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
1093
  /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
1094
  /// as plaintext. Does not insert a leading space.
examples/talk-llama/unicode.cpp CHANGED
@@ -204,12 +204,17 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
204
  // disable C++17 deprecation warning for std::codecvt_utf8
205
  # pragma clang diagnostic push
206
  # pragma clang diagnostic ignored "-Wdeprecated-declarations"
 
 
 
207
  #endif
208
 
209
  std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
210
 
211
  #if defined(__clang__)
212
  # pragma clang diagnostic pop
 
 
213
  #endif
214
 
215
  return conv.from_bytes(s);
 
204
  // disable C++17 deprecation warning for std::codecvt_utf8
205
  # pragma clang diagnostic push
206
  # pragma clang diagnostic ignored "-Wdeprecated-declarations"
207
+ #elif defined(__GNUC__)
208
+ # pragma GCC diagnostic push
209
+ # pragma GCC diagnostic ignored "-Wdeprecated-declarations"
210
  #endif
211
 
212
  std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
213
 
214
  #if defined(__clang__)
215
  # pragma clang diagnostic pop
216
+ #elif defined(__GNUC__)
217
+ # pragma GCC diagnostic pop
218
  #endif
219
 
220
  return conv.from_bytes(s);