Spaces:
Running
Running
| // | |
| // llama_memory_recurrent | |
| // | |
| llama_memory_recurrent::llama_memory_recurrent( | |
| const llama_model & model, | |
| layer_filter_cb && filter, | |
| ggml_type type_r, | |
| ggml_type type_s, | |
| bool offload, | |
| uint32_t mem_size, | |
| uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { | |
| const int32_t n_layer = hparams.n_layer; | |
| head = 0; | |
| size = mem_size; | |
| used = 0; | |
| cells.clear(); | |
| cells.resize(mem_size); | |
| // create a context for each buffer type | |
| std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; | |
| auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { | |
| auto it = ctx_map.find(buft); | |
| if (it == ctx_map.end()) { | |
| ggml_init_params params = { | |
| /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), | |
| /*.mem_buffer =*/ NULL, | |
| /*.no_alloc =*/ true, | |
| }; | |
| ggml_context * ctx = ggml_init(params); | |
| if (!ctx) { | |
| return nullptr; | |
| } | |
| ctx_map[buft] = ctx; | |
| ctxs.emplace_back(ctx); | |
| return ctx; | |
| } | |
| return it->second; | |
| }; | |
| r_l.resize(n_layer); | |
| s_l.resize(n_layer); | |
| for (int i = 0; i < n_layer; i++) { | |
| if (filter && !filter(i)) { | |
| LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i); | |
| continue; | |
| } | |
| const char * dev_name = "CPU"; | |
| ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); | |
| if (offload) { | |
| auto * dev = model.dev_layer(i); | |
| buft = ggml_backend_dev_buffer_type(dev); | |
| dev_name = ggml_backend_dev_name(dev); | |
| } | |
| LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); | |
| ggml_context * ctx = ctx_for_buft(buft); | |
| if (!ctx) { | |
| throw std::runtime_error("failed to create ggml context for rs cache"); | |
| } | |
| ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size); | |
| ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size); | |
| ggml_format_name(r, "cache_r_l%d", i); | |
| ggml_format_name(s, "cache_s_l%d", i); | |
| r_l[i] = r; | |
| s_l[i] = s; | |
| } | |
| // allocate tensors and initialize the buffers to avoid NaNs in the padding | |
| for (auto it : ctx_map) { | |
| auto * buft = it.first; | |
| auto * ctx = it.second; | |
| ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); | |
| if (!buf) { | |
| throw std::runtime_error("failed to allocate buffer for rs cache"); | |
| } | |
| ggml_backend_buffer_clear(buf, 0); | |
| LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); | |
| bufs.emplace_back(buf); | |
| } | |
| { | |
| const size_t memory_size_r = size_r_bytes(); | |
| const size_t memory_size_s = size_s_bytes(); | |
| LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, | |
| (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, | |
| ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), | |
| ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); | |
| } | |
| } | |
| void llama_memory_recurrent::clear(bool data) { | |
| for (int32_t i = 0; i < (int32_t) size; ++i) { | |
| cells[i].pos = -1; | |
| cells[i].seq_id.clear(); | |
| cells[i].src = -1; | |
| cells[i].tail = -1; | |
| } | |
| head = 0; | |
| used = 0; | |
| if (data) { | |
| for (auto & buf : bufs) { | |
| ggml_backend_buffer_clear(buf.get(), 0); | |
| } | |
| } | |
| } | |
| bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { | |
| uint32_t new_head = size; | |
| if (p0 < 0) { | |
| p0 = 0; | |
| } | |
| if (p1 < 0) { | |
| p1 = std::numeric_limits<llama_pos>::max(); | |
| } | |
| // models like Mamba or RWKV can't have a state partially erased | |
| if (seq_id >= (int64_t) size) { | |
| // could be fatal | |
| return false; | |
| } | |
| if (0 <= seq_id) { | |
| int32_t & tail_id = cells[seq_id].tail; | |
| if (tail_id >= 0) { | |
| const auto & cell = cells[tail_id]; | |
| // partial intersection is invalid | |
| if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { | |
| return false; | |
| } | |
| // invalidate tails which will be cleared | |
| if (p0 <= cell.pos && cell.pos < p1) { | |
| tail_id = -1; | |
| } | |
| } | |
| } else { | |
| // seq_id is negative, then the range should include everything or nothing | |
| if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) { | |
| return false; | |
| } | |
| } | |
| for (uint32_t i = 0; i < size; ++i) { | |
| if (cells[i].pos >= p0 && cells[i].pos < p1) { | |
| if (seq_id < 0) { | |
| cells[i].seq_id.clear(); | |
| } else if (cells[i].has_seq_id(seq_id)) { | |
| cells[i].seq_id.erase(seq_id); | |
| } else { | |
| continue; | |
| } | |
| if (cells[i].is_empty()) { | |
| // keep count of the number of used cells | |
| if (cells[i].pos >= 0) { | |
| used--; | |
| } | |
| cells[i].pos = -1; | |
| cells[i].src = -1; | |
| if (new_head == size) { | |
| new_head = i; | |
| } | |
| } | |
| } | |
| } | |
| // If we freed up a slot, set head to it so searching can start there. | |
| if (new_head != size && new_head < head) { | |
| head = new_head; | |
| } | |
| return true; | |
| } | |
| void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { | |
| if (seq_id_src == seq_id_dst) { | |
| return; | |
| } | |
| if (p0 < 0) { | |
| p0 = 0; | |
| } | |
| if (p1 < 0) { | |
| p1 = std::numeric_limits<llama_pos>::max(); | |
| } | |
| if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { | |
| auto & tail_src = cells[seq_id_src]; | |
| auto & tail_dst = cells[seq_id_dst]; | |
| if (tail_dst.tail >= 0) { | |
| // clear destination seq_id if it wasn't empty | |
| auto & cell_dst = cells[tail_dst.tail]; | |
| cell_dst.seq_id.erase(seq_id_dst); | |
| tail_dst.tail = -1; | |
| if (cell_dst.seq_id.empty()) { | |
| cell_dst.pos = -1; | |
| cell_dst.src = -1; | |
| used -= 1; | |
| } | |
| } | |
| if (tail_src.tail >= 0) { | |
| auto & cell_src = cells[tail_src.tail]; | |
| cell_src.seq_id.insert(seq_id_dst); | |
| tail_dst.tail = tail_src.tail; | |
| } | |
| } | |
| } | |
| void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) { | |
| uint32_t new_head = size; | |
| for (uint32_t i = 0; i < size; ++i) { | |
| if ((llama_seq_id) i != seq_id) { | |
| cells[i].tail = -1; | |
| } | |
| if (!cells[i].has_seq_id(seq_id)) { | |
| if (cells[i].pos >= 0) { | |
| used--; | |
| } | |
| cells[i].pos = -1; | |
| cells[i].src = -1; | |
| cells[i].seq_id.clear(); | |
| if (new_head == size){ | |
| new_head = i; | |
| } | |
| } else { | |
| cells[i].seq_id.clear(); | |
| cells[i].seq_id.insert(seq_id); | |
| } | |
| } | |
| // If we freed up a slot, set head to it so searching can start there. | |
| if (new_head != size && new_head < head) { | |
| head = new_head; | |
| } | |
| } | |
| void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { | |
| if (shift == 0) { | |
| return; | |
| } | |
| if (p0 < 0) { | |
| p0 = 0; | |
| } | |
| if (p1 < 0) { | |
| p1 = std::numeric_limits<llama_pos>::max(); | |
| } | |
| // If there is no range then return early to avoid looping over the | |
| if (p0 == p1) { | |
| return; | |
| } | |
| // for Mamba-like or RWKV models, only the pos needs to be shifted | |
| if (0 <= seq_id && seq_id < (int64_t) size) { | |
| const int32_t tail_id = cells[seq_id].tail; | |
| if (tail_id >= 0) { | |
| auto & cell = cells[tail_id]; | |
| if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { | |
| cell.pos += shift; | |
| } | |
| } | |
| } | |
| } | |
| void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { | |
| if (d == 1) { | |
| return; | |
| } | |
| if (p0 < 0) { | |
| p0 = 0; | |
| } | |
| if (p1 < 0) { | |
| p1 = std::numeric_limits<llama_pos>::max(); | |
| } | |
| // If there is no range then return early to avoid looping over the cache. | |
| if (p0 == p1) { | |
| return; | |
| } | |
| // for Mamba-like or RWKV models, only the pos needs to be changed | |
| if (0 <= seq_id && seq_id < (int64_t) size) { | |
| const int32_t tail_id = cells[seq_id].tail; | |
| if (tail_id >= 0) { | |
| auto & cell = cells[tail_id]; | |
| if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { | |
| cell.pos /= d; | |
| } | |
| } | |
| } | |
| } | |
| llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const { | |
| llama_pos result = std::numeric_limits<llama_pos>::max(); | |
| for (uint32_t i = 0; i < size; ++i) { | |
| if (cells[i].has_seq_id(seq_id)) { | |
| result = std::min(result, cells[i].pos); | |
| } | |
| } | |
| if (result == std::numeric_limits<llama_pos>::max()) { | |
| result = -1; | |
| } | |
| return result; | |
| } | |
| llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { | |
| llama_pos result = -1; | |
| for (uint32_t i = 0; i < size; ++i) { | |
| if (cells[i].has_seq_id(seq_id)) { | |
| result = std::max(result, cells[i].pos); | |
| } | |
| } | |
| return result; | |
| } | |
| llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { | |
| do { | |
| balloc.split_reset(); | |
| std::vector<llama_ubatch> ubatches; | |
| while (true) { | |
| llama_ubatch ubatch; | |
| if (embd_all) { | |
| // if all tokens are output, split by sequence | |
| ubatch = balloc.split_seq(n_ubatch); | |
| } else { | |
| ubatch = balloc.split_equal(n_ubatch, false); | |
| } | |
| if (ubatch.n_tokens == 0) { | |
| break; | |
| } | |
| ubatches.push_back(std::move(ubatch)); // NOLINT | |
| } | |
| if (balloc.get_n_used() < balloc.get_n_tokens()) { | |
| // failed to find a suitable split | |
| break; | |
| } | |
| if (!prepare(ubatches)) { | |
| break; | |
| } | |
| return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches)); | |
| } while (false); | |
| return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | |
| } | |
| llama_memory_context_ptr llama_memory_recurrent::init_full() { | |
| return std::make_unique<llama_memory_recurrent_context>(this); | |
| } | |
| llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) { | |
| GGML_UNUSED(lctx); | |
| GGML_UNUSED(optimize); | |
| return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE); | |
| } | |
| bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) { | |
| // simply remember the full state because it is very small for this type of cache | |
| // TODO: optimize | |
| auto org_cells = cells; | |
| auto org_used = used; | |
| auto org_head = head; | |
| bool success = true; | |
| for (const auto & ubatch : ubatches) { | |
| if (!find_slot(ubatch)) { | |
| success = false; | |
| break; | |
| } | |
| } | |
| // restore the original state | |
| cells = std::move(org_cells); | |
| used = org_used; | |
| head = org_head; | |
| return success; | |
| } | |
| bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { | |
| const uint32_t n_seq_tokens = ubatch.n_seq_tokens; | |
| const uint32_t n_seqs = ubatch.n_seqs; | |
| // if we have enough unused cells before the current head -> | |
| // better to start searching from the beginning of the cache, hoping to fill it | |
| if (head > used + 2*n_seqs) { | |
| head = 0; | |
| } | |
| // For recurrent state architectures (like Mamba or RWKV), | |
| // each cache cell can store the state for a whole sequence. | |
| // A slot should be always be contiguous. | |
| // can only process batches with an equal number of new tokens in each sequence | |
| GGML_ASSERT(ubatch.equal_seqs()); | |
| int32_t min = size - 1; | |
| int32_t max = 0; | |
| // everything should fit if all seq_ids are smaller than the max | |
| for (uint32_t s = 0; s < n_seqs; ++s) { | |
| const uint32_t i = s*n_seq_tokens; // first token of sequence set s | |
| const uint32_t n_seq_id = ubatch.n_seq_id[i]; | |
| for (uint32_t j = 0; j < n_seq_id; ++j) { | |
| const llama_seq_id seq_id = ubatch.seq_id[i][j]; | |
| if (seq_id < 0 || (uint32_t) seq_id >= size) { | |
| // too big seq_id | |
| // TODO: would it be possible to resize the cache instead? | |
| LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); | |
| return false; | |
| } | |
| if (j > 0) { | |
| auto & seq = cells[seq_id]; | |
| if (seq.tail >= 0) { | |
| auto & cell = cells[seq.tail]; | |
| // clear cells from seq_ids that become shared | |
| // (should not normally happen, but let's handle it anyway) | |
| cell.seq_id.erase(seq_id); | |
| seq.tail = -1; | |
| if (cell.seq_id.empty()) { | |
| cell.pos = -1; | |
| cell.src = -1; | |
| used -= 1; | |
| } | |
| } | |
| } | |
| } | |
| } | |
| { | |
| std::vector<int32_t> tails_verif; | |
| tails_verif.assign(size, -1); | |
| for (uint32_t i = 0; i < size; ++i) { | |
| auto & cell = cells[i]; | |
| for (llama_seq_id seq_id : cell.seq_id) { | |
| if (tails_verif[seq_id] != -1) { | |
| LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); | |
| } | |
| tails_verif[seq_id] = i; | |
| } | |
| } | |
| for (uint32_t i = 0; i < size; ++i) { | |
| if (tails_verif[i] != cells[i].tail) { | |
| LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); | |
| } | |
| } | |
| } | |
| // find next empty cell | |
| uint32_t next_empty_cell = head; | |
| for (uint32_t i = 0; i < size; ++i) { | |
| if (next_empty_cell >= size) { next_empty_cell -= size; } | |
| auto & cell = cells[next_empty_cell]; | |
| if (cell.is_empty()) { break; } | |
| next_empty_cell += 1; | |
| } | |
| // find usable cell range | |
| for (uint32_t s = 0; s < n_seqs; ++s) { | |
| const uint32_t i = s*n_seq_tokens; | |
| const llama_seq_id seq_id = ubatch.seq_id[i][0]; | |
| auto & seq_meta = cells[seq_id]; | |
| bool has_cell = false; | |
| if (seq_meta.tail >= 0) { | |
| auto & cell = cells[seq_meta.tail]; | |
| GGML_ASSERT(cell.has_seq_id(seq_id)); | |
| // does this seq_id "own" the cell? | |
| if (cell.seq_id.size() == 1) { has_cell = true; } | |
| } | |
| if (!has_cell) { | |
| auto & empty_cell = cells[next_empty_cell]; | |
| GGML_ASSERT(empty_cell.is_empty()); | |
| // copy old tail into the empty cell | |
| if (seq_meta.tail >= 0) { | |
| auto & orig_cell = cells[seq_meta.tail]; | |
| empty_cell.pos = orig_cell.pos; | |
| empty_cell.src = orig_cell.src; | |
| orig_cell.seq_id.erase(seq_id); | |
| empty_cell.seq_id.insert(seq_id); // will be overwritten | |
| GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id | |
| } | |
| seq_meta.tail = next_empty_cell; | |
| // find next empty cell | |
| if (s + 1 < n_seqs) { | |
| for (uint32_t j = 0; j < size; ++j) { | |
| next_empty_cell += 1; | |
| if (next_empty_cell >= size) { next_empty_cell -= size; } | |
| auto & cell = cells[next_empty_cell]; | |
| if (cell.is_empty()) { break; } | |
| } | |
| } | |
| } | |
| if (min > seq_meta.tail) { min = seq_meta.tail; } | |
| if (max < seq_meta.tail) { max = seq_meta.tail; } | |
| } | |
| // gather and re-order | |
| for (uint32_t s = 0; s < n_seqs; ++s) { | |
| const uint32_t i = s*n_seq_tokens; | |
| const int32_t dst_id = s + min; | |
| const int32_t src_id = cells[ubatch.seq_id[i][0]].tail; | |
| if (dst_id != src_id) { | |
| auto & dst_cell = cells[dst_id]; | |
| auto & src_cell = cells[src_id]; | |
| std::swap(dst_cell.pos, src_cell.pos); | |
| std::swap(dst_cell.src, src_cell.src); | |
| std::swap(dst_cell.seq_id, src_cell.seq_id); | |
| // swap tails | |
| for (uint32_t j = 0; j < size; ++j) { | |
| int32_t & tail = cells[j].tail; | |
| if (tail == src_id) { | |
| tail = dst_id; | |
| } else if (tail == dst_id) { | |
| tail = src_id; | |
| } | |
| } | |
| } | |
| } | |
| // update the pos of the used seqs | |
| for (uint32_t s = 0; s < n_seqs; ++s) { | |
| const uint32_t i = s*n_seq_tokens; | |
| const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1]; | |
| const int32_t cell_id = s + min; | |
| auto & cell = cells[cell_id]; | |
| if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { | |
| // What should happen when the pos backtracks or skips a value? | |
| // Clearing the state mid-batch would require special-casing which isn't done. | |
| LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", | |
| __func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens); | |
| } | |
| cell.pos = last_pos; | |
| cell.seq_id.clear(); | |
| for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) { | |
| const llama_seq_id seq_id = ubatch.seq_id[i][j]; | |
| cell.seq_id.insert(seq_id); | |
| cells[seq_id].tail = cell_id; | |
| } | |
| } | |
| // Find first cell without src refs, to use as the zero-ed state | |
| { | |
| // TODO: bake-in src refcounts in the cell metadata | |
| std::vector<int32_t> refcounts(size, 0); | |
| for (size_t i = 0; i < size; ++i) { | |
| const int32_t src = cells[i].src; | |
| if (src >= 0) { | |
| refcounts[src] += 1; | |
| } | |
| } | |
| rs_z = -1; | |
| for (int i = min; i <= max; ++i) { | |
| if (refcounts[i] == 0) { | |
| rs_z = i; | |
| break; | |
| } | |
| } | |
| for (int i = min; i <= max; ++i) { | |
| if (cells[i].src < 0) { | |
| GGML_ASSERT(rs_z >= 0); | |
| cells[i].src0 = rs_z; | |
| } else { | |
| // Stage the source ids for all used cells to allow correct seq_* behavior | |
| // and still make these values available when setting the inputs | |
| cells[i].src0 = cells[i].src; | |
| } | |
| cells[i].src = i; // avoid moving or clearing twice | |
| } | |
| } | |
| // allow getting the range of used cells, from head to head + n | |
| head = min; | |
| n = max - min + 1; | |
| used = std::count_if(cells.begin(), cells.end(), | |
| [](const mem_cell & cell){ return !cell.is_empty(); }); | |
| // sanity check | |
| return n >= n_seqs; | |
| } | |
| bool llama_memory_recurrent::get_can_shift() const { | |
| // shifting the pos is trivial for recurrent models | |
| return true; | |
| } | |
| size_t llama_memory_recurrent::total_size() const { | |
| size_t size = 0; | |
| for (const auto & buf : bufs) { | |
| size += ggml_backend_buffer_get_size(buf.get()); | |
| } | |
| return size; | |
| } | |
| size_t llama_memory_recurrent::size_r_bytes() const { | |
| size_t size_r_bytes = 0; | |
| for (const auto & r : r_l) { | |
| if (r != nullptr) { | |
| size_r_bytes += ggml_nbytes(r); | |
| } | |
| } | |
| return size_r_bytes; | |
| } | |
| size_t llama_memory_recurrent::size_s_bytes() const { | |
| size_t size_s_bytes = 0; | |
| for (const auto & s : s_l) { | |
| if (s != nullptr) { | |
| size_s_bytes += ggml_nbytes(s); | |
| } | |
| } | |
| return size_s_bytes; | |
| } | |
| void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { | |
| GGML_UNUSED(flags); | |
| std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive | |
| uint32_t cell_count = 0; | |
| // Count the number of cells with the specified seq_id | |
| // Find all the ranges of cells with this seq id (or all, when -1) | |
| uint32_t cell_range_begin = size; | |
| for (uint32_t i = 0; i < size; ++i) { | |
| const auto & cell = cells[i]; | |
| if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { | |
| ++cell_count; | |
| if (cell_range_begin == size) { | |
| cell_range_begin = i; | |
| } | |
| } else { | |
| if (cell_range_begin != size) { | |
| cell_ranges.emplace_back(cell_range_begin, i); | |
| cell_range_begin = size; | |
| } | |
| } | |
| } | |
| if (cell_range_begin != size) { | |
| cell_ranges.emplace_back(cell_range_begin, size); | |
| } | |
| // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count | |
| uint32_t cell_count_check = 0; | |
| for (const auto & range : cell_ranges) { | |
| cell_count_check += range.second - range.first; | |
| } | |
| GGML_ASSERT(cell_count == cell_count_check); | |
| io.write(&cell_count, sizeof(cell_count)); | |
| state_write_meta(io, cell_ranges, seq_id); | |
| state_write_data(io, cell_ranges); | |
| } | |
| void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { | |
| GGML_UNUSED(flags); | |
| uint32_t cell_count; | |
| io.read_to(&cell_count, sizeof(cell_count)); | |
| bool res = true; | |
| res = res && state_read_meta(io, cell_count, seq_id); | |
| res = res && state_read_data(io, cell_count); | |
| if (!res) { | |
| if (seq_id == -1) { | |
| clear(true); | |
| } else { | |
| seq_rm(seq_id, -1, -1); | |
| } | |
| throw std::runtime_error("failed to restore kv cache"); | |
| } | |
| } | |
| void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const { | |
| for (const auto & range : cell_ranges) { | |
| for (uint32_t i = range.first; i < range.second; ++i) { | |
| const auto & cell = cells[i]; | |
| const llama_pos pos = cell.pos; | |
| const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; | |
| io.write(&pos, sizeof(pos)); | |
| io.write(&n_seq_id, sizeof(n_seq_id)); | |
| if (n_seq_id) { | |
| for (auto seq_id : cell.seq_id) { | |
| io.write(&seq_id, sizeof(seq_id)); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const { | |
| const uint32_t s_trans = 0; | |
| const uint32_t n_layer = hparams.n_layer; | |
| io.write(&s_trans, sizeof(s_trans)); | |
| io.write(&n_layer, sizeof(n_layer)); | |
| std::vector<uint8_t> tmp_buf; | |
| // Iterate and write all the keys first, each row is a cell | |
| // Get whole range at a time | |
| for (uint32_t il = 0; il < n_layer; ++il) { | |
| // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) | |
| if (r_l[il] == nullptr) continue; | |
| // Write key type | |
| const int32_t r_type_i = (int32_t)r_l[il]->type; | |
| io.write(&r_type_i, sizeof(r_type_i)); | |
| // Write row size of key | |
| const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); | |
| io.write(&r_size_row, sizeof(r_size_row)); | |
| // Read each range of cells of k_size length each into tmp_buf and write out | |
| for (const auto & range : cell_ranges) { | |
| const size_t range_size = range.second - range.first; | |
| const size_t buf_size = range_size * r_size_row; | |
| io.write_tensor(r_l[il], range.first * r_size_row, buf_size); | |
| } | |
| } | |
| if (!s_trans) { | |
| for (uint32_t il = 0; il < n_layer; ++il) { | |
| // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) | |
| if (s_l[il] == nullptr) continue; | |
| // Write value type | |
| const int32_t s_type_i = (int32_t)s_l[il]->type; | |
| io.write(&s_type_i, sizeof(s_type_i)); | |
| // Write row size of value | |
| const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); | |
| io.write(&s_size_row, sizeof(s_size_row)); | |
| // Read each range of cells of s_size length each into tmp_buf and write out | |
| for (const auto & range : cell_ranges) { | |
| const size_t range_size = range.second - range.first; | |
| const size_t buf_size = range_size * s_size_row; | |
| io.write_tensor(s_l[il], range.first * s_size_row, buf_size); | |
| } | |
| } | |
| } else { | |
| // When v is transposed, we also need the element size and get the element ranges from each row | |
| const uint32_t mem_size = size; | |
| for (uint32_t il = 0; il < n_layer; ++il) { | |
| // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) | |
| if (s_l[il] == nullptr) continue; | |
| const uint32_t n_embd_s = hparams.n_embd_s(); | |
| // Write value type | |
| const int32_t s_type_i = (int32_t)s_l[il]->type; | |
| io.write(&s_type_i, sizeof(s_type_i)); | |
| // Write element size | |
| const uint32_t s_size_el = ggml_type_size(s_l[il]->type); | |
| io.write(&s_size_el, sizeof(s_size_el)); | |
| // Write GQA embedding size | |
| io.write(&n_embd_s, sizeof(n_embd_s)); | |
| // For each row, we get the element values of each cell | |
| for (uint32_t j = 0; j < n_embd_s; ++j) { | |
| // Read each range of cells of v_size_el length each into tmp_buf and write out | |
| for (const auto & range : cell_ranges) { | |
| const size_t range_size = range.second - range.first; | |
| const size_t src_offset = (range.first + j * mem_size) * s_size_el; | |
| const size_t buf_size = range_size * s_size_el; | |
| io.write_tensor(s_l[il], src_offset, buf_size); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { | |
| if (dest_seq_id != -1) { | |
| // single sequence | |
| seq_rm(dest_seq_id, -1, -1); | |
| llama_batch_allocr balloc(hparams.n_pos_per_embd()); | |
| llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); | |
| for (uint32_t i = 0; i < cell_count; ++i) { | |
| llama_pos pos; | |
| uint32_t n_seq_id; | |
| io.read_to(&pos, sizeof(pos)); | |
| io.read_to(&n_seq_id, sizeof(n_seq_id)); | |
| if (n_seq_id != 0) { | |
| LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); | |
| return false; | |
| } | |
| ubatch.pos[i] = pos; | |
| } | |
| ubatch.n_seq_id[0] = 1; | |
| ubatch.seq_id[0] = &dest_seq_id; | |
| if (!find_slot(ubatch)) { | |
| LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); | |
| return false; | |
| } | |
| // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) | |
| // Assume that this is one contiguous block of cells | |
| GGML_ASSERT(head + cell_count <= size); | |
| GGML_ASSERT(cells[head].pos == ubatch.pos[0]); | |
| GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]); | |
| GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); | |
| GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); | |
| } else { | |
| // whole KV cache restore | |
| if (cell_count > size) { | |
| LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); | |
| return false; | |
| } | |
| clear(true); | |
| for (uint32_t i = 0; i < cell_count; ++i) { | |
| auto & cell = cells[i]; | |
| llama_pos pos; | |
| uint32_t n_seq_id; | |
| io.read_to(&pos, sizeof(pos)); | |
| io.read_to(&n_seq_id, sizeof(n_seq_id)); | |
| cell.pos = pos; | |
| for (uint32_t j = 0; j < n_seq_id; ++j) { | |
| llama_seq_id seq_id; | |
| io.read_to(&seq_id, sizeof(seq_id)); | |
| // TODO: llama_memory_recurrent should have a notion of max sequences | |
| //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { | |
| if (seq_id < 0) { | |
| //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); | |
| LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); | |
| return false; | |
| } | |
| cell.seq_id.insert(seq_id); | |
| int32_t & tail = cells[seq_id].tail; | |
| if (tail != -1) { | |
| LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); | |
| return false; | |
| } | |
| tail = i; | |
| } | |
| } | |
| head = 0; | |
| used = cell_count; | |
| } | |
| for (uint32_t i = 0; i < cell_count; ++i) { | |
| uint32_t cell_id = head + i; | |
| // make sure the recurrent states will keep their restored state | |
| cells[cell_id].src = cell_id; | |
| } | |
| return true; | |
| } | |
| bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { | |
| uint32_t s_trans; | |
| uint32_t n_layer; | |
| io.read_to(&s_trans, sizeof(s_trans)); | |
| io.read_to(&n_layer, sizeof(n_layer)); | |
| if (n_layer != hparams.n_layer) { | |
| LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); | |
| return false; | |
| } | |
| if (cell_count > size) { | |
| LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); | |
| return false; | |
| } | |
| if (false != (bool) s_trans) { | |
| LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__); | |
| return false; | |
| } | |
| // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block | |
| for (uint32_t il = 0; il < n_layer; ++il) { | |
| // skip null layers | |
| if (r_l[il] == nullptr) continue; | |
| // Read type of key | |
| int32_t r_type_i_ref; | |
| io.read_to(&r_type_i_ref, sizeof(r_type_i_ref)); | |
| const int32_t r_type_i = (int32_t) r_l[il]->type; | |
| if (r_type_i != r_type_i_ref) { | |
| LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il); | |
| return false; | |
| } | |
| // Read row size of key | |
| uint64_t r_size_row_ref; | |
| io.read_to(&r_size_row_ref, sizeof(r_size_row_ref)); | |
| const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); | |
| if (r_size_row != r_size_row_ref) { | |
| LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il); | |
| return false; | |
| } | |
| if (cell_count) { | |
| // Read and set the keys for the whole cell range | |
| ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row); | |
| } | |
| } | |
| if (!s_trans) { | |
| for (uint32_t il = 0; il < n_layer; ++il) { | |
| // skip null layers | |
| if (s_l[il] == nullptr) continue; | |
| // Read type of value | |
| int32_t s_type_i_ref; | |
| io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); | |
| const int32_t s_type_i = (int32_t)s_l[il]->type; | |
| if (s_type_i != s_type_i_ref) { | |
| LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); | |
| return false; | |
| } | |
| // Read row size of value | |
| uint64_t s_size_row_ref; | |
| io.read_to(&s_size_row_ref, sizeof(s_size_row_ref)); | |
| const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); | |
| if (s_size_row != s_size_row_ref) { | |
| LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il); | |
| return false; | |
| } | |
| if (cell_count) { | |
| // Read and set the values for the whole cell range | |
| ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row); | |
| } | |
| } | |
| } else { | |
| // For each layer, read the values for each cell (transposed) | |
| for (uint32_t il = 0; il < n_layer; ++il) { | |
| // skip null layers | |
| if (s_l[il] == nullptr) continue; | |
| const uint32_t n_embd_s = hparams.n_embd_s(); | |
| // Read type of value | |
| int32_t s_type_i_ref; | |
| io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); | |
| const int32_t s_type_i = (int32_t)s_l[il]->type; | |
| if (s_type_i != s_type_i_ref) { | |
| LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); | |
| return false; | |
| } | |
| // Read element size of value | |
| uint32_t s_size_el_ref; | |
| io.read_to(&s_size_el_ref, sizeof(s_size_el_ref)); | |
| const size_t s_size_el = ggml_type_size(s_l[il]->type); | |
| if (s_size_el != s_size_el_ref) { | |
| LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il); | |
| return false; | |
| } | |
| // Read state embedding size | |
| uint32_t n_embd_s_ref; | |
| io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref)); | |
| if (n_embd_s != n_embd_s_ref) { | |
| LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il); | |
| return false; | |
| } | |
| if (cell_count) { | |
| // For each row in the transposed matrix, read the values for the whole cell range | |
| for (uint32_t j = 0; j < n_embd_s; ++j) { | |
| const size_t dst_offset = (head + j * size) * s_size_el; | |
| ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el); | |
| } | |
| } | |
| } | |
| } | |
| return true; | |
| } | |
| // | |
| // llama_memory_recurrent_context | |
| // | |
| llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} | |
| llama_memory_recurrent_context::llama_memory_recurrent_context( | |
| llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { | |
| } | |
| llama_memory_recurrent_context::llama_memory_recurrent_context( | |
| llama_memory_recurrent * mem, | |
| std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} | |
| llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; | |
| bool llama_memory_recurrent_context::next() { | |
| assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | |
| if (++i_next >= ubatches.size()) { | |
| return false; | |
| } | |
| return true; | |
| } | |
| bool llama_memory_recurrent_context::apply() { | |
| assert(!llama_memory_status_is_fail(status)); | |
| // no ubatches -> this is an update | |
| if (ubatches.empty()) { | |
| // recurrent cache never performs updates | |
| assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE); | |
| return true; | |
| } | |
| mem->find_slot(ubatches[i_next]); | |
| return true; | |
| } | |
| llama_memory_status llama_memory_recurrent_context::get_status() const { | |
| return status; | |
| } | |
| const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { | |
| assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | |
| return ubatches[i_next]; | |
| } | |
| uint32_t llama_memory_recurrent_context::get_n_rs() const { | |
| return is_full ? mem->size : mem->n; | |
| } | |
| uint32_t llama_memory_recurrent_context::get_head() const { | |
| return is_full ? 0 : mem->head; | |
| } | |
| int32_t llama_memory_recurrent_context::get_rs_z() const { | |
| return is_full ? 0 : mem->rs_z; | |
| } | |
| uint32_t llama_memory_recurrent_context::get_size() const { | |
| return mem->size; | |
| } | |
| ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { | |
| return mem->r_l[il]; | |
| } | |
| ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { | |
| return mem->s_l[il]; | |
| } | |
| int32_t llama_memory_recurrent_context::s_copy(int i) const { | |
| return mem->cells[i + mem->head].src0; | |
| } | |