Spaces:
Running
Running
talk-llama : sync llama.cpp (#2709)
Browse files- examples/talk-llama/CMakeLists.txt +17 -1
- examples/talk-llama/llama-adapter.cpp +334 -0
- examples/talk-llama/llama-adapter.h +66 -0
- examples/talk-llama/llama-arch.cpp +1434 -0
- examples/talk-llama/llama-arch.h +395 -0
- examples/talk-llama/llama-batch.cpp +368 -0
- examples/talk-llama/llama-batch.h +88 -0
- examples/talk-llama/llama-chat.cpp +567 -0
- examples/talk-llama/llama-chat.h +51 -0
- examples/talk-llama/llama-context.cpp +1771 -0
- examples/talk-llama/llama-context.h +128 -0
- examples/talk-llama/llama-cparams.cpp +1 -0
- examples/talk-llama/llama-cparams.h +37 -0
- examples/talk-llama/llama-grammar.cpp +16 -15
- examples/talk-llama/llama-grammar.h +5 -6
- examples/talk-llama/llama-hparams.cpp +71 -0
- examples/talk-llama/llama-hparams.h +140 -0
- examples/talk-llama/llama-impl.cpp +166 -0
- examples/talk-llama/llama-impl.h +16 -136
- examples/talk-llama/llama-kv-cache.cpp +718 -0
- examples/talk-llama/llama-kv-cache.h +218 -0
- examples/talk-llama/llama-mmap.cpp +589 -0
- examples/talk-llama/llama-mmap.h +67 -0
- examples/talk-llama/llama-model-loader.cpp +1010 -0
- examples/talk-llama/llama-model-loader.h +158 -0
- examples/talk-llama/llama-model.cpp +0 -0
- examples/talk-llama/llama-model.h +391 -0
- examples/talk-llama/llama-quant.cpp +929 -0
- examples/talk-llama/llama-quant.h +1 -0
- examples/talk-llama/llama-sampling.cpp +117 -4
- examples/talk-llama/llama-vocab.cpp +26 -29
- examples/talk-llama/llama-vocab.h +14 -2
- examples/talk-llama/llama.cpp +0 -0
- examples/talk-llama/llama.h +31 -9
- examples/talk-llama/talk-llama.cpp +1 -1
- examples/talk-llama/unicode.cpp +6 -0
examples/talk-llama/CMakeLists.txt
CHANGED
|
@@ -1,10 +1,26 @@
|
|
| 1 |
if (WHISPER_SDL2)
|
|
|
|
|
|
|
|
|
|
| 2 |
set(TARGET whisper-talk-llama)
|
| 3 |
add_executable(${TARGET} talk-llama.cpp
|
| 4 |
llama.cpp
|
| 5 |
-
llama-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
llama-grammar.cpp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
llama-sampling.cpp
|
|
|
|
| 8 |
unicode.cpp
|
| 9 |
unicode-data.cpp)
|
| 10 |
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
|
|
|
| 1 |
if (WHISPER_SDL2)
|
| 2 |
+
set(CMAKE_CXX_STANDARD 17)
|
| 3 |
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
| 4 |
+
|
| 5 |
set(TARGET whisper-talk-llama)
|
| 6 |
add_executable(${TARGET} talk-llama.cpp
|
| 7 |
llama.cpp
|
| 8 |
+
llama-adapter.cpp
|
| 9 |
+
llama-arch.cpp
|
| 10 |
+
llama-batch.cpp
|
| 11 |
+
llama-chat.cpp
|
| 12 |
+
llama-context.cpp
|
| 13 |
+
llama-cparams.cpp
|
| 14 |
llama-grammar.cpp
|
| 15 |
+
llama-hparams.cpp
|
| 16 |
+
llama-impl.cpp
|
| 17 |
+
llama-kv-cache.cpp
|
| 18 |
+
llama-mmap.cpp
|
| 19 |
+
llama-model-loader.cpp
|
| 20 |
+
llama-model.cpp
|
| 21 |
+
llama-quant.cpp
|
| 22 |
llama-sampling.cpp
|
| 23 |
+
llama-vocab.cpp
|
| 24 |
unicode.cpp
|
| 25 |
unicode-data.cpp)
|
| 26 |
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
examples/talk-llama/llama-adapter.cpp
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-adapter.h"
|
| 2 |
+
|
| 3 |
+
#include "llama-model.h"
|
| 4 |
+
|
| 5 |
+
#include <algorithm>
|
| 6 |
+
#include <map>
|
| 7 |
+
#include <cassert>
|
| 8 |
+
#include <stdexcept>
|
| 9 |
+
|
| 10 |
+
// vec
|
| 11 |
+
|
| 12 |
+
struct ggml_tensor * llama_control_vector::tensor_for(int il) const {
|
| 13 |
+
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
|
| 14 |
+
return nullptr;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
return tensors[il];
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
|
| 21 |
+
ggml_tensor * layer_dir = tensor_for(il);
|
| 22 |
+
if (layer_dir != nullptr) {
|
| 23 |
+
cur = ggml_add(ctx, cur, layer_dir);
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
return cur;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) {
|
| 30 |
+
const auto & hparams = model.hparams;
|
| 31 |
+
|
| 32 |
+
GGML_ASSERT(cvec.tensors.empty());
|
| 33 |
+
GGML_ASSERT(cvec.ctxs.empty());
|
| 34 |
+
GGML_ASSERT(cvec.bufs.empty());
|
| 35 |
+
|
| 36 |
+
// create a context for each buffer type
|
| 37 |
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 38 |
+
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 39 |
+
auto it = ctx_map.find(buft);
|
| 40 |
+
if (it == ctx_map.end()) {
|
| 41 |
+
struct ggml_init_params params = {
|
| 42 |
+
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
|
| 43 |
+
/*.mem_buffer =*/ NULL,
|
| 44 |
+
/*.no_alloc =*/ true,
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
ggml_context * ctx = ggml_init(params);
|
| 48 |
+
if (!ctx) {
|
| 49 |
+
return nullptr;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
ctx_map[buft] = ctx;
|
| 53 |
+
cvec.ctxs.emplace_back(ctx);
|
| 54 |
+
|
| 55 |
+
return ctx;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
return it->second;
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
// make tensors
|
| 62 |
+
cvec.tensors.reserve(hparams.n_layer);
|
| 63 |
+
cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
|
| 64 |
+
for (size_t il = 1; il < hparams.n_layer; il++) {
|
| 65 |
+
ggml_backend_buffer_type_t buft = llama_model_select_buft(model, il);
|
| 66 |
+
ggml_context * ctx = ctx_for_buft(buft);
|
| 67 |
+
if (!ctx) {
|
| 68 |
+
LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
|
| 69 |
+
return false;
|
| 70 |
+
}
|
| 71 |
+
ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
|
| 72 |
+
cvec.tensors.push_back(tensor);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// allocate tensors / buffers and zero
|
| 76 |
+
cvec.bufs.reserve(ctx_map.size());
|
| 77 |
+
for (auto it : ctx_map) {
|
| 78 |
+
ggml_backend_buffer_type_t buft = it.first;
|
| 79 |
+
ggml_context * ctx = it.second;
|
| 80 |
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
| 81 |
+
if (!buf) {
|
| 82 |
+
LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__);
|
| 83 |
+
return false;
|
| 84 |
+
}
|
| 85 |
+
ggml_backend_buffer_clear(buf, 0);
|
| 86 |
+
cvec.bufs.emplace_back(buf);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
return true;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
int32_t llama_control_vector_apply(
|
| 93 |
+
struct llama_control_vector & cvec,
|
| 94 |
+
const llama_model & model,
|
| 95 |
+
const float * data,
|
| 96 |
+
size_t len,
|
| 97 |
+
int32_t n_embd,
|
| 98 |
+
int32_t il_start,
|
| 99 |
+
int32_t il_end) {
|
| 100 |
+
const auto & hparams = model.hparams;
|
| 101 |
+
|
| 102 |
+
if (data == nullptr) {
|
| 103 |
+
// disable the current control vector (but leave allocated for later)
|
| 104 |
+
cvec.layer_start = -1;
|
| 105 |
+
cvec.layer_end = -1;
|
| 106 |
+
return 0;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
if (n_embd != (int) hparams.n_embd) {
|
| 110 |
+
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
|
| 111 |
+
return 1;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
if (cvec.tensors.empty()) {
|
| 115 |
+
if (!llama_control_vector_init(cvec, model)) {
|
| 116 |
+
return 1;
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
cvec.layer_start = il_start;
|
| 121 |
+
cvec.layer_end = il_end;
|
| 122 |
+
|
| 123 |
+
for (size_t il = 1; il < hparams.n_layer; il++) {
|
| 124 |
+
assert(cvec.tensors[il] != nullptr);
|
| 125 |
+
|
| 126 |
+
const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
|
| 127 |
+
if (off + n_embd <= len) {
|
| 128 |
+
ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il]));
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
return 0;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// lora
|
| 136 |
+
|
| 137 |
+
llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) {
|
| 138 |
+
const std::string name(w->name);
|
| 139 |
+
|
| 140 |
+
const auto pos = ab_map.find(name);
|
| 141 |
+
if (pos != ab_map.end()) {
|
| 142 |
+
return &pos->second;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
return nullptr;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
|
| 149 |
+
delete adapter;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
static void llama_lora_adapter_init_impl(struct llama_model & model, const char * path_lora, struct llama_lora_adapter & adapter) {
|
| 153 |
+
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
|
| 154 |
+
|
| 155 |
+
ggml_context * ctx_init;
|
| 156 |
+
struct gguf_init_params meta_gguf_params = {
|
| 157 |
+
/* .no_alloc = */ true,
|
| 158 |
+
/* .ctx = */ &ctx_init,
|
| 159 |
+
};
|
| 160 |
+
|
| 161 |
+
gguf_context_ptr ctx_gguf { gguf_init_from_file(path_lora, meta_gguf_params) };
|
| 162 |
+
if (!ctx_gguf) {
|
| 163 |
+
throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora));
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
ggml_context_ptr ctx { ctx_init };
|
| 167 |
+
|
| 168 |
+
// check metadata
|
| 169 |
+
{
|
| 170 |
+
auto get_kv_str = [&](const std::string & key) -> std::string {
|
| 171 |
+
int id = gguf_find_key(ctx_gguf.get(), key.c_str());
|
| 172 |
+
return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id));
|
| 173 |
+
};
|
| 174 |
+
auto get_kv_f32 = [&](const std::string & key) -> float {
|
| 175 |
+
int id = gguf_find_key(ctx_gguf.get(), key.c_str());
|
| 176 |
+
return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id);
|
| 177 |
+
};
|
| 178 |
+
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
|
| 179 |
+
|
| 180 |
+
auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE));
|
| 181 |
+
if (general_type != "adapter") {
|
| 182 |
+
throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
|
| 186 |
+
auto general_arch = llm_arch_from_string(general_arch_str);
|
| 187 |
+
if (general_arch != model.arch) {
|
| 188 |
+
throw std::runtime_error("model arch and LoRA arch mismatch");
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE));
|
| 192 |
+
if (adapter_type != "lora") {
|
| 193 |
+
throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
|
| 200 |
+
|
| 201 |
+
// contexts for each buffer type
|
| 202 |
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 203 |
+
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 204 |
+
auto it = ctx_map.find(buft);
|
| 205 |
+
if (it == ctx_map.end()) {
|
| 206 |
+
// add a new context
|
| 207 |
+
struct ggml_init_params params = {
|
| 208 |
+
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
|
| 209 |
+
/*.mem_buffer =*/ NULL,
|
| 210 |
+
/*.no_alloc =*/ true,
|
| 211 |
+
};
|
| 212 |
+
ggml_context * buft_ctx = ggml_init(params);
|
| 213 |
+
if (!buft_ctx) {
|
| 214 |
+
return nullptr;
|
| 215 |
+
}
|
| 216 |
+
ctx_map[buft] = buft_ctx;
|
| 217 |
+
adapter.ctxs.emplace_back(buft_ctx);
|
| 218 |
+
return buft_ctx;
|
| 219 |
+
};
|
| 220 |
+
return it->second;
|
| 221 |
+
};
|
| 222 |
+
|
| 223 |
+
// bundle lora_a and lora_b into pairs
|
| 224 |
+
std::map<std::string, llama_lora_weight> ab_map;
|
| 225 |
+
auto str_endswith = [](const std::string & str, const std::string & suffix) {
|
| 226 |
+
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
| 227 |
+
};
|
| 228 |
+
|
| 229 |
+
for (ggml_tensor * cur = ggml_get_first_tensor(ctx.get()); cur; cur = ggml_get_next_tensor(ctx.get(), cur)) {
|
| 230 |
+
std::string name(cur->name);
|
| 231 |
+
if (str_endswith(name, ".lora_a")) {
|
| 232 |
+
replace_all(name, ".lora_a", "");
|
| 233 |
+
if (ab_map.find(name) == ab_map.end()) {
|
| 234 |
+
ab_map[name] = llama_lora_weight(cur, nullptr);
|
| 235 |
+
} else {
|
| 236 |
+
ab_map[name].a = cur;
|
| 237 |
+
}
|
| 238 |
+
} else if (str_endswith(name, ".lora_b")) {
|
| 239 |
+
replace_all(name, ".lora_b", "");
|
| 240 |
+
if (ab_map.find(name) == ab_map.end()) {
|
| 241 |
+
ab_map[name] = llama_lora_weight(nullptr, cur);
|
| 242 |
+
} else {
|
| 243 |
+
ab_map[name].b = cur;
|
| 244 |
+
}
|
| 245 |
+
} else {
|
| 246 |
+
throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
// add tensors
|
| 251 |
+
for (auto & it : ab_map) {
|
| 252 |
+
const std::string & name = it.first;
|
| 253 |
+
llama_lora_weight & w = it.second;
|
| 254 |
+
|
| 255 |
+
if (!w.a || !w.b) {
|
| 256 |
+
throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// device buft and device ctx
|
| 260 |
+
auto * model_tensor = llama_model_get_tensor(model, name.c_str());
|
| 261 |
+
if (!model_tensor) {
|
| 262 |
+
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
|
| 266 |
+
// validate tensor shape
|
| 267 |
+
if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
|
| 268 |
+
throw std::runtime_error("tensor '" + name + "' has incorrect shape");
|
| 269 |
+
}
|
| 270 |
+
if (w.a->ne[1] != w.b->ne[0]) {
|
| 271 |
+
throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
// save tensor to adapter
|
| 275 |
+
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
|
| 276 |
+
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
|
| 277 |
+
ggml_set_name(tensor_a, w.a->name);
|
| 278 |
+
ggml_set_name(tensor_b, w.b->name);
|
| 279 |
+
adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
// allocate tensors / buffers and zero
|
| 283 |
+
{
|
| 284 |
+
adapter.ctxs.reserve(ctx_map.size());
|
| 285 |
+
adapter.bufs.reserve(ctx_map.size());
|
| 286 |
+
for (auto & it : ctx_map) {
|
| 287 |
+
ggml_backend_buffer_type_t buft = it.first;
|
| 288 |
+
ggml_context * ctx_dev = it.second;
|
| 289 |
+
ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft) };
|
| 290 |
+
if (!buf) {
|
| 291 |
+
throw std::runtime_error("failed to allocate buffer for lora adapter\n");
|
| 292 |
+
}
|
| 293 |
+
LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get())/1024.0/1024.0);
|
| 294 |
+
adapter.bufs.emplace_back(std::move(buf));
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// set tensor data
|
| 299 |
+
{
|
| 300 |
+
llama_file gguf_file(path_lora, "rb");
|
| 301 |
+
std::vector<uint8_t> read_buf;
|
| 302 |
+
auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
|
| 303 |
+
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
|
| 304 |
+
size_t size = ggml_nbytes(orig);
|
| 305 |
+
read_buf.resize(size);
|
| 306 |
+
gguf_file.seek(offs, SEEK_SET);
|
| 307 |
+
gguf_file.read_raw(read_buf.data(), size);
|
| 308 |
+
ggml_backend_tensor_set(dev, read_buf.data(), 0, size);
|
| 309 |
+
};
|
| 310 |
+
for (auto & it : adapter.ab_map) {
|
| 311 |
+
auto orig = ab_map[it.first];
|
| 312 |
+
auto dev = it.second;
|
| 313 |
+
set_tensor(orig.a, dev.a);
|
| 314 |
+
set_tensor(orig.b, dev.b);
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
|
| 322 |
+
struct llama_lora_adapter * adapter = new llama_lora_adapter();
|
| 323 |
+
|
| 324 |
+
try {
|
| 325 |
+
llama_lora_adapter_init_impl(*model, path_lora, *adapter);
|
| 326 |
+
return adapter;
|
| 327 |
+
} catch (const std::exception & err) {
|
| 328 |
+
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
|
| 329 |
+
|
| 330 |
+
delete adapter;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
return nullptr;
|
| 334 |
+
}
|
examples/talk-llama/llama-adapter.h
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
#include "llama-hparams.h"
|
| 5 |
+
|
| 6 |
+
#include "ggml-cpp.h"
|
| 7 |
+
|
| 8 |
+
#include <unordered_map>
|
| 9 |
+
#include <vector>
|
| 10 |
+
|
| 11 |
+
//
|
| 12 |
+
// llama_adapter_cvec
|
| 13 |
+
//
|
| 14 |
+
|
| 15 |
+
// TODO: rename to llama_adapter_cvec
|
| 16 |
+
struct llama_control_vector {
|
| 17 |
+
std::vector<ggml_context_ptr> ctxs;
|
| 18 |
+
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 19 |
+
|
| 20 |
+
std::vector<struct ggml_tensor *> tensors; // per layer
|
| 21 |
+
|
| 22 |
+
int32_t layer_start = -1;
|
| 23 |
+
int32_t layer_end = -1;
|
| 24 |
+
|
| 25 |
+
struct ggml_tensor * tensor_for(int il) const;
|
| 26 |
+
|
| 27 |
+
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
int32_t llama_control_vector_apply(
|
| 31 |
+
struct llama_control_vector & cvec,
|
| 32 |
+
const llama_model & model,
|
| 33 |
+
const float * data,
|
| 34 |
+
size_t len,
|
| 35 |
+
int32_t n_embd,
|
| 36 |
+
int32_t il_start,
|
| 37 |
+
int32_t il_end);
|
| 38 |
+
|
| 39 |
+
//
|
| 40 |
+
// llama_adapter_lora
|
| 41 |
+
//
|
| 42 |
+
|
| 43 |
+
// TODO: rename to llama_adapter_lora_weight
|
| 44 |
+
struct llama_lora_weight {
|
| 45 |
+
struct ggml_tensor * a = nullptr;
|
| 46 |
+
struct ggml_tensor * b = nullptr;
|
| 47 |
+
|
| 48 |
+
llama_lora_weight() = default;
|
| 49 |
+
llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
// TODO: rename to llama_adapter_lora
|
| 53 |
+
struct llama_lora_adapter {
|
| 54 |
+
// map tensor name to lora_a_b
|
| 55 |
+
std::unordered_map<std::string, struct llama_lora_weight> ab_map;
|
| 56 |
+
|
| 57 |
+
std::vector<ggml_context_ptr> ctxs;
|
| 58 |
+
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 59 |
+
|
| 60 |
+
float alpha;
|
| 61 |
+
|
| 62 |
+
llama_lora_adapter() = default;
|
| 63 |
+
~llama_lora_adapter() = default;
|
| 64 |
+
|
| 65 |
+
llama_lora_weight * get_weight(struct ggml_tensor * w);
|
| 66 |
+
};
|
examples/talk-llama/llama-arch.cpp
ADDED
|
@@ -0,0 +1,1434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-arch.h"
|
| 2 |
+
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
|
| 5 |
+
#include <map>
|
| 6 |
+
|
| 7 |
+
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
| 8 |
+
{ LLM_ARCH_LLAMA, "llama" },
|
| 9 |
+
{ LLM_ARCH_DECI, "deci" },
|
| 10 |
+
{ LLM_ARCH_FALCON, "falcon" },
|
| 11 |
+
{ LLM_ARCH_GROK, "grok" },
|
| 12 |
+
{ LLM_ARCH_GPT2, "gpt2" },
|
| 13 |
+
{ LLM_ARCH_GPTJ, "gptj" },
|
| 14 |
+
{ LLM_ARCH_GPTNEOX, "gptneox" },
|
| 15 |
+
{ LLM_ARCH_MPT, "mpt" },
|
| 16 |
+
{ LLM_ARCH_BAICHUAN, "baichuan" },
|
| 17 |
+
{ LLM_ARCH_STARCODER, "starcoder" },
|
| 18 |
+
{ LLM_ARCH_REFACT, "refact" },
|
| 19 |
+
{ LLM_ARCH_BERT, "bert" },
|
| 20 |
+
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
| 21 |
+
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
|
| 22 |
+
{ LLM_ARCH_BLOOM, "bloom" },
|
| 23 |
+
{ LLM_ARCH_STABLELM, "stablelm" },
|
| 24 |
+
{ LLM_ARCH_QWEN, "qwen" },
|
| 25 |
+
{ LLM_ARCH_QWEN2, "qwen2" },
|
| 26 |
+
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
|
| 27 |
+
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
| 28 |
+
{ LLM_ARCH_PHI2, "phi2" },
|
| 29 |
+
{ LLM_ARCH_PHI3, "phi3" },
|
| 30 |
+
{ LLM_ARCH_PLAMO, "plamo" },
|
| 31 |
+
{ LLM_ARCH_CODESHELL, "codeshell" },
|
| 32 |
+
{ LLM_ARCH_ORION, "orion" },
|
| 33 |
+
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
| 34 |
+
{ LLM_ARCH_MINICPM, "minicpm" },
|
| 35 |
+
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
| 36 |
+
{ LLM_ARCH_GEMMA, "gemma" },
|
| 37 |
+
{ LLM_ARCH_GEMMA2, "gemma2" },
|
| 38 |
+
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
| 39 |
+
{ LLM_ARCH_MAMBA, "mamba" },
|
| 40 |
+
{ LLM_ARCH_XVERSE, "xverse" },
|
| 41 |
+
{ LLM_ARCH_COMMAND_R, "command-r" },
|
| 42 |
+
{ LLM_ARCH_COHERE2, "cohere2" },
|
| 43 |
+
{ LLM_ARCH_DBRX, "dbrx" },
|
| 44 |
+
{ LLM_ARCH_OLMO, "olmo" },
|
| 45 |
+
{ LLM_ARCH_OLMO2, "olmo2" },
|
| 46 |
+
{ LLM_ARCH_OLMOE, "olmoe" },
|
| 47 |
+
{ LLM_ARCH_OPENELM, "openelm" },
|
| 48 |
+
{ LLM_ARCH_ARCTIC, "arctic" },
|
| 49 |
+
{ LLM_ARCH_DEEPSEEK, "deepseek" },
|
| 50 |
+
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
| 51 |
+
{ LLM_ARCH_CHATGLM, "chatglm" },
|
| 52 |
+
{ LLM_ARCH_BITNET, "bitnet" },
|
| 53 |
+
{ LLM_ARCH_T5, "t5" },
|
| 54 |
+
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
| 55 |
+
{ LLM_ARCH_JAIS, "jais" },
|
| 56 |
+
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
| 57 |
+
{ LLM_ARCH_EXAONE, "exaone" },
|
| 58 |
+
{ LLM_ARCH_RWKV6, "rwkv6" },
|
| 59 |
+
{ LLM_ARCH_GRANITE, "granite" },
|
| 60 |
+
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
| 61 |
+
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
| 62 |
+
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
| 63 |
+
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
| 67 |
+
{ LLM_KV_GENERAL_TYPE, "general.type" },
|
| 68 |
+
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
|
| 69 |
+
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
|
| 70 |
+
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
|
| 71 |
+
{ LLM_KV_GENERAL_NAME, "general.name" },
|
| 72 |
+
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
|
| 73 |
+
{ LLM_KV_GENERAL_VERSION, "general.version" },
|
| 74 |
+
{ LLM_KV_GENERAL_URL, "general.url" },
|
| 75 |
+
{ LLM_KV_GENERAL_DESCRIPTION, "general.description" },
|
| 76 |
+
{ LLM_KV_GENERAL_LICENSE, "general.license" },
|
| 77 |
+
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
|
| 78 |
+
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
|
| 79 |
+
|
| 80 |
+
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
|
| 81 |
+
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
|
| 82 |
+
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
|
| 83 |
+
{ LLM_KV_FEATURES_LENGTH, "%s.features_length" },
|
| 84 |
+
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
|
| 85 |
+
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
|
| 86 |
+
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
|
| 87 |
+
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
|
| 88 |
+
{ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
|
| 89 |
+
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
|
| 90 |
+
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
|
| 91 |
+
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
| 92 |
+
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
| 93 |
+
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
|
| 94 |
+
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
| 95 |
+
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
| 96 |
+
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
| 97 |
+
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
| 98 |
+
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
| 99 |
+
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
| 100 |
+
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
|
| 101 |
+
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
|
| 102 |
+
{ LLM_KV_SWIN_NORM, "%s.swin_norm" },
|
| 103 |
+
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
|
| 104 |
+
{ LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" },
|
| 105 |
+
{ LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
|
| 106 |
+
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
|
| 107 |
+
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
| 108 |
+
|
| 109 |
+
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
| 110 |
+
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
| 111 |
+
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
| 112 |
+
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
| 113 |
+
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
| 114 |
+
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
| 115 |
+
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
| 116 |
+
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
| 117 |
+
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
| 118 |
+
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
| 119 |
+
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
| 120 |
+
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
| 121 |
+
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
| 122 |
+
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
| 123 |
+
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
| 124 |
+
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
| 125 |
+
|
| 126 |
+
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
| 127 |
+
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
| 128 |
+
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
| 129 |
+
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
| 130 |
+
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
| 131 |
+
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
|
| 132 |
+
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
|
| 133 |
+
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
| 134 |
+
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
| 135 |
+
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
|
| 136 |
+
|
| 137 |
+
{ LLM_KV_SPLIT_NO, "split.no" },
|
| 138 |
+
{ LLM_KV_SPLIT_COUNT, "split.count" },
|
| 139 |
+
{ LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" },
|
| 140 |
+
|
| 141 |
+
{ LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" },
|
| 142 |
+
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
|
| 143 |
+
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
|
| 144 |
+
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
| 145 |
+
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
|
| 146 |
+
|
| 147 |
+
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
|
| 148 |
+
|
| 149 |
+
{ LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" },
|
| 150 |
+
{ LLM_KV_POSNET_BLOCK_COUNT, "%s.posnet.block_count" },
|
| 151 |
+
|
| 152 |
+
{ LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
|
| 153 |
+
{ LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
|
| 154 |
+
|
| 155 |
+
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
| 156 |
+
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
| 157 |
+
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
| 158 |
+
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
|
| 159 |
+
{ LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" },
|
| 160 |
+
{ LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" },
|
| 161 |
+
{ LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" },
|
| 162 |
+
{ LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" },
|
| 163 |
+
{ LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" },
|
| 164 |
+
{ LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
|
| 165 |
+
{ LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" },
|
| 166 |
+
{ LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
|
| 167 |
+
{ LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
|
| 168 |
+
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
|
| 169 |
+
{ LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" },
|
| 170 |
+
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
|
| 171 |
+
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
|
| 172 |
+
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
|
| 173 |
+
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
|
| 174 |
+
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
|
| 175 |
+
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
|
| 176 |
+
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
| 177 |
+
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
| 178 |
+
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
| 179 |
+
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
| 180 |
+
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
| 181 |
+
{ LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" },
|
| 182 |
+
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
|
| 183 |
+
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
|
| 184 |
+
|
| 185 |
+
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
|
| 186 |
+
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
|
| 187 |
+
|
| 188 |
+
// deprecated
|
| 189 |
+
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
|
| 190 |
+
{ LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
|
| 191 |
+
{ LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" },
|
| 192 |
+
};
|
| 193 |
+
|
| 194 |
+
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
|
| 195 |
+
{
|
| 196 |
+
LLM_ARCH_LLAMA,
|
| 197 |
+
{
|
| 198 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 199 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 200 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 201 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 202 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 203 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 204 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 205 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 206 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 207 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 208 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 209 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 210 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 211 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 212 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 213 |
+
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
| 214 |
+
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
| 215 |
+
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
| 216 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 217 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 218 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 219 |
+
},
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
LLM_ARCH_DECI,
|
| 223 |
+
{
|
| 224 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 225 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 226 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 227 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 228 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 229 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 230 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 231 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 232 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 233 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 234 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 235 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 236 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 237 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 238 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 239 |
+
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
| 240 |
+
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
| 241 |
+
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
| 242 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 243 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 244 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 245 |
+
},
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
LLM_ARCH_BAICHUAN,
|
| 249 |
+
{
|
| 250 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 251 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 252 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 253 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 254 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 255 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 256 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 257 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 258 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 259 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 260 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 261 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 262 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 263 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 264 |
+
},
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
LLM_ARCH_FALCON,
|
| 268 |
+
{
|
| 269 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 270 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 271 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 272 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 273 |
+
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
| 274 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 275 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 276 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 277 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 278 |
+
},
|
| 279 |
+
},
|
| 280 |
+
{
|
| 281 |
+
LLM_ARCH_GROK,
|
| 282 |
+
{
|
| 283 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 284 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 285 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 286 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 287 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 288 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 289 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 290 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 291 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 292 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 293 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 294 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 295 |
+
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
| 296 |
+
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
| 297 |
+
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
| 298 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 299 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 300 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 301 |
+
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
| 302 |
+
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
| 303 |
+
},
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
LLM_ARCH_GPT2,
|
| 307 |
+
{
|
| 308 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 309 |
+
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
| 310 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 311 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 312 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 313 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 314 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 315 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 316 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 317 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 318 |
+
},
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
LLM_ARCH_GPTJ,
|
| 322 |
+
{
|
| 323 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 324 |
+
},
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
LLM_ARCH_GPTNEOX,
|
| 328 |
+
{
|
| 329 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 330 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 331 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 332 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 333 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 334 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 335 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 336 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 337 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 338 |
+
},
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
LLM_ARCH_MPT,
|
| 342 |
+
{
|
| 343 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 344 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 345 |
+
{ LLM_TENSOR_OUTPUT, "output"},
|
| 346 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 347 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 348 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 349 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 350 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 351 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 352 |
+
{ LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" },
|
| 353 |
+
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
| 354 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"},
|
| 355 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"},
|
| 356 |
+
},
|
| 357 |
+
},
|
| 358 |
+
{
|
| 359 |
+
LLM_ARCH_STARCODER,
|
| 360 |
+
{
|
| 361 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 362 |
+
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
| 363 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 364 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 365 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 366 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 367 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 368 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 369 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 370 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 371 |
+
},
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
LLM_ARCH_REFACT,
|
| 375 |
+
{
|
| 376 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 377 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 378 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 379 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 380 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 381 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 382 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 383 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 384 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 385 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 386 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 387 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 388 |
+
},
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
LLM_ARCH_BERT,
|
| 392 |
+
{
|
| 393 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 394 |
+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 395 |
+
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
| 396 |
+
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
| 397 |
+
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
| 398 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 399 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 400 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 401 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 402 |
+
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
| 403 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 404 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 405 |
+
{ LLM_TENSOR_CLS, "cls" },
|
| 406 |
+
{ LLM_TENSOR_CLS_OUT, "cls.output" },
|
| 407 |
+
},
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
LLM_ARCH_NOMIC_BERT,
|
| 411 |
+
{
|
| 412 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 413 |
+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 414 |
+
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
| 415 |
+
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
| 416 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 417 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 418 |
+
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
| 419 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 420 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 421 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 422 |
+
},
|
| 423 |
+
},
|
| 424 |
+
{
|
| 425 |
+
LLM_ARCH_JINA_BERT_V2,
|
| 426 |
+
{
|
| 427 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 428 |
+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 429 |
+
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
| 430 |
+
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
| 431 |
+
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
| 432 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 433 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 434 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 435 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 436 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 437 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 438 |
+
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
| 439 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 440 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 441 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 442 |
+
{ LLM_TENSOR_CLS, "cls" },
|
| 443 |
+
},
|
| 444 |
+
},
|
| 445 |
+
{
|
| 446 |
+
LLM_ARCH_BLOOM,
|
| 447 |
+
{
|
| 448 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 449 |
+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 450 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 451 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 452 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 453 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 454 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 455 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 456 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 457 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 458 |
+
},
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
LLM_ARCH_STABLELM,
|
| 462 |
+
{
|
| 463 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 464 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 465 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 466 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 467 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 468 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 469 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 470 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 471 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 472 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 473 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 474 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 475 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 476 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 477 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 478 |
+
},
|
| 479 |
+
},
|
| 480 |
+
{
|
| 481 |
+
LLM_ARCH_QWEN,
|
| 482 |
+
{
|
| 483 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 484 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 485 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 486 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 487 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 488 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 489 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 490 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 491 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 492 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 493 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 494 |
+
},
|
| 495 |
+
},
|
| 496 |
+
{
|
| 497 |
+
LLM_ARCH_QWEN2,
|
| 498 |
+
{
|
| 499 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 500 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 501 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 502 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 503 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 504 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 505 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 506 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 507 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 508 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 509 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 510 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 511 |
+
},
|
| 512 |
+
},
|
| 513 |
+
{
|
| 514 |
+
LLM_ARCH_QWEN2VL,
|
| 515 |
+
{
|
| 516 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 517 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 518 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 519 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 520 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 521 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 522 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 523 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 524 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 525 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 526 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 527 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 528 |
+
},
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
LLM_ARCH_QWEN2MOE,
|
| 532 |
+
{
|
| 533 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 534 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 535 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 536 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 537 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 538 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 539 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 540 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 541 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 542 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 543 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 544 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 545 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 546 |
+
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
| 547 |
+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
| 548 |
+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
| 549 |
+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 550 |
+
},
|
| 551 |
+
},
|
| 552 |
+
{
|
| 553 |
+
LLM_ARCH_PHI2,
|
| 554 |
+
{
|
| 555 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 556 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 557 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 558 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 559 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 560 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 561 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 562 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 563 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 564 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 565 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 566 |
+
},
|
| 567 |
+
},
|
| 568 |
+
{
|
| 569 |
+
LLM_ARCH_PHI3,
|
| 570 |
+
{
|
| 571 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 572 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 573 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 574 |
+
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
|
| 575 |
+
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
|
| 576 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 577 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 578 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 579 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 580 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 581 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 582 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 583 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 584 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 585 |
+
},
|
| 586 |
+
},
|
| 587 |
+
{
|
| 588 |
+
LLM_ARCH_PLAMO,
|
| 589 |
+
{
|
| 590 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 591 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 592 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 593 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 594 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 595 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 596 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 597 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 598 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 599 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 600 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 601 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 602 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 603 |
+
},
|
| 604 |
+
},
|
| 605 |
+
{
|
| 606 |
+
LLM_ARCH_CODESHELL,
|
| 607 |
+
{
|
| 608 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 609 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 610 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 611 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 612 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 613 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 614 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 615 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 616 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 617 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 618 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 619 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 620 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 621 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 622 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 623 |
+
},
|
| 624 |
+
},
|
| 625 |
+
{
|
| 626 |
+
LLM_ARCH_ORION,
|
| 627 |
+
{
|
| 628 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 629 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 630 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 631 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 632 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 633 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 634 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 635 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 636 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 637 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 638 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 639 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 640 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 641 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 642 |
+
},
|
| 643 |
+
},
|
| 644 |
+
{
|
| 645 |
+
LLM_ARCH_INTERNLM2,
|
| 646 |
+
{
|
| 647 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 648 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 649 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 650 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 651 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 652 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 653 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 654 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 655 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 656 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 657 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 658 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 659 |
+
},
|
| 660 |
+
},
|
| 661 |
+
{
|
| 662 |
+
LLM_ARCH_MINICPM,
|
| 663 |
+
{
|
| 664 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 665 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 666 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 667 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 668 |
+
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
|
| 669 |
+
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
|
| 670 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 671 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 672 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 673 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 674 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 675 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 676 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 677 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 678 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 679 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 680 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 681 |
+
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
| 682 |
+
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
| 683 |
+
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
| 684 |
+
},
|
| 685 |
+
},
|
| 686 |
+
{
|
| 687 |
+
LLM_ARCH_MINICPM3,
|
| 688 |
+
{
|
| 689 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 690 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 691 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 692 |
+
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
|
| 693 |
+
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
|
| 694 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 695 |
+
{ LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
|
| 696 |
+
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
|
| 697 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 698 |
+
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
|
| 699 |
+
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
| 700 |
+
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
| 701 |
+
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
| 702 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 703 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 704 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 705 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 706 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 707 |
+
},
|
| 708 |
+
},
|
| 709 |
+
{
|
| 710 |
+
LLM_ARCH_GEMMA,
|
| 711 |
+
{
|
| 712 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 713 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 714 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 715 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 716 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 717 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 718 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 719 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 720 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 721 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 722 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 723 |
+
},
|
| 724 |
+
},
|
| 725 |
+
{
|
| 726 |
+
LLM_ARCH_GEMMA2,
|
| 727 |
+
{
|
| 728 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 729 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 730 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 731 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 732 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 733 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 734 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 735 |
+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
| 736 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 737 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 738 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 739 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 740 |
+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 741 |
+
},
|
| 742 |
+
},
|
| 743 |
+
{
|
| 744 |
+
LLM_ARCH_STARCODER2,
|
| 745 |
+
{
|
| 746 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 747 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 748 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 749 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 750 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 751 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 752 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 753 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 754 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 755 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 756 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 757 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 758 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 759 |
+
},
|
| 760 |
+
},
|
| 761 |
+
{
|
| 762 |
+
LLM_ARCH_MAMBA,
|
| 763 |
+
{
|
| 764 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 765 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 766 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 767 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 768 |
+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
| 769 |
+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
| 770 |
+
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
|
| 771 |
+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
| 772 |
+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
| 773 |
+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
| 774 |
+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
| 775 |
+
},
|
| 776 |
+
},
|
| 777 |
+
{
|
| 778 |
+
LLM_ARCH_XVERSE,
|
| 779 |
+
{
|
| 780 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 781 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 782 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 783 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 784 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 785 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 786 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 787 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 788 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 789 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 790 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 791 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 792 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 793 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 794 |
+
},
|
| 795 |
+
},
|
| 796 |
+
{
|
| 797 |
+
LLM_ARCH_COMMAND_R,
|
| 798 |
+
{
|
| 799 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 800 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 801 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 802 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 803 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 804 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 805 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 806 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 807 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 808 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 809 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 810 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 811 |
+
},
|
| 812 |
+
},
|
| 813 |
+
{
|
| 814 |
+
LLM_ARCH_COHERE2,
|
| 815 |
+
{
|
| 816 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 817 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 818 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 819 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 820 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 821 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 822 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 823 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 824 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 825 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 826 |
+
},
|
| 827 |
+
},
|
| 828 |
+
{
|
| 829 |
+
LLM_ARCH_DBRX,
|
| 830 |
+
{
|
| 831 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 832 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 833 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 834 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 835 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 836 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 837 |
+
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
| 838 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 839 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 840 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 841 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 842 |
+
},
|
| 843 |
+
},
|
| 844 |
+
{
|
| 845 |
+
LLM_ARCH_OLMO,
|
| 846 |
+
{
|
| 847 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 848 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 849 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 850 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 851 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 852 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 853 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 854 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 855 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 856 |
+
},
|
| 857 |
+
},
|
| 858 |
+
{
|
| 859 |
+
LLM_ARCH_OLMO2,
|
| 860 |
+
{
|
| 861 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 862 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 863 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 864 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 865 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 866 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 867 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 868 |
+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
| 869 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 870 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 871 |
+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 872 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 873 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 874 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 875 |
+
},
|
| 876 |
+
},
|
| 877 |
+
{
|
| 878 |
+
LLM_ARCH_OLMOE,
|
| 879 |
+
{
|
| 880 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 881 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 882 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 883 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 884 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 885 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 886 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 887 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 888 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 889 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 890 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 891 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 892 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 893 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 894 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 895 |
+
},
|
| 896 |
+
},
|
| 897 |
+
{
|
| 898 |
+
LLM_ARCH_OPENELM,
|
| 899 |
+
{
|
| 900 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 901 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 902 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 903 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 904 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 905 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 906 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 907 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 908 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 909 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 910 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 911 |
+
},
|
| 912 |
+
},
|
| 913 |
+
{
|
| 914 |
+
LLM_ARCH_ARCTIC,
|
| 915 |
+
{
|
| 916 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 917 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 918 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 919 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 920 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 921 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 922 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 923 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 924 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 925 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 926 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 927 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 928 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 929 |
+
{ LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" },
|
| 930 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 931 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 932 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 933 |
+
},
|
| 934 |
+
},
|
| 935 |
+
{
|
| 936 |
+
LLM_ARCH_DEEPSEEK,
|
| 937 |
+
{
|
| 938 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 939 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 940 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 941 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 942 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 943 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 944 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 945 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 946 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 947 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 948 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 949 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 950 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 951 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 952 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 953 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 954 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 955 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 956 |
+
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
| 957 |
+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
| 958 |
+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
| 959 |
+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 960 |
+
},
|
| 961 |
+
},
|
| 962 |
+
{
|
| 963 |
+
LLM_ARCH_DEEPSEEK2,
|
| 964 |
+
{
|
| 965 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 966 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 967 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 968 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 969 |
+
{ LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
|
| 970 |
+
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
|
| 971 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 972 |
+
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
|
| 973 |
+
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
| 974 |
+
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
| 975 |
+
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
| 976 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 977 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 978 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 979 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 980 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 981 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 982 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 983 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 984 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 985 |
+
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
| 986 |
+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
| 987 |
+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
| 988 |
+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 989 |
+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
| 990 |
+
},
|
| 991 |
+
},
|
| 992 |
+
{
|
| 993 |
+
LLM_ARCH_CHATGLM,
|
| 994 |
+
{
|
| 995 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 996 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 997 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 998 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 999 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1000 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 1001 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1002 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1003 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1004 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1005 |
+
},
|
| 1006 |
+
},
|
| 1007 |
+
{
|
| 1008 |
+
LLM_ARCH_BITNET,
|
| 1009 |
+
{
|
| 1010 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1011 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1012 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1013 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1014 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1015 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1016 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1017 |
+
{ LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
|
| 1018 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1019 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1020 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1021 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1022 |
+
{ LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
|
| 1023 |
+
},
|
| 1024 |
+
},
|
| 1025 |
+
{
|
| 1026 |
+
LLM_ARCH_T5,
|
| 1027 |
+
{
|
| 1028 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1029 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1030 |
+
{ LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" },
|
| 1031 |
+
{ LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" },
|
| 1032 |
+
{ LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" },
|
| 1033 |
+
{ LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" },
|
| 1034 |
+
{ LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" },
|
| 1035 |
+
{ LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" },
|
| 1036 |
+
{ LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" },
|
| 1037 |
+
{ LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" },
|
| 1038 |
+
{ LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" },
|
| 1039 |
+
{ LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" },
|
| 1040 |
+
{ LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" },
|
| 1041 |
+
{ LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" },
|
| 1042 |
+
{ LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" },
|
| 1043 |
+
{ LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" },
|
| 1044 |
+
{ LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" },
|
| 1045 |
+
{ LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" },
|
| 1046 |
+
{ LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" },
|
| 1047 |
+
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
|
| 1048 |
+
{ LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
|
| 1049 |
+
{ LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
|
| 1050 |
+
{ LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
|
| 1051 |
+
{ LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
|
| 1052 |
+
{ LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
|
| 1053 |
+
{ LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" },
|
| 1054 |
+
{ LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
|
| 1055 |
+
{ LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" },
|
| 1056 |
+
{ LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
|
| 1057 |
+
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
|
| 1058 |
+
},
|
| 1059 |
+
},
|
| 1060 |
+
{
|
| 1061 |
+
LLM_ARCH_T5ENCODER,
|
| 1062 |
+
{
|
| 1063 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1064 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1065 |
+
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
|
| 1066 |
+
{ LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
|
| 1067 |
+
{ LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
|
| 1068 |
+
{ LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
|
| 1069 |
+
{ LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
|
| 1070 |
+
{ LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
|
| 1071 |
+
{ LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" },
|
| 1072 |
+
{ LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
|
| 1073 |
+
{ LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" },
|
| 1074 |
+
{ LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
|
| 1075 |
+
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
|
| 1076 |
+
},
|
| 1077 |
+
},
|
| 1078 |
+
{
|
| 1079 |
+
LLM_ARCH_JAIS,
|
| 1080 |
+
{
|
| 1081 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1082 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1083 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1084 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1085 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 1086 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1087 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1088 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1089 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1090 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1091 |
+
},
|
| 1092 |
+
},
|
| 1093 |
+
{
|
| 1094 |
+
LLM_ARCH_NEMOTRON,
|
| 1095 |
+
{
|
| 1096 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1097 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1098 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1099 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 1100 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1101 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1102 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1103 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1104 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1105 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 1106 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1107 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1108 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1109 |
+
},
|
| 1110 |
+
},
|
| 1111 |
+
{
|
| 1112 |
+
LLM_ARCH_EXAONE,
|
| 1113 |
+
{
|
| 1114 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1115 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1116 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1117 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 1118 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1119 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1120 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1121 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1122 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1123 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 1124 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1125 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1126 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1127 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1128 |
+
},
|
| 1129 |
+
},
|
| 1130 |
+
{
|
| 1131 |
+
LLM_ARCH_RWKV6,
|
| 1132 |
+
{
|
| 1133 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1134 |
+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 1135 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1136 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1137 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1138 |
+
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
| 1139 |
+
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
| 1140 |
+
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
| 1141 |
+
{ LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" },
|
| 1142 |
+
{ LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" },
|
| 1143 |
+
{ LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" },
|
| 1144 |
+
{ LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" },
|
| 1145 |
+
{ LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" },
|
| 1146 |
+
{ LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" },
|
| 1147 |
+
{ LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
|
| 1148 |
+
{ LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
|
| 1149 |
+
{ LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
|
| 1150 |
+
{ LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" },
|
| 1151 |
+
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
| 1152 |
+
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
| 1153 |
+
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
| 1154 |
+
{ LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" },
|
| 1155 |
+
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
| 1156 |
+
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
| 1157 |
+
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
|
| 1158 |
+
{ LLM_TENSOR_CHANNEL_MIX_LERP_R, "blk.%d.channel_mix_lerp_r" },
|
| 1159 |
+
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
|
| 1160 |
+
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
|
| 1161 |
+
{ LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" },
|
| 1162 |
+
},
|
| 1163 |
+
},
|
| 1164 |
+
{
|
| 1165 |
+
LLM_ARCH_GRANITE,
|
| 1166 |
+
{
|
| 1167 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1168 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1169 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1170 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1171 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1172 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1173 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1174 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1175 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1176 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1177 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1178 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1179 |
+
},
|
| 1180 |
+
},
|
| 1181 |
+
{
|
| 1182 |
+
LLM_ARCH_GRANITE_MOE,
|
| 1183 |
+
{
|
| 1184 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1185 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1186 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1187 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1188 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1189 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1190 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1191 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1192 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1193 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 1194 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 1195 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 1196 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 1197 |
+
},
|
| 1198 |
+
},
|
| 1199 |
+
{
|
| 1200 |
+
LLM_ARCH_CHAMELEON,
|
| 1201 |
+
{
|
| 1202 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1203 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1204 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1205 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1206 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1207 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1208 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1209 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1210 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1211 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1212 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1213 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1214 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 1215 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 1216 |
+
},
|
| 1217 |
+
},
|
| 1218 |
+
{
|
| 1219 |
+
LLM_ARCH_WAVTOKENIZER_DEC,
|
| 1220 |
+
{
|
| 1221 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1222 |
+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 1223 |
+
{ LLM_TENSOR_CONV1D, "conv1d" },
|
| 1224 |
+
{ LLM_TENSOR_CONVNEXT_DW, "convnext.%d.dw" },
|
| 1225 |
+
{ LLM_TENSOR_CONVNEXT_NORM, "convnext.%d.norm" },
|
| 1226 |
+
{ LLM_TENSOR_CONVNEXT_PW1, "convnext.%d.pw1" },
|
| 1227 |
+
{ LLM_TENSOR_CONVNEXT_PW2, "convnext.%d.pw2" },
|
| 1228 |
+
{ LLM_TENSOR_CONVNEXT_GAMMA, "convnext.%d.gamma" },
|
| 1229 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1230 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1231 |
+
{ LLM_TENSOR_POS_NET_CONV1, "posnet.%d.conv1" },
|
| 1232 |
+
{ LLM_TENSOR_POS_NET_CONV2, "posnet.%d.conv2" },
|
| 1233 |
+
{ LLM_TENSOR_POS_NET_NORM, "posnet.%d.norm" },
|
| 1234 |
+
{ LLM_TENSOR_POS_NET_NORM1, "posnet.%d.norm1" },
|
| 1235 |
+
{ LLM_TENSOR_POS_NET_NORM2, "posnet.%d.norm2" },
|
| 1236 |
+
{ LLM_TENSOR_POS_NET_ATTN_NORM, "posnet.%d.attn_norm" },
|
| 1237 |
+
{ LLM_TENSOR_POS_NET_ATTN_Q, "posnet.%d.attn_q" },
|
| 1238 |
+
{ LLM_TENSOR_POS_NET_ATTN_K, "posnet.%d.attn_k" },
|
| 1239 |
+
{ LLM_TENSOR_POS_NET_ATTN_V, "posnet.%d.attn_v" },
|
| 1240 |
+
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
|
| 1241 |
+
},
|
| 1242 |
+
},
|
| 1243 |
+
{
|
| 1244 |
+
LLM_ARCH_UNKNOWN,
|
| 1245 |
+
{
|
| 1246 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1247 |
+
},
|
| 1248 |
+
},
|
| 1249 |
+
};
|
| 1250 |
+
|
| 1251 |
+
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
| 1252 |
+
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
| 1253 |
+
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
| 1254 |
+
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
| 1255 |
+
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
| 1256 |
+
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
| 1257 |
+
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
| 1258 |
+
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
| 1259 |
+
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
| 1260 |
+
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
| 1261 |
+
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
| 1262 |
+
{LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
| 1263 |
+
{LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
| 1264 |
+
{LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
| 1265 |
+
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1266 |
+
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1267 |
+
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1268 |
+
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1269 |
+
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1270 |
+
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1271 |
+
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1272 |
+
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1273 |
+
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1274 |
+
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1275 |
+
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1276 |
+
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1277 |
+
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1278 |
+
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1279 |
+
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1280 |
+
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1281 |
+
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1282 |
+
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1283 |
+
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1284 |
+
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1285 |
+
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1286 |
+
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1287 |
+
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1288 |
+
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1289 |
+
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1290 |
+
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1291 |
+
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1292 |
+
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1293 |
+
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1294 |
+
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1295 |
+
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1296 |
+
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1297 |
+
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1298 |
+
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1299 |
+
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1300 |
+
{LLM_TENSOR_DEC_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1301 |
+
{LLM_TENSOR_DEC_CROSS_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1302 |
+
{LLM_TENSOR_DEC_CROSS_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1303 |
+
{LLM_TENSOR_DEC_CROSS_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1304 |
+
{LLM_TENSOR_DEC_CROSS_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1305 |
+
{LLM_TENSOR_DEC_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1306 |
+
{LLM_TENSOR_DEC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1307 |
+
{LLM_TENSOR_DEC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1308 |
+
{LLM_TENSOR_ENC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1309 |
+
{LLM_TENSOR_ENC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1310 |
+
{LLM_TENSOR_ENC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1311 |
+
{LLM_TENSOR_ENC_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1312 |
+
{LLM_TENSOR_ENC_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1313 |
+
{LLM_TENSOR_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1314 |
+
{LLM_TENSOR_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1315 |
+
{LLM_TENSOR_FFN_GATE_INP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1316 |
+
{LLM_TENSOR_FFN_GATE_INP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1317 |
+
{LLM_TENSOR_SSM_IN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1318 |
+
{LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1319 |
+
{LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1320 |
+
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1321 |
+
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1322 |
+
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1323 |
+
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1324 |
+
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1325 |
+
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1326 |
+
{LLM_TENSOR_TIME_MIX_VALUE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1327 |
+
{LLM_TENSOR_TIME_MIX_RECEPTANCE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1328 |
+
{LLM_TENSOR_TIME_MIX_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1329 |
+
{LLM_TENSOR_TIME_MIX_OUTPUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1330 |
+
{LLM_TENSOR_CHANNEL_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1331 |
+
{LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1332 |
+
{LLM_TENSOR_CHANNEL_MIX_VALUE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1333 |
+
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
|
| 1334 |
+
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
| 1335 |
+
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
|
| 1336 |
+
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1337 |
+
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1338 |
+
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1339 |
+
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1340 |
+
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1341 |
+
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1342 |
+
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1343 |
+
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1344 |
+
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1345 |
+
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1346 |
+
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1347 |
+
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
| 1348 |
+
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1349 |
+
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1350 |
+
{LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1351 |
+
{LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1352 |
+
{LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1353 |
+
{LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1354 |
+
{LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1355 |
+
{LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1356 |
+
{LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1357 |
+
{LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1358 |
+
{LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1359 |
+
{LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1360 |
+
{LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1361 |
+
{LLM_TENSOR_FFN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1362 |
+
{LLM_TENSOR_DEC_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1363 |
+
{LLM_TENSOR_DEC_CROSS_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1364 |
+
{LLM_TENSOR_DEC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1365 |
+
{LLM_TENSOR_ENC_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1366 |
+
{LLM_TENSOR_ENC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1367 |
+
{LLM_TENSOR_DEC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
|
| 1368 |
+
{LLM_TENSOR_ENC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
|
| 1369 |
+
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
| 1370 |
+
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
| 1371 |
+
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
| 1372 |
+
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1373 |
+
// this tensor is loaded for T5, but never used
|
| 1374 |
+
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
| 1375 |
+
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
|
| 1376 |
+
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1377 |
+
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1378 |
+
{LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1379 |
+
{LLM_TENSOR_POS_NET_CONV1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
|
| 1380 |
+
{LLM_TENSOR_POS_NET_CONV2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
|
| 1381 |
+
{LLM_TENSOR_POS_NET_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1382 |
+
{LLM_TENSOR_POS_NET_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1383 |
+
{LLM_TENSOR_POS_NET_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1384 |
+
{LLM_TENSOR_POS_NET_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1385 |
+
{LLM_TENSOR_POS_NET_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1386 |
+
{LLM_TENSOR_CONVNEXT_DW, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
|
| 1387 |
+
{LLM_TENSOR_CONVNEXT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1388 |
+
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1389 |
+
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1390 |
+
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1391 |
+
};
|
| 1392 |
+
|
| 1393 |
+
LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
|
| 1394 |
+
|
| 1395 |
+
std::string LLM_KV::operator()(llm_kv kv) const {
|
| 1396 |
+
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
| 1397 |
+
}
|
| 1398 |
+
|
| 1399 |
+
std::string LLM_TN_IMPL::str() const {
|
| 1400 |
+
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
|
| 1401 |
+
return "__missing__";
|
| 1402 |
+
}
|
| 1403 |
+
|
| 1404 |
+
std::string name = ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid, xid);
|
| 1405 |
+
|
| 1406 |
+
if (suffix != nullptr) {
|
| 1407 |
+
name += ".";
|
| 1408 |
+
name += suffix;
|
| 1409 |
+
}
|
| 1410 |
+
|
| 1411 |
+
return name;
|
| 1412 |
+
}
|
| 1413 |
+
|
| 1414 |
+
const char * llm_arch_name(llm_arch arch) {
|
| 1415 |
+
auto it = LLM_ARCH_NAMES.find(arch);
|
| 1416 |
+
if (it == LLM_ARCH_NAMES.end()) {
|
| 1417 |
+
return "unknown";
|
| 1418 |
+
}
|
| 1419 |
+
return it->second;
|
| 1420 |
+
}
|
| 1421 |
+
|
| 1422 |
+
llm_arch llm_arch_from_string(const std::string & name) {
|
| 1423 |
+
for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
|
| 1424 |
+
if (kv.second == name) {
|
| 1425 |
+
return kv.first;
|
| 1426 |
+
}
|
| 1427 |
+
}
|
| 1428 |
+
|
| 1429 |
+
return LLM_ARCH_UNKNOWN;
|
| 1430 |
+
}
|
| 1431 |
+
|
| 1432 |
+
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
| 1433 |
+
return LLM_TENSOR_INFOS.at(tensor);
|
| 1434 |
+
}
|
examples/talk-llama/llama-arch.h
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "ggml.h" // ggml_op
|
| 4 |
+
|
| 5 |
+
#include <string>
|
| 6 |
+
|
| 7 |
+
//
|
| 8 |
+
// gguf constants (sync with gguf.py)
|
| 9 |
+
//
|
| 10 |
+
|
| 11 |
+
enum llm_arch {
|
| 12 |
+
LLM_ARCH_LLAMA,
|
| 13 |
+
LLM_ARCH_DECI,
|
| 14 |
+
LLM_ARCH_FALCON,
|
| 15 |
+
LLM_ARCH_BAICHUAN,
|
| 16 |
+
LLM_ARCH_GROK,
|
| 17 |
+
LLM_ARCH_GPT2,
|
| 18 |
+
LLM_ARCH_GPTJ,
|
| 19 |
+
LLM_ARCH_GPTNEOX,
|
| 20 |
+
LLM_ARCH_MPT,
|
| 21 |
+
LLM_ARCH_STARCODER,
|
| 22 |
+
LLM_ARCH_REFACT,
|
| 23 |
+
LLM_ARCH_BERT,
|
| 24 |
+
LLM_ARCH_NOMIC_BERT,
|
| 25 |
+
LLM_ARCH_JINA_BERT_V2,
|
| 26 |
+
LLM_ARCH_BLOOM,
|
| 27 |
+
LLM_ARCH_STABLELM,
|
| 28 |
+
LLM_ARCH_QWEN,
|
| 29 |
+
LLM_ARCH_QWEN2,
|
| 30 |
+
LLM_ARCH_QWEN2MOE,
|
| 31 |
+
LLM_ARCH_QWEN2VL,
|
| 32 |
+
LLM_ARCH_PHI2,
|
| 33 |
+
LLM_ARCH_PHI3,
|
| 34 |
+
LLM_ARCH_PLAMO,
|
| 35 |
+
LLM_ARCH_CODESHELL,
|
| 36 |
+
LLM_ARCH_ORION,
|
| 37 |
+
LLM_ARCH_INTERNLM2,
|
| 38 |
+
LLM_ARCH_MINICPM,
|
| 39 |
+
LLM_ARCH_MINICPM3,
|
| 40 |
+
LLM_ARCH_GEMMA,
|
| 41 |
+
LLM_ARCH_GEMMA2,
|
| 42 |
+
LLM_ARCH_STARCODER2,
|
| 43 |
+
LLM_ARCH_MAMBA,
|
| 44 |
+
LLM_ARCH_XVERSE,
|
| 45 |
+
LLM_ARCH_COMMAND_R,
|
| 46 |
+
LLM_ARCH_COHERE2,
|
| 47 |
+
LLM_ARCH_DBRX,
|
| 48 |
+
LLM_ARCH_OLMO,
|
| 49 |
+
LLM_ARCH_OLMO2,
|
| 50 |
+
LLM_ARCH_OLMOE,
|
| 51 |
+
LLM_ARCH_OPENELM,
|
| 52 |
+
LLM_ARCH_ARCTIC,
|
| 53 |
+
LLM_ARCH_DEEPSEEK,
|
| 54 |
+
LLM_ARCH_DEEPSEEK2,
|
| 55 |
+
LLM_ARCH_CHATGLM,
|
| 56 |
+
LLM_ARCH_BITNET,
|
| 57 |
+
LLM_ARCH_T5,
|
| 58 |
+
LLM_ARCH_T5ENCODER,
|
| 59 |
+
LLM_ARCH_JAIS,
|
| 60 |
+
LLM_ARCH_NEMOTRON,
|
| 61 |
+
LLM_ARCH_EXAONE,
|
| 62 |
+
LLM_ARCH_RWKV6,
|
| 63 |
+
LLM_ARCH_GRANITE,
|
| 64 |
+
LLM_ARCH_GRANITE_MOE,
|
| 65 |
+
LLM_ARCH_CHAMELEON,
|
| 66 |
+
LLM_ARCH_WAVTOKENIZER_DEC,
|
| 67 |
+
LLM_ARCH_UNKNOWN,
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
enum llm_kv {
|
| 71 |
+
LLM_KV_GENERAL_TYPE,
|
| 72 |
+
LLM_KV_GENERAL_ARCHITECTURE,
|
| 73 |
+
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
| 74 |
+
LLM_KV_GENERAL_ALIGNMENT,
|
| 75 |
+
LLM_KV_GENERAL_NAME,
|
| 76 |
+
LLM_KV_GENERAL_AUTHOR,
|
| 77 |
+
LLM_KV_GENERAL_VERSION,
|
| 78 |
+
LLM_KV_GENERAL_URL,
|
| 79 |
+
LLM_KV_GENERAL_DESCRIPTION,
|
| 80 |
+
LLM_KV_GENERAL_LICENSE,
|
| 81 |
+
LLM_KV_GENERAL_SOURCE_URL,
|
| 82 |
+
LLM_KV_GENERAL_SOURCE_HF_REPO,
|
| 83 |
+
|
| 84 |
+
LLM_KV_VOCAB_SIZE,
|
| 85 |
+
LLM_KV_CONTEXT_LENGTH,
|
| 86 |
+
LLM_KV_EMBEDDING_LENGTH,
|
| 87 |
+
LLM_KV_FEATURES_LENGTH,
|
| 88 |
+
LLM_KV_BLOCK_COUNT,
|
| 89 |
+
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
|
| 90 |
+
LLM_KV_FEED_FORWARD_LENGTH,
|
| 91 |
+
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
|
| 92 |
+
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
|
| 93 |
+
LLM_KV_USE_PARALLEL_RESIDUAL,
|
| 94 |
+
LLM_KV_TENSOR_DATA_LAYOUT,
|
| 95 |
+
LLM_KV_EXPERT_COUNT,
|
| 96 |
+
LLM_KV_EXPERT_USED_COUNT,
|
| 97 |
+
LLM_KV_EXPERT_SHARED_COUNT,
|
| 98 |
+
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
| 99 |
+
LLM_KV_EXPERT_WEIGHTS_NORM,
|
| 100 |
+
LLM_KV_EXPERT_GATING_FUNC,
|
| 101 |
+
LLM_KV_POOLING_TYPE,
|
| 102 |
+
LLM_KV_LOGIT_SCALE,
|
| 103 |
+
LLM_KV_DECODER_START_TOKEN_ID,
|
| 104 |
+
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
|
| 105 |
+
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
|
| 106 |
+
LLM_KV_SWIN_NORM,
|
| 107 |
+
LLM_KV_RESCALE_EVERY_N_LAYERS,
|
| 108 |
+
LLM_KV_TIME_MIX_EXTRA_DIM,
|
| 109 |
+
LLM_KV_TIME_DECAY_EXTRA_DIM,
|
| 110 |
+
LLM_KV_RESIDUAL_SCALE,
|
| 111 |
+
LLM_KV_EMBEDDING_SCALE,
|
| 112 |
+
|
| 113 |
+
LLM_KV_ATTENTION_HEAD_COUNT,
|
| 114 |
+
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
| 115 |
+
LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
|
| 116 |
+
LLM_KV_ATTENTION_CLAMP_KQV,
|
| 117 |
+
LLM_KV_ATTENTION_KEY_LENGTH,
|
| 118 |
+
LLM_KV_ATTENTION_VALUE_LENGTH,
|
| 119 |
+
LLM_KV_ATTENTION_LAYERNORM_EPS,
|
| 120 |
+
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
|
| 121 |
+
LLM_KV_ATTENTION_GROUPNORM_EPS,
|
| 122 |
+
LLM_KV_ATTENTION_GROUPNORM_GROUPS,
|
| 123 |
+
LLM_KV_ATTENTION_CAUSAL,
|
| 124 |
+
LLM_KV_ATTENTION_Q_LORA_RANK,
|
| 125 |
+
LLM_KV_ATTENTION_KV_LORA_RANK,
|
| 126 |
+
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
| 127 |
+
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
| 128 |
+
LLM_KV_ATTENTION_SCALE,
|
| 129 |
+
|
| 130 |
+
LLM_KV_ROPE_DIMENSION_COUNT,
|
| 131 |
+
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
| 132 |
+
LLM_KV_ROPE_FREQ_BASE,
|
| 133 |
+
LLM_KV_ROPE_SCALE_LINEAR,
|
| 134 |
+
LLM_KV_ROPE_SCALING_TYPE,
|
| 135 |
+
LLM_KV_ROPE_SCALING_FACTOR,
|
| 136 |
+
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
|
| 137 |
+
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
| 138 |
+
LLM_KV_ROPE_SCALING_FINETUNED,
|
| 139 |
+
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
|
| 140 |
+
|
| 141 |
+
LLM_KV_SPLIT_NO,
|
| 142 |
+
LLM_KV_SPLIT_COUNT,
|
| 143 |
+
LLM_KV_SPLIT_TENSORS_COUNT,
|
| 144 |
+
|
| 145 |
+
LLM_KV_SSM_INNER_SIZE,
|
| 146 |
+
LLM_KV_SSM_CONV_KERNEL,
|
| 147 |
+
LLM_KV_SSM_STATE_SIZE,
|
| 148 |
+
LLM_KV_SSM_TIME_STEP_RANK,
|
| 149 |
+
LLM_KV_SSM_DT_B_C_RMS,
|
| 150 |
+
|
| 151 |
+
LLM_KV_WKV_HEAD_SIZE,
|
| 152 |
+
|
| 153 |
+
LLM_KV_TOKENIZER_MODEL,
|
| 154 |
+
LLM_KV_TOKENIZER_PRE,
|
| 155 |
+
LLM_KV_TOKENIZER_LIST,
|
| 156 |
+
LLM_KV_TOKENIZER_TOKEN_TYPE,
|
| 157 |
+
LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
|
| 158 |
+
LLM_KV_TOKENIZER_SCORES,
|
| 159 |
+
LLM_KV_TOKENIZER_MERGES,
|
| 160 |
+
LLM_KV_TOKENIZER_BOS_ID,
|
| 161 |
+
LLM_KV_TOKENIZER_EOS_ID,
|
| 162 |
+
LLM_KV_TOKENIZER_EOT_ID,
|
| 163 |
+
LLM_KV_TOKENIZER_EOM_ID,
|
| 164 |
+
LLM_KV_TOKENIZER_UNK_ID,
|
| 165 |
+
LLM_KV_TOKENIZER_SEP_ID,
|
| 166 |
+
LLM_KV_TOKENIZER_PAD_ID,
|
| 167 |
+
LLM_KV_TOKENIZER_CLS_ID,
|
| 168 |
+
LLM_KV_TOKENIZER_MASK_ID,
|
| 169 |
+
LLM_KV_TOKENIZER_ADD_BOS,
|
| 170 |
+
LLM_KV_TOKENIZER_ADD_EOS,
|
| 171 |
+
LLM_KV_TOKENIZER_ADD_PREFIX,
|
| 172 |
+
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
|
| 173 |
+
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
|
| 174 |
+
LLM_KV_TOKENIZER_HF_JSON,
|
| 175 |
+
LLM_KV_TOKENIZER_RWKV,
|
| 176 |
+
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
| 177 |
+
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
| 178 |
+
LLM_KV_TOKENIZER_FIM_MID_ID,
|
| 179 |
+
LLM_KV_TOKENIZER_FIM_PAD_ID,
|
| 180 |
+
LLM_KV_TOKENIZER_FIM_REP_ID,
|
| 181 |
+
LLM_KV_TOKENIZER_FIM_SEP_ID,
|
| 182 |
+
|
| 183 |
+
LLM_KV_ADAPTER_TYPE,
|
| 184 |
+
LLM_KV_ADAPTER_LORA_ALPHA,
|
| 185 |
+
|
| 186 |
+
LLM_KV_POSNET_EMBEDDING_LENGTH,
|
| 187 |
+
LLM_KV_POSNET_BLOCK_COUNT,
|
| 188 |
+
|
| 189 |
+
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
|
| 190 |
+
LLM_KV_CONVNEXT_BLOCK_COUNT,
|
| 191 |
+
|
| 192 |
+
// deprecated:
|
| 193 |
+
LLM_KV_TOKENIZER_PREFIX_ID,
|
| 194 |
+
LLM_KV_TOKENIZER_SUFFIX_ID,
|
| 195 |
+
LLM_KV_TOKENIZER_MIDDLE_ID,
|
| 196 |
+
};
|
| 197 |
+
|
| 198 |
+
enum llm_tensor {
|
| 199 |
+
LLM_TENSOR_TOKEN_EMBD,
|
| 200 |
+
LLM_TENSOR_TOKEN_EMBD_NORM,
|
| 201 |
+
LLM_TENSOR_TOKEN_TYPES,
|
| 202 |
+
LLM_TENSOR_POS_EMBD,
|
| 203 |
+
LLM_TENSOR_OUTPUT,
|
| 204 |
+
LLM_TENSOR_OUTPUT_NORM,
|
| 205 |
+
LLM_TENSOR_ROPE_FREQS,
|
| 206 |
+
LLM_TENSOR_ROPE_FACTORS_LONG,
|
| 207 |
+
LLM_TENSOR_ROPE_FACTORS_SHORT,
|
| 208 |
+
LLM_TENSOR_ATTN_Q,
|
| 209 |
+
LLM_TENSOR_ATTN_K,
|
| 210 |
+
LLM_TENSOR_ATTN_V,
|
| 211 |
+
LLM_TENSOR_ATTN_QKV,
|
| 212 |
+
LLM_TENSOR_ATTN_OUT,
|
| 213 |
+
LLM_TENSOR_ATTN_NORM,
|
| 214 |
+
LLM_TENSOR_ATTN_NORM_2,
|
| 215 |
+
LLM_TENSOR_ATTN_OUT_NORM,
|
| 216 |
+
LLM_TENSOR_ATTN_POST_NORM,
|
| 217 |
+
LLM_TENSOR_ATTN_ROT_EMBD,
|
| 218 |
+
LLM_TENSOR_FFN_GATE_INP,
|
| 219 |
+
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
| 220 |
+
LLM_TENSOR_FFN_NORM,
|
| 221 |
+
LLM_TENSOR_FFN_POST_NORM,
|
| 222 |
+
LLM_TENSOR_FFN_GATE,
|
| 223 |
+
LLM_TENSOR_FFN_DOWN,
|
| 224 |
+
LLM_TENSOR_FFN_UP,
|
| 225 |
+
LLM_TENSOR_FFN_ACT,
|
| 226 |
+
LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility
|
| 227 |
+
LLM_TENSOR_FFN_GATE_EXP,
|
| 228 |
+
LLM_TENSOR_FFN_UP_EXP,
|
| 229 |
+
LLM_TENSOR_FFN_NORM_EXPS,
|
| 230 |
+
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
|
| 231 |
+
LLM_TENSOR_FFN_GATE_EXPS,
|
| 232 |
+
LLM_TENSOR_FFN_UP_EXPS,
|
| 233 |
+
LLM_TENSOR_FFN_DOWN_SHEXP,
|
| 234 |
+
LLM_TENSOR_FFN_GATE_SHEXP,
|
| 235 |
+
LLM_TENSOR_FFN_UP_SHEXP,
|
| 236 |
+
LLM_TENSOR_FFN_EXP_PROBS_B,
|
| 237 |
+
LLM_TENSOR_ATTN_Q_NORM,
|
| 238 |
+
LLM_TENSOR_ATTN_K_NORM,
|
| 239 |
+
LLM_TENSOR_LAYER_OUT_NORM,
|
| 240 |
+
LLM_TENSOR_SSM_IN,
|
| 241 |
+
LLM_TENSOR_SSM_CONV1D,
|
| 242 |
+
LLM_TENSOR_SSM_X,
|
| 243 |
+
LLM_TENSOR_SSM_DT,
|
| 244 |
+
LLM_TENSOR_SSM_A,
|
| 245 |
+
LLM_TENSOR_SSM_D,
|
| 246 |
+
LLM_TENSOR_SSM_OUT,
|
| 247 |
+
LLM_TENSOR_TIME_MIX_W1,
|
| 248 |
+
LLM_TENSOR_TIME_MIX_W2,
|
| 249 |
+
LLM_TENSOR_TIME_MIX_LERP_X,
|
| 250 |
+
LLM_TENSOR_TIME_MIX_LERP_W,
|
| 251 |
+
LLM_TENSOR_TIME_MIX_LERP_K,
|
| 252 |
+
LLM_TENSOR_TIME_MIX_LERP_V,
|
| 253 |
+
LLM_TENSOR_TIME_MIX_LERP_R,
|
| 254 |
+
LLM_TENSOR_TIME_MIX_LERP_G,
|
| 255 |
+
LLM_TENSOR_TIME_MIX_FIRST,
|
| 256 |
+
LLM_TENSOR_TIME_MIX_DECAY,
|
| 257 |
+
LLM_TENSOR_TIME_MIX_DECAY_W1,
|
| 258 |
+
LLM_TENSOR_TIME_MIX_DECAY_W2,
|
| 259 |
+
LLM_TENSOR_TIME_MIX_KEY,
|
| 260 |
+
LLM_TENSOR_TIME_MIX_VALUE,
|
| 261 |
+
LLM_TENSOR_TIME_MIX_RECEPTANCE,
|
| 262 |
+
LLM_TENSOR_TIME_MIX_GATE,
|
| 263 |
+
LLM_TENSOR_TIME_MIX_LN,
|
| 264 |
+
LLM_TENSOR_TIME_MIX_OUTPUT,
|
| 265 |
+
LLM_TENSOR_CHANNEL_MIX_LERP_K,
|
| 266 |
+
LLM_TENSOR_CHANNEL_MIX_LERP_R,
|
| 267 |
+
LLM_TENSOR_CHANNEL_MIX_KEY,
|
| 268 |
+
LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,
|
| 269 |
+
LLM_TENSOR_CHANNEL_MIX_VALUE,
|
| 270 |
+
LLM_TENSOR_ATTN_Q_A,
|
| 271 |
+
LLM_TENSOR_ATTN_Q_B,
|
| 272 |
+
LLM_TENSOR_ATTN_KV_A_MQA,
|
| 273 |
+
LLM_TENSOR_ATTN_KV_B,
|
| 274 |
+
LLM_TENSOR_ATTN_Q_A_NORM,
|
| 275 |
+
LLM_TENSOR_ATTN_KV_A_NORM,
|
| 276 |
+
LLM_TENSOR_ATTN_SUB_NORM,
|
| 277 |
+
LLM_TENSOR_FFN_SUB_NORM,
|
| 278 |
+
LLM_TENSOR_DEC_ATTN_NORM,
|
| 279 |
+
LLM_TENSOR_DEC_ATTN_Q,
|
| 280 |
+
LLM_TENSOR_DEC_ATTN_K,
|
| 281 |
+
LLM_TENSOR_DEC_ATTN_V,
|
| 282 |
+
LLM_TENSOR_DEC_ATTN_OUT,
|
| 283 |
+
LLM_TENSOR_DEC_ATTN_REL_B,
|
| 284 |
+
LLM_TENSOR_DEC_CROSS_ATTN_NORM,
|
| 285 |
+
LLM_TENSOR_DEC_CROSS_ATTN_Q,
|
| 286 |
+
LLM_TENSOR_DEC_CROSS_ATTN_K,
|
| 287 |
+
LLM_TENSOR_DEC_CROSS_ATTN_V,
|
| 288 |
+
LLM_TENSOR_DEC_CROSS_ATTN_OUT,
|
| 289 |
+
LLM_TENSOR_DEC_CROSS_ATTN_REL_B,
|
| 290 |
+
LLM_TENSOR_DEC_FFN_NORM,
|
| 291 |
+
LLM_TENSOR_DEC_FFN_GATE,
|
| 292 |
+
LLM_TENSOR_DEC_FFN_DOWN,
|
| 293 |
+
LLM_TENSOR_DEC_FFN_UP,
|
| 294 |
+
LLM_TENSOR_DEC_OUTPUT_NORM,
|
| 295 |
+
LLM_TENSOR_ENC_ATTN_NORM,
|
| 296 |
+
LLM_TENSOR_ENC_ATTN_Q,
|
| 297 |
+
LLM_TENSOR_ENC_ATTN_K,
|
| 298 |
+
LLM_TENSOR_ENC_ATTN_V,
|
| 299 |
+
LLM_TENSOR_ENC_ATTN_OUT,
|
| 300 |
+
LLM_TENSOR_ENC_ATTN_REL_B,
|
| 301 |
+
LLM_TENSOR_ENC_FFN_NORM,
|
| 302 |
+
LLM_TENSOR_ENC_FFN_GATE,
|
| 303 |
+
LLM_TENSOR_ENC_FFN_DOWN,
|
| 304 |
+
LLM_TENSOR_ENC_FFN_UP,
|
| 305 |
+
LLM_TENSOR_ENC_OUTPUT_NORM,
|
| 306 |
+
LLM_TENSOR_CLS,
|
| 307 |
+
LLM_TENSOR_CLS_OUT,
|
| 308 |
+
LLM_TENSOR_CONV1D,
|
| 309 |
+
LLM_TENSOR_CONVNEXT_DW,
|
| 310 |
+
LLM_TENSOR_CONVNEXT_NORM,
|
| 311 |
+
LLM_TENSOR_CONVNEXT_PW1,
|
| 312 |
+
LLM_TENSOR_CONVNEXT_PW2,
|
| 313 |
+
LLM_TENSOR_CONVNEXT_GAMMA,
|
| 314 |
+
LLM_TENSOR_POS_NET_CONV1,
|
| 315 |
+
LLM_TENSOR_POS_NET_CONV2,
|
| 316 |
+
LLM_TENSOR_POS_NET_NORM,
|
| 317 |
+
LLM_TENSOR_POS_NET_NORM1,
|
| 318 |
+
LLM_TENSOR_POS_NET_NORM2,
|
| 319 |
+
LLM_TENSOR_POS_NET_ATTN_NORM,
|
| 320 |
+
LLM_TENSOR_POS_NET_ATTN_Q,
|
| 321 |
+
LLM_TENSOR_POS_NET_ATTN_K,
|
| 322 |
+
LLM_TENSOR_POS_NET_ATTN_V,
|
| 323 |
+
LLM_TENSOR_POS_NET_ATTN_OUT,
|
| 324 |
+
};
|
| 325 |
+
|
| 326 |
+
enum llm_tensor_layer {
|
| 327 |
+
LLM_TENSOR_LAYER_INPUT,
|
| 328 |
+
LLM_TENSOR_LAYER_REPEATING,
|
| 329 |
+
LLM_TENSOR_LAYER_OUTPUT,
|
| 330 |
+
};
|
| 331 |
+
|
| 332 |
+
struct LLM_KV {
|
| 333 |
+
LLM_KV(llm_arch arch);
|
| 334 |
+
|
| 335 |
+
llm_arch arch;
|
| 336 |
+
|
| 337 |
+
std::string operator()(llm_kv kv) const;
|
| 338 |
+
};
|
| 339 |
+
|
| 340 |
+
// helper to handle gguf constants
|
| 341 |
+
// usage:
|
| 342 |
+
//
|
| 343 |
+
// const auto tn = LLM_TN(LLM_ARCH_LLAMA);
|
| 344 |
+
//
|
| 345 |
+
// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output"
|
| 346 |
+
// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias"
|
| 347 |
+
// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight"
|
| 348 |
+
//
|
| 349 |
+
struct LLM_TN_IMPL {
|
| 350 |
+
const llm_arch arch;
|
| 351 |
+
const llm_tensor tensor;
|
| 352 |
+
const char * const suffix;
|
| 353 |
+
const int bid;
|
| 354 |
+
const int xid;
|
| 355 |
+
|
| 356 |
+
std::string str() const;
|
| 357 |
+
|
| 358 |
+
operator std::string() const {
|
| 359 |
+
return str();
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) {
|
| 363 |
+
return str == tn.str();
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) {
|
| 367 |
+
return str != tn.str();
|
| 368 |
+
}
|
| 369 |
+
};
|
| 370 |
+
|
| 371 |
+
struct LLM_TN {
|
| 372 |
+
LLM_TN(llm_arch arch) : arch(arch) {}
|
| 373 |
+
|
| 374 |
+
llm_arch arch;
|
| 375 |
+
|
| 376 |
+
LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
|
| 377 |
+
return { arch, tensor, suffix, bid, xid };
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const {
|
| 381 |
+
return { arch, tensor, nullptr, bid, xid };
|
| 382 |
+
}
|
| 383 |
+
};
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
struct llm_tensor_info {
|
| 387 |
+
llm_tensor_layer layer;
|
| 388 |
+
ggml_op op;
|
| 389 |
+
};
|
| 390 |
+
|
| 391 |
+
const char * llm_arch_name(llm_arch arch);
|
| 392 |
+
|
| 393 |
+
llm_arch llm_arch_from_string(const std::string & name);
|
| 394 |
+
|
| 395 |
+
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
examples/talk-llama/llama-batch.cpp
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-batch.h"
|
| 2 |
+
|
| 3 |
+
#include <cstring>
|
| 4 |
+
#include <algorithm>
|
| 5 |
+
|
| 6 |
+
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
| 7 |
+
// clear empty sequences
|
| 8 |
+
// the previous ubatch is assumed to be gone,
|
| 9 |
+
// so nothing should refer to values in these sequences anymore.
|
| 10 |
+
for (size_t i = seq.size(); i-- > 0;) {
|
| 11 |
+
if (seq[i].length == 0) {
|
| 12 |
+
seq.pop_back();
|
| 13 |
+
} else {
|
| 14 |
+
break;
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
| 18 |
+
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
| 19 |
+
ubatch_pos.resize(n_ubatch);
|
| 20 |
+
ubatch_n_seq_id.resize(n_ubatch);
|
| 21 |
+
ubatch_seq_id.resize(n_ubatch);
|
| 22 |
+
ubatch_output.resize(n_ubatch);
|
| 23 |
+
llama_ubatch ubatch = {
|
| 24 |
+
/*equal_seqs =*/ true,
|
| 25 |
+
/*n_tokens =*/ 0,
|
| 26 |
+
/*n_seq_tokens =*/ 0,
|
| 27 |
+
/*n_seqs =*/ 0,
|
| 28 |
+
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
|
| 29 |
+
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
|
| 30 |
+
/*pos =*/ ubatch_pos.data(),
|
| 31 |
+
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
| 32 |
+
/*seq_id =*/ ubatch_seq_id.data(),
|
| 33 |
+
/*output =*/ ubatch_output.data(),
|
| 34 |
+
};
|
| 35 |
+
return ubatch;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
|
| 39 |
+
GGML_ASSERT(batch != nullptr);
|
| 40 |
+
GGML_ASSERT(length <= seq.length);
|
| 41 |
+
// Can only add sequences of equal lengths to a batch,
|
| 42 |
+
// otherwise it isn't clear to which sequence a token belongs
|
| 43 |
+
GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
|
| 44 |
+
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
|
| 45 |
+
// NOTE: loops are separated for cache-friendliness
|
| 46 |
+
if (batch->token) {
|
| 47 |
+
if (ubatch.equal_seqs) {
|
| 48 |
+
for (size_t i = 0; i < length; ++i) {
|
| 49 |
+
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
| 50 |
+
}
|
| 51 |
+
} else {
|
| 52 |
+
// simple split
|
| 53 |
+
ubatch.token = batch->token + seq.offset;
|
| 54 |
+
}
|
| 55 |
+
} else {
|
| 56 |
+
ubatch.token = nullptr;
|
| 57 |
+
}
|
| 58 |
+
if (batch->embd) {
|
| 59 |
+
if (ubatch.equal_seqs) {
|
| 60 |
+
for (size_t i = 0; i < length; ++i) {
|
| 61 |
+
memcpy(
|
| 62 |
+
ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
|
| 63 |
+
batch->embd + (n_embd * ids[seq.offset + i]),
|
| 64 |
+
n_embd * sizeof(float)
|
| 65 |
+
);
|
| 66 |
+
}
|
| 67 |
+
} else {
|
| 68 |
+
// simple split
|
| 69 |
+
ubatch.embd = batch->embd + (n_embd * seq.offset);
|
| 70 |
+
}
|
| 71 |
+
} else {
|
| 72 |
+
ubatch.embd = nullptr;
|
| 73 |
+
}
|
| 74 |
+
if (ubatch.equal_seqs) {
|
| 75 |
+
for (size_t i = 0; i < length; ++i) {
|
| 76 |
+
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
| 77 |
+
}
|
| 78 |
+
} else {
|
| 79 |
+
// simple split
|
| 80 |
+
ubatch.pos = batch->pos + seq.offset;
|
| 81 |
+
}
|
| 82 |
+
if (ubatch.equal_seqs) {
|
| 83 |
+
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
| 84 |
+
if (seq.seq_id) {
|
| 85 |
+
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
| 86 |
+
}
|
| 87 |
+
} else {
|
| 88 |
+
// simple split
|
| 89 |
+
if (batch->n_seq_id) {
|
| 90 |
+
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
| 91 |
+
} else {
|
| 92 |
+
for (size_t i = 0; i < length; ++i) {
|
| 93 |
+
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
if (batch->seq_id) {
|
| 97 |
+
ubatch.seq_id = batch->seq_id + seq.offset;
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
if (logits_all) {
|
| 101 |
+
for (size_t i = 0; i < length; ++i) {
|
| 102 |
+
ubatch.output[ubatch.n_tokens + i] = 1;
|
| 103 |
+
out_ids.push_back(ids[seq.offset + i]);
|
| 104 |
+
}
|
| 105 |
+
} else if (batch->logits) {
|
| 106 |
+
if (ubatch.equal_seqs) {
|
| 107 |
+
for (size_t i = 0; i < length; ++i) {
|
| 108 |
+
size_t id = ids[seq.offset + i];
|
| 109 |
+
int8_t is_output = batch->logits[id];
|
| 110 |
+
ubatch.output[ubatch.n_tokens + i] = is_output;
|
| 111 |
+
if (is_output) { out_ids.push_back(id); }
|
| 112 |
+
}
|
| 113 |
+
} else {
|
| 114 |
+
// simple split
|
| 115 |
+
ubatch.output = batch->logits + seq.offset;
|
| 116 |
+
for (size_t i = 0; i < length; ++i) {
|
| 117 |
+
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
} else {
|
| 121 |
+
// only get last output
|
| 122 |
+
for (size_t i = 0; i < length; ++i) {
|
| 123 |
+
size_t id = ids[seq.offset + i];
|
| 124 |
+
int8_t is_last = id == ids.size() - 1;
|
| 125 |
+
ubatch.output[ubatch.n_tokens + i] = is_last;
|
| 126 |
+
if (is_last) { out_ids.push_back(id); }
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
| 130 |
+
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
|
| 131 |
+
}
|
| 132 |
+
ubatch.n_tokens += length;
|
| 133 |
+
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
|
| 134 |
+
seq.offset += length;
|
| 135 |
+
seq.length -= length;
|
| 136 |
+
n_tokens -= length;
|
| 137 |
+
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
|
| 141 |
+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
| 142 |
+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
| 143 |
+
ubatch.equal_seqs = false;
|
| 144 |
+
if (!seq.empty()) {
|
| 145 |
+
llama_sbatch_seq & s = seq[0];
|
| 146 |
+
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
| 147 |
+
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
| 148 |
+
add_seq_to_ubatch(ubatch, s, length);
|
| 149 |
+
}
|
| 150 |
+
return ubatch;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
|
| 154 |
+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
| 155 |
+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
| 156 |
+
if (!seq.empty()) {
|
| 157 |
+
size_t length = 0;
|
| 158 |
+
size_t n_tokens_in_ubatch = 0;
|
| 159 |
+
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
|
| 160 |
+
// smallest first, because it's easier to split this way;
|
| 161 |
+
// starting from the end to pop in constant time.
|
| 162 |
+
for (size_t i = seq.size(); i-- > 0;) {
|
| 163 |
+
llama_sbatch_seq & s = seq[i];
|
| 164 |
+
GGML_ASSERT(s.length > 0);
|
| 165 |
+
if (length == 0) {
|
| 166 |
+
length = s.length < n_ubatch ? s.length : n_ubatch;
|
| 167 |
+
}
|
| 168 |
+
add_seq_to_ubatch(ubatch, s, length);
|
| 169 |
+
n_tokens_in_ubatch += length;
|
| 170 |
+
// shared prompts can't be mixed with any of their sequences,
|
| 171 |
+
// so it's safer to compute them in their own ubatch
|
| 172 |
+
if (s.n_seq_id > 1) { break; }
|
| 173 |
+
// stop when there isn't enough space for another sequence
|
| 174 |
+
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
return ubatch;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
| 181 |
+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
| 182 |
+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
| 183 |
+
if (!seq.empty()) {
|
| 184 |
+
llama_sbatch_seq & s = seq[seq.size() - 1];
|
| 185 |
+
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
| 186 |
+
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
|
| 187 |
+
add_seq_to_ubatch(ubatch, s, length);
|
| 188 |
+
}
|
| 189 |
+
return ubatch;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
| 193 |
+
GGML_ASSERT(batch.n_tokens >= 0);
|
| 194 |
+
this->batch = &batch;
|
| 195 |
+
this->n_embd = n_embd;
|
| 196 |
+
this->logits_all = logits_all;
|
| 197 |
+
|
| 198 |
+
n_tokens = batch.n_tokens;
|
| 199 |
+
ids.resize(n_tokens);
|
| 200 |
+
out_ids.clear();
|
| 201 |
+
// TODO: reserve out_ids and seq
|
| 202 |
+
|
| 203 |
+
for (size_t i = 0; i < n_tokens; ++i) {
|
| 204 |
+
ids[i] = i;
|
| 205 |
+
}
|
| 206 |
+
if (simple_split) {
|
| 207 |
+
seq.resize(1);
|
| 208 |
+
llama_sbatch_seq & s = seq[0];
|
| 209 |
+
s.n_seq_id = 0;
|
| 210 |
+
s.seq_id = nullptr;
|
| 211 |
+
s.offset = 0;
|
| 212 |
+
s.length = n_tokens;
|
| 213 |
+
return;
|
| 214 |
+
}
|
| 215 |
+
std::sort(ids.begin(), ids.end(),
|
| 216 |
+
[&batch](size_t a, size_t b) {
|
| 217 |
+
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
| 218 |
+
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
|
| 219 |
+
// sort by seq_id, then by pos
|
| 220 |
+
if (n_seq_a == n_seq_b) {
|
| 221 |
+
if (batch.seq_id) {
|
| 222 |
+
for (int32_t i = 0; i < n_seq_a; ++i) {
|
| 223 |
+
llama_seq_id seq_id_a = batch.seq_id[a][i];
|
| 224 |
+
llama_seq_id seq_id_b = batch.seq_id[b][i];
|
| 225 |
+
// smaller seq_ids go first
|
| 226 |
+
if (seq_id_a != seq_id_b) {
|
| 227 |
+
return seq_id_a < seq_id_b;
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
// when all else is equal, sort by pos
|
| 232 |
+
if (batch.pos) {
|
| 233 |
+
return batch.pos[a] < batch.pos[b];
|
| 234 |
+
}
|
| 235 |
+
// no pos, sort by id
|
| 236 |
+
return a < b;
|
| 237 |
+
}
|
| 238 |
+
// shared prompts go first
|
| 239 |
+
return n_seq_a > n_seq_b;
|
| 240 |
+
}
|
| 241 |
+
);
|
| 242 |
+
// init seq
|
| 243 |
+
llama_sbatch_seq * last_seq = nullptr;
|
| 244 |
+
|
| 245 |
+
for (size_t i = 0; i < n_tokens; ++i) {
|
| 246 |
+
const size_t bi = ids[i];
|
| 247 |
+
const int32_t n_seqs = batch.n_seq_id[bi];
|
| 248 |
+
llama_seq_id * seq_ids = batch.seq_id[bi];
|
| 249 |
+
if (last_seq != nullptr) {
|
| 250 |
+
bool same = n_seqs == last_seq->n_seq_id;
|
| 251 |
+
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
| 252 |
+
if (seq_ids[j] != last_seq->seq_id[j]) {
|
| 253 |
+
same = false;
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
if (same) {
|
| 257 |
+
last_seq->length += 1;
|
| 258 |
+
continue;
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
|
| 262 |
+
seq.push_back(new_seq);
|
| 263 |
+
last_seq = &seq.back();
|
| 264 |
+
}
|
| 265 |
+
// keep shared prompts first at the end, then sort by length descending.
|
| 266 |
+
std::sort(seq.begin(), seq.end(),
|
| 267 |
+
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
| 268 |
+
if (a.n_seq_id == b.n_seq_id) {
|
| 269 |
+
return a.length > b.length;
|
| 270 |
+
}
|
| 271 |
+
return a.n_seq_id < b.n_seq_id;
|
| 272 |
+
}
|
| 273 |
+
);
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
|
| 277 |
+
batch = in_batch;
|
| 278 |
+
GGML_ASSERT(batch.n_tokens > 0);
|
| 279 |
+
if (!batch.pos) {
|
| 280 |
+
pos.resize(batch.n_tokens);
|
| 281 |
+
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
| 282 |
+
pos[i] = i + p0;
|
| 283 |
+
}
|
| 284 |
+
batch.pos = pos.data();
|
| 285 |
+
}
|
| 286 |
+
if (!batch.n_seq_id) {
|
| 287 |
+
n_seq_id.resize(batch.n_tokens);
|
| 288 |
+
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
| 289 |
+
n_seq_id[i] = seq_id_0.size();
|
| 290 |
+
}
|
| 291 |
+
batch.n_seq_id = n_seq_id.data();
|
| 292 |
+
}
|
| 293 |
+
if (!batch.seq_id) {
|
| 294 |
+
seq_id.resize(batch.n_tokens + 1);
|
| 295 |
+
seq_id[batch.n_tokens] = NULL;
|
| 296 |
+
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
| 297 |
+
seq_id[i] = seq_id_0.data();
|
| 298 |
+
}
|
| 299 |
+
batch.seq_id = seq_id.data();
|
| 300 |
+
}
|
| 301 |
+
if (!batch.logits) {
|
| 302 |
+
logits.resize(batch.n_tokens);
|
| 303 |
+
logits[logits.size() - 1] = true;
|
| 304 |
+
batch.logits = logits.data();
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
//
|
| 309 |
+
// interface implementation
|
| 310 |
+
//
|
| 311 |
+
|
| 312 |
+
struct llama_batch llama_batch_get_one(
|
| 313 |
+
llama_token * tokens,
|
| 314 |
+
int32_t n_tokens) {
|
| 315 |
+
return {
|
| 316 |
+
/*n_tokens =*/ n_tokens,
|
| 317 |
+
/*tokens =*/ tokens,
|
| 318 |
+
/*embd =*/ nullptr,
|
| 319 |
+
/*pos =*/ nullptr,
|
| 320 |
+
/*n_seq_id =*/ nullptr,
|
| 321 |
+
/*seq_id =*/ nullptr,
|
| 322 |
+
/*logits =*/ nullptr,
|
| 323 |
+
};
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
| 327 |
+
llama_batch batch = {
|
| 328 |
+
/*n_tokens =*/ 0,
|
| 329 |
+
/*tokens =*/ nullptr,
|
| 330 |
+
/*embd =*/ nullptr,
|
| 331 |
+
/*pos =*/ nullptr,
|
| 332 |
+
/*n_seq_id =*/ nullptr,
|
| 333 |
+
/*seq_id =*/ nullptr,
|
| 334 |
+
/*logits =*/ nullptr,
|
| 335 |
+
};
|
| 336 |
+
|
| 337 |
+
if (embd) {
|
| 338 |
+
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
| 339 |
+
} else {
|
| 340 |
+
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
|
| 344 |
+
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
|
| 345 |
+
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
|
| 346 |
+
for (int i = 0; i < n_tokens_alloc; ++i) {
|
| 347 |
+
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
| 348 |
+
}
|
| 349 |
+
batch.seq_id[n_tokens_alloc] = nullptr;
|
| 350 |
+
|
| 351 |
+
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
| 352 |
+
|
| 353 |
+
return batch;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
void llama_batch_free(struct llama_batch batch) {
|
| 357 |
+
if (batch.token) free(batch.token);
|
| 358 |
+
if (batch.embd) free(batch.embd);
|
| 359 |
+
if (batch.pos) free(batch.pos);
|
| 360 |
+
if (batch.n_seq_id) free(batch.n_seq_id);
|
| 361 |
+
if (batch.seq_id) {
|
| 362 |
+
for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
|
| 363 |
+
free(batch.seq_id[i]);
|
| 364 |
+
}
|
| 365 |
+
free(batch.seq_id);
|
| 366 |
+
}
|
| 367 |
+
if (batch.logits) free(batch.logits);
|
| 368 |
+
}
|
examples/talk-llama/llama-batch.h
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
|
| 5 |
+
#include <array>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// very similar to llama_batch,
|
| 9 |
+
// but has more metadata about sequences
|
| 10 |
+
struct llama_ubatch {
|
| 11 |
+
bool equal_seqs;
|
| 12 |
+
// TODO: whole_seqs for embeddings?
|
| 13 |
+
|
| 14 |
+
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
| 15 |
+
uint32_t n_seq_tokens; // tokens per sequence
|
| 16 |
+
uint32_t n_seqs;
|
| 17 |
+
|
| 18 |
+
llama_token * token; // [n_tokens]
|
| 19 |
+
float * embd; // [n_embd, n_tokens]
|
| 20 |
+
llama_pos * pos; // [n_tokens]
|
| 21 |
+
int32_t * n_seq_id; // [n_seqs]
|
| 22 |
+
llama_seq_id ** seq_id; // [n_seqs]
|
| 23 |
+
int8_t * output; // [n_tokens]
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
struct llama_sbatch_seq {
|
| 27 |
+
int32_t n_seq_id;
|
| 28 |
+
|
| 29 |
+
llama_seq_id * seq_id;
|
| 30 |
+
|
| 31 |
+
size_t offset;
|
| 32 |
+
size_t length;
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
// sequence-length-aware batch splitting
|
| 36 |
+
struct llama_sbatch {
|
| 37 |
+
// tokens left in this batch
|
| 38 |
+
size_t n_tokens;
|
| 39 |
+
|
| 40 |
+
size_t n_embd;
|
| 41 |
+
|
| 42 |
+
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
| 43 |
+
|
| 44 |
+
// sorted indices into the batch
|
| 45 |
+
std::vector<size_t> ids;
|
| 46 |
+
// batch indices of the output
|
| 47 |
+
std::vector<size_t> out_ids;
|
| 48 |
+
std::vector<llama_sbatch_seq> seq;
|
| 49 |
+
|
| 50 |
+
const llama_batch * batch = nullptr;
|
| 51 |
+
|
| 52 |
+
// buffers for the ubatch
|
| 53 |
+
std::vector<llama_token> ubatch_token;
|
| 54 |
+
std::vector<float> ubatch_embd;
|
| 55 |
+
std::vector<llama_pos> ubatch_pos;
|
| 56 |
+
std::vector<int32_t> ubatch_n_seq_id;
|
| 57 |
+
std::vector<llama_seq_id *> ubatch_seq_id;
|
| 58 |
+
std::vector<int8_t> ubatch_output;
|
| 59 |
+
|
| 60 |
+
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
| 61 |
+
|
| 62 |
+
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
|
| 63 |
+
|
| 64 |
+
// simple split, unknown number of sequences of unequal lengths
|
| 65 |
+
llama_ubatch split_simple(size_t n_ubatch);
|
| 66 |
+
|
| 67 |
+
// make batches of equal-length sequences
|
| 68 |
+
llama_ubatch split_equal(size_t n_ubatch);
|
| 69 |
+
|
| 70 |
+
// sequence-wise split
|
| 71 |
+
llama_ubatch split_seq(size_t n_ubatch);
|
| 72 |
+
|
| 73 |
+
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
// temporary allocate memory for the input batch if needed
|
| 77 |
+
struct llama_batch_allocr {
|
| 78 |
+
struct llama_batch batch;
|
| 79 |
+
|
| 80 |
+
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
| 81 |
+
std::vector<llama_pos> pos;
|
| 82 |
+
std::vector<int32_t> n_seq_id;
|
| 83 |
+
std::vector<llama_seq_id *> seq_id;
|
| 84 |
+
std::vector<int8_t> logits;
|
| 85 |
+
|
| 86 |
+
// optionally fulfill the batch returned by llama_batch_get_one
|
| 87 |
+
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
| 88 |
+
};
|
examples/talk-llama/llama-chat.cpp
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-chat.h"
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
|
| 5 |
+
#include <map>
|
| 6 |
+
#include <sstream>
|
| 7 |
+
|
| 8 |
+
#if __cplusplus >= 202000L
|
| 9 |
+
#define LU8(x) (const char*)(u8##x)
|
| 10 |
+
#else
|
| 11 |
+
#define LU8(x) u8##x
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
// trim whitespace from the beginning and end of a string
|
| 15 |
+
static std::string trim(const std::string & str) {
|
| 16 |
+
size_t start = 0;
|
| 17 |
+
size_t end = str.size();
|
| 18 |
+
while (start < end && isspace(str[start])) {
|
| 19 |
+
start += 1;
|
| 20 |
+
}
|
| 21 |
+
while (end > start && isspace(str[end - 1])) {
|
| 22 |
+
end -= 1;
|
| 23 |
+
}
|
| 24 |
+
return str.substr(start, end - start);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
| 28 |
+
{ "chatml", LLM_CHAT_TEMPLATE_CHATML },
|
| 29 |
+
{ "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 },
|
| 30 |
+
{ "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS },
|
| 31 |
+
{ "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS },
|
| 32 |
+
{ "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP },
|
| 33 |
+
{ "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 },
|
| 34 |
+
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
|
| 35 |
+
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
|
| 36 |
+
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
|
| 37 |
+
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
|
| 38 |
+
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
|
| 39 |
+
{ "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR },
|
| 40 |
+
{ "monarch", LLM_CHAT_TEMPLATE_MONARCH },
|
| 41 |
+
{ "gemma", LLM_CHAT_TEMPLATE_GEMMA },
|
| 42 |
+
{ "orion", LLM_CHAT_TEMPLATE_ORION },
|
| 43 |
+
{ "openchat", LLM_CHAT_TEMPLATE_OPENCHAT },
|
| 44 |
+
{ "vicuna", LLM_CHAT_TEMPLATE_VICUNA },
|
| 45 |
+
{ "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA },
|
| 46 |
+
{ "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK },
|
| 47 |
+
{ "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 },
|
| 48 |
+
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
|
| 49 |
+
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
|
| 50 |
+
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
|
| 51 |
+
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
|
| 52 |
+
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
|
| 53 |
+
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
| 54 |
+
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
| 55 |
+
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
| 56 |
+
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
| 57 |
+
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
| 58 |
+
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
| 62 |
+
return LLM_CHAT_TEMPLATES.at(name);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
| 66 |
+
try {
|
| 67 |
+
return llm_chat_template_from_str(tmpl);
|
| 68 |
+
} catch (const std::out_of_range &) {
|
| 69 |
+
// ignore
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
auto tmpl_contains = [&tmpl](const char * haystack) -> bool {
|
| 73 |
+
return tmpl.find(haystack) != std::string::npos;
|
| 74 |
+
};
|
| 75 |
+
if (tmpl_contains("<|im_start|>")) {
|
| 76 |
+
return LLM_CHAT_TEMPLATE_CHATML;
|
| 77 |
+
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
| 78 |
+
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
| 79 |
+
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
|
| 80 |
+
} else if (
|
| 81 |
+
// catches official 'v1' template
|
| 82 |
+
tmpl_contains("' [INST] ' + system_message")
|
| 83 |
+
// catches official 'v3' and 'v3-tekken' templates
|
| 84 |
+
|| tmpl_contains("[AVAILABLE_TOOLS]")
|
| 85 |
+
) {
|
| 86 |
+
// Official mistral 'v1', 'v3' and 'v3-tekken' templates
|
| 87 |
+
// See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
|
| 88 |
+
// See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
|
| 89 |
+
if (tmpl_contains(" [INST]")) {
|
| 90 |
+
return LLM_CHAT_TEMPLATE_MISTRAL_V1;
|
| 91 |
+
} else if (tmpl_contains("\"[INST]\"")) {
|
| 92 |
+
return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN;
|
| 93 |
+
}
|
| 94 |
+
return LLM_CHAT_TEMPLATE_MISTRAL_V3;
|
| 95 |
+
} else {
|
| 96 |
+
// llama2 template and its variants
|
| 97 |
+
// [variant] support system message
|
| 98 |
+
// See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
|
| 99 |
+
bool support_system_message = tmpl_contains("<<SYS>>");
|
| 100 |
+
bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
|
| 101 |
+
bool strip_message = tmpl_contains("content.strip()");
|
| 102 |
+
if (strip_message) {
|
| 103 |
+
return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP;
|
| 104 |
+
} else if (add_bos_inside_history) {
|
| 105 |
+
return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS;
|
| 106 |
+
} else if (support_system_message) {
|
| 107 |
+
return LLM_CHAT_TEMPLATE_LLAMA_2_SYS;
|
| 108 |
+
} else {
|
| 109 |
+
return LLM_CHAT_TEMPLATE_LLAMA_2;
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
|
| 113 |
+
return LLM_CHAT_TEMPLATE_PHI_3;
|
| 114 |
+
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
|
| 115 |
+
return LLM_CHAT_TEMPLATE_FALCON_3;
|
| 116 |
+
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
|
| 117 |
+
return LLM_CHAT_TEMPLATE_ZEPHYR;
|
| 118 |
+
} else if (tmpl_contains("bos_token + message['role']")) {
|
| 119 |
+
return LLM_CHAT_TEMPLATE_MONARCH;
|
| 120 |
+
} else if (tmpl_contains("<start_of_turn>")) {
|
| 121 |
+
return LLM_CHAT_TEMPLATE_GEMMA;
|
| 122 |
+
} else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
|
| 123 |
+
// OrionStarAI/Orion-14B-Chat
|
| 124 |
+
return LLM_CHAT_TEMPLATE_ORION;
|
| 125 |
+
} else if (tmpl_contains("GPT4 Correct ")) {
|
| 126 |
+
// openchat/openchat-3.5-0106
|
| 127 |
+
return LLM_CHAT_TEMPLATE_OPENCHAT;
|
| 128 |
+
} else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) {
|
| 129 |
+
// eachadea/vicuna-13b-1.1 (and Orca variant)
|
| 130 |
+
if (tmpl_contains("SYSTEM: ")) {
|
| 131 |
+
return LLM_CHAT_TEMPLATE_VICUNA_ORCA;
|
| 132 |
+
}
|
| 133 |
+
return LLM_CHAT_TEMPLATE_VICUNA;
|
| 134 |
+
} else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) {
|
| 135 |
+
// deepseek-ai/deepseek-coder-33b-instruct
|
| 136 |
+
return LLM_CHAT_TEMPLATE_DEEPSEEK;
|
| 137 |
+
} else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) {
|
| 138 |
+
// CohereForAI/c4ai-command-r-plus
|
| 139 |
+
return LLM_CHAT_TEMPLATE_COMMAND_R;
|
| 140 |
+
} else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) {
|
| 141 |
+
return LLM_CHAT_TEMPLATE_LLAMA_3;
|
| 142 |
+
} else if (tmpl_contains("[gMASK]sop")) {
|
| 143 |
+
// chatglm3-6b
|
| 144 |
+
return LLM_CHAT_TEMPLATE_CHATGML_3;
|
| 145 |
+
} else if (tmpl_contains("[gMASK]<sop>")) {
|
| 146 |
+
return LLM_CHAT_TEMPLATE_CHATGML_4;
|
| 147 |
+
} else if (tmpl_contains(LU8("<用户>"))) {
|
| 148 |
+
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
| 149 |
+
return LLM_CHAT_TEMPLATE_MINICPM;
|
| 150 |
+
} else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
|
| 151 |
+
return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
|
| 152 |
+
} else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) {
|
| 153 |
+
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
|
| 154 |
+
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
|
| 155 |
+
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
| 156 |
+
// EXAONE-3.0-7.8B-Instruct
|
| 157 |
+
return LLM_CHAT_TEMPLATE_EXAONE_3;
|
| 158 |
+
} else if (tmpl_contains("rwkv-world")) {
|
| 159 |
+
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
|
| 160 |
+
} else if (tmpl_contains("<|start_of_role|>")) {
|
| 161 |
+
return LLM_CHAT_TEMPLATE_GRANITE;
|
| 162 |
+
} else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) {
|
| 163 |
+
return LLM_CHAT_TEMPLATE_GIGACHAT;
|
| 164 |
+
} else if (tmpl_contains("<|role_start|>")) {
|
| 165 |
+
return LLM_CHAT_TEMPLATE_MEGREZ;
|
| 166 |
+
}
|
| 167 |
+
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
// Simple version of "llama_apply_chat_template" that only works with strings
|
| 171 |
+
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
|
| 172 |
+
int32_t llm_chat_apply_template(
|
| 173 |
+
llm_chat_template tmpl,
|
| 174 |
+
const std::vector<const llama_chat_message *> & chat,
|
| 175 |
+
std::string & dest, bool add_ass) {
|
| 176 |
+
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
| 177 |
+
std::stringstream ss;
|
| 178 |
+
if (tmpl == LLM_CHAT_TEMPLATE_CHATML) {
|
| 179 |
+
// chatml template
|
| 180 |
+
for (auto message : chat) {
|
| 181 |
+
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
|
| 182 |
+
}
|
| 183 |
+
if (add_ass) {
|
| 184 |
+
ss << "<|im_start|>assistant\n";
|
| 185 |
+
}
|
| 186 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
|
| 187 |
+
// Official mistral 'v7' template
|
| 188 |
+
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
|
| 189 |
+
for (auto message : chat) {
|
| 190 |
+
std::string role(message->role);
|
| 191 |
+
std::string content(message->content);
|
| 192 |
+
if (role == "system") {
|
| 193 |
+
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
|
| 194 |
+
} else if (role == "user") {
|
| 195 |
+
ss << "[INST] " << content << "[/INST]";
|
| 196 |
+
}
|
| 197 |
+
else {
|
| 198 |
+
ss << " " << content << "</s>";
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
|
| 202 |
+
|| tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3
|
| 203 |
+
|| tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) {
|
| 204 |
+
// See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
|
| 205 |
+
// See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
|
| 206 |
+
std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : "";
|
| 207 |
+
std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " ";
|
| 208 |
+
bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3;
|
| 209 |
+
bool is_inside_turn = false;
|
| 210 |
+
for (auto message : chat) {
|
| 211 |
+
if (!is_inside_turn) {
|
| 212 |
+
ss << leading_space << "[INST]" << trailing_space;
|
| 213 |
+
is_inside_turn = true;
|
| 214 |
+
}
|
| 215 |
+
std::string role(message->role);
|
| 216 |
+
std::string content(message->content);
|
| 217 |
+
if (role == "system") {
|
| 218 |
+
ss << content << "\n\n";
|
| 219 |
+
} else if (role == "user") {
|
| 220 |
+
ss << content << leading_space << "[/INST]";
|
| 221 |
+
} else {
|
| 222 |
+
ss << trailing_space << (trim_assistant_message ? trim(content) : content) << "</s>";
|
| 223 |
+
is_inside_turn = false;
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
} else if (
|
| 227 |
+
tmpl == LLM_CHAT_TEMPLATE_LLAMA_2
|
| 228 |
+
|| tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS
|
| 229 |
+
|| tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS
|
| 230 |
+
|| tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) {
|
| 231 |
+
// llama2 template and its variants
|
| 232 |
+
// [variant] support system message
|
| 233 |
+
// See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
|
| 234 |
+
bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2;
|
| 235 |
+
// [variant] add BOS inside history
|
| 236 |
+
bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS;
|
| 237 |
+
// [variant] trim spaces from the input message
|
| 238 |
+
bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP;
|
| 239 |
+
// construct the prompt
|
| 240 |
+
bool is_inside_turn = true; // skip BOS at the beginning
|
| 241 |
+
ss << "[INST] ";
|
| 242 |
+
for (auto message : chat) {
|
| 243 |
+
std::string content = strip_message ? trim(message->content) : message->content;
|
| 244 |
+
std::string role(message->role);
|
| 245 |
+
if (!is_inside_turn) {
|
| 246 |
+
is_inside_turn = true;
|
| 247 |
+
ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
|
| 248 |
+
}
|
| 249 |
+
if (role == "system") {
|
| 250 |
+
if (support_system_message) {
|
| 251 |
+
ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
|
| 252 |
+
} else {
|
| 253 |
+
// if the model does not support system message, we still include it in the first message, but without <<SYS>>
|
| 254 |
+
ss << content << "\n";
|
| 255 |
+
}
|
| 256 |
+
} else if (role == "user") {
|
| 257 |
+
ss << content << " [/INST]";
|
| 258 |
+
} else {
|
| 259 |
+
ss << content << "</s>";
|
| 260 |
+
is_inside_turn = false;
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) {
|
| 264 |
+
// Phi 3
|
| 265 |
+
for (auto message : chat) {
|
| 266 |
+
std::string role(message->role);
|
| 267 |
+
ss << "<|" << role << "|>\n" << message->content << "<|end|>\n";
|
| 268 |
+
}
|
| 269 |
+
if (add_ass) {
|
| 270 |
+
ss << "<|assistant|>\n";
|
| 271 |
+
}
|
| 272 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) {
|
| 273 |
+
// Falcon 3
|
| 274 |
+
for (auto message : chat) {
|
| 275 |
+
std::string role(message->role);
|
| 276 |
+
ss << "<|" << role << "|>\n" << message->content << "\n";
|
| 277 |
+
}
|
| 278 |
+
if (add_ass) {
|
| 279 |
+
ss << "<|assistant|>\n";
|
| 280 |
+
}
|
| 281 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) {
|
| 282 |
+
// zephyr template
|
| 283 |
+
for (auto message : chat) {
|
| 284 |
+
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
|
| 285 |
+
}
|
| 286 |
+
if (add_ass) {
|
| 287 |
+
ss << "<|assistant|>\n";
|
| 288 |
+
}
|
| 289 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) {
|
| 290 |
+
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
|
| 291 |
+
for (auto message : chat) {
|
| 292 |
+
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
|
| 293 |
+
ss << bos << message->role << "\n" << message->content << "</s>\n";
|
| 294 |
+
}
|
| 295 |
+
if (add_ass) {
|
| 296 |
+
ss << "<s>assistant\n";
|
| 297 |
+
}
|
| 298 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) {
|
| 299 |
+
// google/gemma-7b-it
|
| 300 |
+
std::string system_prompt = "";
|
| 301 |
+
for (auto message : chat) {
|
| 302 |
+
std::string role(message->role);
|
| 303 |
+
if (role == "system") {
|
| 304 |
+
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
|
| 305 |
+
system_prompt = trim(message->content);
|
| 306 |
+
continue;
|
| 307 |
+
}
|
| 308 |
+
// in gemma, "assistant" is "model"
|
| 309 |
+
role = role == "assistant" ? "model" : message->role;
|
| 310 |
+
ss << "<start_of_turn>" << role << "\n";
|
| 311 |
+
if (!system_prompt.empty() && role != "model") {
|
| 312 |
+
ss << system_prompt << "\n\n";
|
| 313 |
+
system_prompt = "";
|
| 314 |
+
}
|
| 315 |
+
ss << trim(message->content) << "<end_of_turn>\n";
|
| 316 |
+
}
|
| 317 |
+
if (add_ass) {
|
| 318 |
+
ss << "<start_of_turn>model\n";
|
| 319 |
+
}
|
| 320 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_ORION) {
|
| 321 |
+
// OrionStarAI/Orion-14B-Chat
|
| 322 |
+
std::string system_prompt = "";
|
| 323 |
+
for (auto message : chat) {
|
| 324 |
+
std::string role(message->role);
|
| 325 |
+
if (role == "system") {
|
| 326 |
+
// there is no system message support, we will merge it with user prompt
|
| 327 |
+
system_prompt = message->content;
|
| 328 |
+
continue;
|
| 329 |
+
} else if (role == "user") {
|
| 330 |
+
ss << "Human: ";
|
| 331 |
+
if (!system_prompt.empty()) {
|
| 332 |
+
ss << system_prompt << "\n\n";
|
| 333 |
+
system_prompt = "";
|
| 334 |
+
}
|
| 335 |
+
ss << message->content << "\n\nAssistant: </s>";
|
| 336 |
+
} else {
|
| 337 |
+
ss << message->content << "</s>";
|
| 338 |
+
}
|
| 339 |
+
}
|
| 340 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) {
|
| 341 |
+
// openchat/openchat-3.5-0106,
|
| 342 |
+
for (auto message : chat) {
|
| 343 |
+
std::string role(message->role);
|
| 344 |
+
if (role == "system") {
|
| 345 |
+
ss << message->content << "<|end_of_turn|>";
|
| 346 |
+
} else {
|
| 347 |
+
role[0] = toupper(role[0]);
|
| 348 |
+
ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
if (add_ass) {
|
| 352 |
+
ss << "GPT4 Correct Assistant:";
|
| 353 |
+
}
|
| 354 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) {
|
| 355 |
+
// eachadea/vicuna-13b-1.1 (and Orca variant)
|
| 356 |
+
for (auto message : chat) {
|
| 357 |
+
std::string role(message->role);
|
| 358 |
+
if (role == "system") {
|
| 359 |
+
// Orca-Vicuna variant uses a system prefix
|
| 360 |
+
if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) {
|
| 361 |
+
ss << "SYSTEM: " << message->content << "\n";
|
| 362 |
+
} else {
|
| 363 |
+
ss << message->content << "\n\n";
|
| 364 |
+
}
|
| 365 |
+
} else if (role == "user") {
|
| 366 |
+
ss << "USER: " << message->content << "\n";
|
| 367 |
+
} else if (role == "assistant") {
|
| 368 |
+
ss << "ASSISTANT: " << message->content << "</s>\n";
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
if (add_ass) {
|
| 372 |
+
ss << "ASSISTANT:";
|
| 373 |
+
}
|
| 374 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) {
|
| 375 |
+
// deepseek-ai/deepseek-coder-33b-instruct
|
| 376 |
+
for (auto message : chat) {
|
| 377 |
+
std::string role(message->role);
|
| 378 |
+
if (role == "system") {
|
| 379 |
+
ss << message->content;
|
| 380 |
+
} else if (role == "user") {
|
| 381 |
+
ss << "### Instruction:\n" << message->content << "\n";
|
| 382 |
+
} else if (role == "assistant") {
|
| 383 |
+
ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
if (add_ass) {
|
| 387 |
+
ss << "### Response:\n";
|
| 388 |
+
}
|
| 389 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) {
|
| 390 |
+
// CohereForAI/c4ai-command-r-plus
|
| 391 |
+
for (auto message : chat) {
|
| 392 |
+
std::string role(message->role);
|
| 393 |
+
if (role == "system") {
|
| 394 |
+
ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
|
| 395 |
+
} else if (role == "user") {
|
| 396 |
+
ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
|
| 397 |
+
} else if (role == "assistant") {
|
| 398 |
+
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
if (add_ass) {
|
| 402 |
+
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
|
| 403 |
+
}
|
| 404 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) {
|
| 405 |
+
// Llama 3
|
| 406 |
+
for (auto message : chat) {
|
| 407 |
+
std::string role(message->role);
|
| 408 |
+
ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
|
| 409 |
+
}
|
| 410 |
+
if (add_ass) {
|
| 411 |
+
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
| 412 |
+
}
|
| 413 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
|
| 414 |
+
// chatglm3-6b
|
| 415 |
+
ss << "[gMASK]" << "sop";
|
| 416 |
+
for (auto message : chat) {
|
| 417 |
+
std::string role(message->role);
|
| 418 |
+
ss << "<|" << role << "|>" << "\n " << message->content;
|
| 419 |
+
}
|
| 420 |
+
if (add_ass) {
|
| 421 |
+
ss << "<|assistant|>";
|
| 422 |
+
}
|
| 423 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
|
| 424 |
+
ss << "[gMASK]" << "<sop>";
|
| 425 |
+
for (auto message : chat) {
|
| 426 |
+
std::string role(message->role);
|
| 427 |
+
ss << "<|" << role << "|>" << "\n" << message->content;
|
| 428 |
+
}
|
| 429 |
+
if (add_ass) {
|
| 430 |
+
ss << "<|assistant|>";
|
| 431 |
+
}
|
| 432 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
|
| 433 |
+
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
| 434 |
+
for (auto message : chat) {
|
| 435 |
+
std::string role(message->role);
|
| 436 |
+
if (role == "user") {
|
| 437 |
+
ss << LU8("<用户>");
|
| 438 |
+
ss << trim(message->content);
|
| 439 |
+
ss << "<AI>";
|
| 440 |
+
} else {
|
| 441 |
+
ss << trim(message->content);
|
| 442 |
+
}
|
| 443 |
+
}
|
| 444 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) {
|
| 445 |
+
// DeepSeek-V2
|
| 446 |
+
for (auto message : chat) {
|
| 447 |
+
std::string role(message->role);
|
| 448 |
+
if (role == "system") {
|
| 449 |
+
ss << message->content << "\n\n";
|
| 450 |
+
} else if (role == "user") {
|
| 451 |
+
ss << "User: " << message->content << "\n\n";
|
| 452 |
+
} else if (role == "assistant") {
|
| 453 |
+
ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>");
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
if (add_ass) {
|
| 457 |
+
ss << "Assistant:";
|
| 458 |
+
}
|
| 459 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) {
|
| 460 |
+
// DeepSeek-V3
|
| 461 |
+
for (auto message : chat) {
|
| 462 |
+
std::string role(message->role);
|
| 463 |
+
if (role == "system") {
|
| 464 |
+
ss << message->content << "\n\n";
|
| 465 |
+
} else if (role == "user") {
|
| 466 |
+
ss << LU8("<|User|>") << message->content;
|
| 467 |
+
} else if (role == "assistant") {
|
| 468 |
+
ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>");
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
if (add_ass) {
|
| 472 |
+
ss << LU8("<|Assistant|>");
|
| 473 |
+
}
|
| 474 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
|
| 475 |
+
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
| 476 |
+
// EXAONE-3.0-7.8B-Instruct
|
| 477 |
+
for (auto message : chat) {
|
| 478 |
+
std::string role(message->role);
|
| 479 |
+
if (role == "system") {
|
| 480 |
+
ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
|
| 481 |
+
} else if (role == "user") {
|
| 482 |
+
ss << "[|user|]" << trim(message->content) << "\n";
|
| 483 |
+
} else if (role == "assistant") {
|
| 484 |
+
ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
|
| 485 |
+
}
|
| 486 |
+
}
|
| 487 |
+
if (add_ass) {
|
| 488 |
+
ss << "[|assistant|]";
|
| 489 |
+
}
|
| 490 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
| 491 |
+
// this template requires the model to have "\n\n" as EOT token
|
| 492 |
+
for (auto message : chat) {
|
| 493 |
+
std::string role(message->role);
|
| 494 |
+
if (role == "user") {
|
| 495 |
+
ss << "User: " << message->content << "\n\nAssistant:";
|
| 496 |
+
} else {
|
| 497 |
+
ss << message->content << "\n\n";
|
| 498 |
+
}
|
| 499 |
+
}
|
| 500 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
|
| 501 |
+
// IBM Granite template
|
| 502 |
+
for (const auto & message : chat) {
|
| 503 |
+
std::string role(message->role);
|
| 504 |
+
ss << "<|start_of_role|>" << role << "<|end_of_role|>";
|
| 505 |
+
if (role == "assistant_tool_call") {
|
| 506 |
+
ss << "<|tool_call|>";
|
| 507 |
+
}
|
| 508 |
+
ss << message->content << "<|end_of_text|>\n";
|
| 509 |
+
}
|
| 510 |
+
if (add_ass) {
|
| 511 |
+
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
|
| 512 |
+
}
|
| 513 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
|
| 514 |
+
// GigaChat template
|
| 515 |
+
bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
|
| 516 |
+
|
| 517 |
+
// Handle system message if present
|
| 518 |
+
if (has_system) {
|
| 519 |
+
ss << "<s>" << chat[0]->content << "<|message_sep|>";
|
| 520 |
+
} else {
|
| 521 |
+
ss << "<s>";
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
// Process remaining messages
|
| 525 |
+
for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) {
|
| 526 |
+
std::string role(chat[i]->role);
|
| 527 |
+
if (role == "user") {
|
| 528 |
+
ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>"
|
| 529 |
+
<< "available functions<|role_sep|>[]<|message_sep|>";
|
| 530 |
+
} else if (role == "assistant") {
|
| 531 |
+
ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>";
|
| 532 |
+
}
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
// Add generation prompt if needed
|
| 536 |
+
if (add_ass) {
|
| 537 |
+
ss << "assistant<|role_sep|>";
|
| 538 |
+
}
|
| 539 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_MEGREZ) {
|
| 540 |
+
// Megrez template
|
| 541 |
+
for (auto message : chat) {
|
| 542 |
+
std::string role(message->role);
|
| 543 |
+
ss << "<|role_start|>" << role << "<|role_end|>" << message->content << "<|turn_end|>";
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
if (add_ass) {
|
| 547 |
+
ss << "<|role_start|>assistant<|role_end|>";
|
| 548 |
+
}
|
| 549 |
+
} else {
|
| 550 |
+
// template not supported
|
| 551 |
+
return -1;
|
| 552 |
+
}
|
| 553 |
+
dest = ss.str();
|
| 554 |
+
return dest.size();
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
// public interface
|
| 558 |
+
|
| 559 |
+
int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
|
| 560 |
+
auto it = LLM_CHAT_TEMPLATES.begin();
|
| 561 |
+
for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) {
|
| 562 |
+
output[i] = it->first.c_str();
|
| 563 |
+
std::advance(it, 1);
|
| 564 |
+
}
|
| 565 |
+
return (int32_t) LLM_CHAT_TEMPLATES.size();
|
| 566 |
+
}
|
| 567 |
+
|
examples/talk-llama/llama-chat.h
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include <vector>
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
|
| 7 |
+
enum llm_chat_template {
|
| 8 |
+
LLM_CHAT_TEMPLATE_CHATML,
|
| 9 |
+
LLM_CHAT_TEMPLATE_LLAMA_2,
|
| 10 |
+
LLM_CHAT_TEMPLATE_LLAMA_2_SYS,
|
| 11 |
+
LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS,
|
| 12 |
+
LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP,
|
| 13 |
+
LLM_CHAT_TEMPLATE_MISTRAL_V1,
|
| 14 |
+
LLM_CHAT_TEMPLATE_MISTRAL_V3,
|
| 15 |
+
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
|
| 16 |
+
LLM_CHAT_TEMPLATE_MISTRAL_V7,
|
| 17 |
+
LLM_CHAT_TEMPLATE_PHI_3,
|
| 18 |
+
LLM_CHAT_TEMPLATE_FALCON_3,
|
| 19 |
+
LLM_CHAT_TEMPLATE_ZEPHYR,
|
| 20 |
+
LLM_CHAT_TEMPLATE_MONARCH,
|
| 21 |
+
LLM_CHAT_TEMPLATE_GEMMA,
|
| 22 |
+
LLM_CHAT_TEMPLATE_ORION,
|
| 23 |
+
LLM_CHAT_TEMPLATE_OPENCHAT,
|
| 24 |
+
LLM_CHAT_TEMPLATE_VICUNA,
|
| 25 |
+
LLM_CHAT_TEMPLATE_VICUNA_ORCA,
|
| 26 |
+
LLM_CHAT_TEMPLATE_DEEPSEEK,
|
| 27 |
+
LLM_CHAT_TEMPLATE_DEEPSEEK_2,
|
| 28 |
+
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
|
| 29 |
+
LLM_CHAT_TEMPLATE_COMMAND_R,
|
| 30 |
+
LLM_CHAT_TEMPLATE_LLAMA_3,
|
| 31 |
+
LLM_CHAT_TEMPLATE_CHATGML_3,
|
| 32 |
+
LLM_CHAT_TEMPLATE_CHATGML_4,
|
| 33 |
+
LLM_CHAT_TEMPLATE_MINICPM,
|
| 34 |
+
LLM_CHAT_TEMPLATE_EXAONE_3,
|
| 35 |
+
LLM_CHAT_TEMPLATE_RWKV_WORLD,
|
| 36 |
+
LLM_CHAT_TEMPLATE_GRANITE,
|
| 37 |
+
LLM_CHAT_TEMPLATE_GIGACHAT,
|
| 38 |
+
LLM_CHAT_TEMPLATE_MEGREZ,
|
| 39 |
+
LLM_CHAT_TEMPLATE_UNKNOWN,
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
struct llama_chat_message;
|
| 43 |
+
|
| 44 |
+
llm_chat_template llm_chat_template_from_str(const std::string & name);
|
| 45 |
+
|
| 46 |
+
llm_chat_template llm_chat_detect_template(const std::string & tmpl);
|
| 47 |
+
|
| 48 |
+
int32_t llm_chat_apply_template(
|
| 49 |
+
llm_chat_template tmpl,
|
| 50 |
+
const std::vector<const llama_chat_message *> & chat,
|
| 51 |
+
std::string & dest, bool add_ass);
|
examples/talk-llama/llama-context.cpp
ADDED
|
@@ -0,0 +1,1771 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-context.h"
|
| 2 |
+
|
| 3 |
+
#include <cassert>
|
| 4 |
+
#include <cmath>
|
| 5 |
+
#include <cstring>
|
| 6 |
+
#include <stdexcept>
|
| 7 |
+
|
| 8 |
+
void llama_set_k_shift(struct llama_context & lctx) {
|
| 9 |
+
const int64_t kv_size = lctx.kv_self.size;
|
| 10 |
+
|
| 11 |
+
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
|
| 12 |
+
|
| 13 |
+
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
|
| 14 |
+
|
| 15 |
+
for (int i = 0; i < kv_size; ++i) {
|
| 16 |
+
data[i] = lctx.kv_self.cells[i].delta;
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
void llama_set_s_copy(struct llama_context & lctx) {
|
| 21 |
+
const int64_t kv_size = lctx.kv_self.size;
|
| 22 |
+
|
| 23 |
+
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
|
| 24 |
+
|
| 25 |
+
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
|
| 26 |
+
|
| 27 |
+
for (int i = 0; i < kv_size; ++i) {
|
| 28 |
+
data[i] = lctx.kv_self.cells[i].src;
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
// llama input
|
| 33 |
+
|
| 34 |
+
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
| 35 |
+
// TODO move to hparams if a T5 variant appears that uses a different value
|
| 36 |
+
const int64_t max_distance = 128;
|
| 37 |
+
|
| 38 |
+
if (bidirectional) {
|
| 39 |
+
n_buckets >>= 1;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
const int64_t max_exact = n_buckets >> 1;
|
| 43 |
+
|
| 44 |
+
int32_t relative_position = x - y;
|
| 45 |
+
int32_t relative_bucket = 0;
|
| 46 |
+
if (bidirectional) {
|
| 47 |
+
relative_bucket += (relative_position > 0) * n_buckets;
|
| 48 |
+
relative_position = abs(relative_position);
|
| 49 |
+
} else {
|
| 50 |
+
relative_position = -std::min<int32_t>(relative_position, 0);
|
| 51 |
+
}
|
| 52 |
+
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
| 53 |
+
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
| 54 |
+
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
| 55 |
+
return relative_bucket;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
|
| 59 |
+
//
|
| 60 |
+
// set input data
|
| 61 |
+
//
|
| 62 |
+
|
| 63 |
+
const auto & hparams = lctx.model.hparams;
|
| 64 |
+
const auto & cparams = lctx.cparams;
|
| 65 |
+
const auto & kv_self = lctx.kv_self;
|
| 66 |
+
|
| 67 |
+
if (ubatch.token) {
|
| 68 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 69 |
+
|
| 70 |
+
ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
if (ubatch.embd) {
|
| 74 |
+
const int64_t n_embd = hparams.n_embd;
|
| 75 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 76 |
+
|
| 77 |
+
ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if (ubatch.pos && lctx.inp_pos) {
|
| 81 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 82 |
+
auto n_pos = lctx.n_pos_per_token;
|
| 83 |
+
ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos));
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
| 87 |
+
//GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
| 88 |
+
|
| 89 |
+
if (!lctx.inp_out_ids) {
|
| 90 |
+
LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__);
|
| 91 |
+
} else {
|
| 92 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 93 |
+
|
| 94 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
|
| 95 |
+
int32_t * data = (int32_t *) lctx.inp_out_ids->data;
|
| 96 |
+
|
| 97 |
+
if (lctx.n_outputs == n_tokens) {
|
| 98 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 99 |
+
data[i] = i;
|
| 100 |
+
}
|
| 101 |
+
} else if (ubatch.output) {
|
| 102 |
+
int32_t n_outputs = 0;
|
| 103 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 104 |
+
if (ubatch.output[i]) {
|
| 105 |
+
data[n_outputs++] = i;
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
// the graph needs to have been passed the correct number of outputs
|
| 109 |
+
GGML_ASSERT(lctx.n_outputs == n_outputs);
|
| 110 |
+
} else if (lctx.n_outputs == 1) {
|
| 111 |
+
// only keep last output
|
| 112 |
+
data[0] = n_tokens - 1;
|
| 113 |
+
} else {
|
| 114 |
+
GGML_ASSERT(lctx.n_outputs == 0);
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
GGML_ASSERT(
|
| 120 |
+
// (!a || b) is a logical implication (a -> b)
|
| 121 |
+
// !hparams.causal_attn -> !cparams.causal_attn
|
| 122 |
+
(hparams.causal_attn || !cparams.causal_attn) &&
|
| 123 |
+
"causal attention is not supported by this model"
|
| 124 |
+
);
|
| 125 |
+
|
| 126 |
+
if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
|
| 127 |
+
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
| 128 |
+
if (cparams.causal_attn && !lctx.is_encoding) {
|
| 129 |
+
const int64_t n_kv = kv_self.n;
|
| 130 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 131 |
+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 132 |
+
const int64_t n_seqs = ubatch.n_seqs;
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
float * data = nullptr;
|
| 136 |
+
float * data_swa = nullptr;
|
| 137 |
+
|
| 138 |
+
if (lctx.inp_KQ_mask) {
|
| 139 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
| 140 |
+
data = (float *) lctx.inp_KQ_mask->data;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
if (lctx.inp_KQ_mask_swa) {
|
| 144 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
|
| 145 |
+
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
// For causal attention, use only the previous KV cells
|
| 149 |
+
// of the correct sequence for each token of the ubatch.
|
| 150 |
+
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
| 151 |
+
for (int h = 0; h < 1; ++h) {
|
| 152 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 153 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 154 |
+
|
| 155 |
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 156 |
+
const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
|
| 157 |
+
|
| 158 |
+
for (int i = 0; i < n_kv; ++i) {
|
| 159 |
+
float f;
|
| 160 |
+
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
| 161 |
+
f = -INFINITY;
|
| 162 |
+
} else {
|
| 163 |
+
if (hparams.use_alibi) {
|
| 164 |
+
f = -std::abs(kv_self.cells[i].pos - pos);
|
| 165 |
+
} else {
|
| 166 |
+
f = 0.0f;
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
if (data) {
|
| 171 |
+
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
// may need to cut off old tokens for sliding window
|
| 175 |
+
if (data_swa) {
|
| 176 |
+
if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
| 177 |
+
f = -INFINITY;
|
| 178 |
+
}
|
| 179 |
+
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
if (data) {
|
| 186 |
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 187 |
+
for (int j = 0; j < n_kv; ++j) {
|
| 188 |
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
if (data_swa) {
|
| 194 |
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 195 |
+
for (int j = 0; j < n_kv; ++j) {
|
| 196 |
+
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
} else {
|
| 202 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 203 |
+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 204 |
+
const int64_t n_seqs = ubatch.n_seqs;
|
| 205 |
+
// when using kv cache, the mask needs to match the kv cache size
|
| 206 |
+
const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
|
| 207 |
+
|
| 208 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
| 209 |
+
|
| 210 |
+
float * data = (float *) lctx.inp_KQ_mask->data;
|
| 211 |
+
|
| 212 |
+
for (int h = 0; h < 1; ++h) {
|
| 213 |
+
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
| 214 |
+
const llama_seq_id seq_id = ubatch.seq_id[s1][0];
|
| 215 |
+
|
| 216 |
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 217 |
+
const int32_t tj = s1*n_seq_tokens + j;
|
| 218 |
+
|
| 219 |
+
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
| 220 |
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 221 |
+
const int32_t ti = s0*n_seq_tokens + i;
|
| 222 |
+
float f = -INFINITY;
|
| 223 |
+
|
| 224 |
+
for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
|
| 225 |
+
if (ubatch.seq_id[s0][s] == seq_id) {
|
| 226 |
+
if (hparams.use_alibi) {
|
| 227 |
+
f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
|
| 228 |
+
} else {
|
| 229 |
+
f = 0.0f;
|
| 230 |
+
}
|
| 231 |
+
break;
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
for (int i = n_tokens; i < n_stride; ++i) {
|
| 240 |
+
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
| 249 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 250 |
+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 251 |
+
const int64_t n_seqs = ubatch.n_seqs;
|
| 252 |
+
|
| 253 |
+
GGML_ASSERT(lctx.inp_mean);
|
| 254 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
| 255 |
+
|
| 256 |
+
float * data = (float *) lctx.inp_mean->data;
|
| 257 |
+
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
|
| 258 |
+
|
| 259 |
+
std::vector<uint64_t> sum(n_tokens, 0);
|
| 260 |
+
|
| 261 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 262 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 263 |
+
|
| 264 |
+
// TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
|
| 265 |
+
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
|
| 266 |
+
|
| 267 |
+
sum[seq_id] += ubatch.n_seq_tokens;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
std::vector<float> div(n_tokens, 0.0f);
|
| 271 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 272 |
+
const uint64_t s = sum[i];
|
| 273 |
+
if (s > 0) {
|
| 274 |
+
div[i] = 1.0f/float(s);
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 279 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 280 |
+
|
| 281 |
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 282 |
+
data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
if (cparams.embeddings && (
|
| 288 |
+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
| 289 |
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
|
| 290 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 291 |
+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 292 |
+
const int64_t n_seqs = ubatch.n_seqs;
|
| 293 |
+
|
| 294 |
+
GGML_ASSERT(lctx.inp_cls);
|
| 295 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
| 296 |
+
|
| 297 |
+
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
| 298 |
+
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
|
| 299 |
+
|
| 300 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 301 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 302 |
+
|
| 303 |
+
// TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
|
| 304 |
+
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
|
| 305 |
+
|
| 306 |
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 307 |
+
const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
|
| 308 |
+
|
| 309 |
+
if (pos == 0) {
|
| 310 |
+
data[seq_id] = s*n_seq_tokens + i;
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
| 317 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 318 |
+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 319 |
+
const int64_t n_seqs = ubatch.n_seqs;
|
| 320 |
+
|
| 321 |
+
GGML_ASSERT(lctx.inp_cls);
|
| 322 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
| 323 |
+
|
| 324 |
+
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
| 325 |
+
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
|
| 326 |
+
|
| 327 |
+
std::vector<int> last_pos(n_tokens, -1);
|
| 328 |
+
std::vector<int> last_row(n_tokens, -1);
|
| 329 |
+
|
| 330 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 331 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 332 |
+
|
| 333 |
+
// TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
|
| 334 |
+
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
|
| 335 |
+
|
| 336 |
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 337 |
+
const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
|
| 338 |
+
|
| 339 |
+
if (pos >= last_pos[seq_id]) {
|
| 340 |
+
last_pos[seq_id] = pos;
|
| 341 |
+
last_row[seq_id] = s*n_seq_tokens + i;
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 347 |
+
if (last_row[i] >= 0) {
|
| 348 |
+
data[i] = last_row[i];
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
if (kv_self.recurrent) {
|
| 354 |
+
const int64_t n_kv = kv_self.n;
|
| 355 |
+
|
| 356 |
+
if (lctx.inp_s_mask) {
|
| 357 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
| 358 |
+
float * data = (float *) lctx.inp_s_mask->data;
|
| 359 |
+
|
| 360 |
+
// clear unused states
|
| 361 |
+
for (int i = 0; i < n_kv; ++i) {
|
| 362 |
+
const uint32_t cell_id = i + kv_self.head;
|
| 363 |
+
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
| 364 |
+
|
| 365 |
+
data[i] = (float) (kv_cell.src >= 0);
|
| 366 |
+
|
| 367 |
+
// only clear once
|
| 368 |
+
if (kv_cell.src < 0) {
|
| 369 |
+
kv_cell.src = cell_id;
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
if (lctx.inp_s_copy) {
|
| 375 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
|
| 376 |
+
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
|
| 377 |
+
|
| 378 |
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 379 |
+
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 380 |
+
const uint32_t cell_id = i + kv_self.head;
|
| 381 |
+
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
| 382 |
+
|
| 383 |
+
// prevent out-of-bound sources
|
| 384 |
+
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
|
| 385 |
+
kv_cell.src = cell_id;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
data[i] = kv_cell.src;
|
| 389 |
+
|
| 390 |
+
// ensure copy only happens once
|
| 391 |
+
if (kv_cell.src != (int32_t) cell_id) {
|
| 392 |
+
kv_cell.src = cell_id;
|
| 393 |
+
}
|
| 394 |
+
}
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
if (lctx.inp_pos_bucket) {
|
| 399 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 400 |
+
|
| 401 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
|
| 402 |
+
GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
|
| 403 |
+
|
| 404 |
+
int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
|
| 405 |
+
|
| 406 |
+
if (!lctx.is_encoding) {
|
| 407 |
+
const int64_t n_kv = kv_self.n;
|
| 408 |
+
for (int h = 0; h < 1; ++h) {
|
| 409 |
+
for (int j = 0; j < n_tokens; ++j) {
|
| 410 |
+
for (int i = 0; i < n_kv; ++i) {
|
| 411 |
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
} else {
|
| 416 |
+
for (int h = 0; h < 1; ++h) {
|
| 417 |
+
for (int j = 0; j < n_tokens; ++j) {
|
| 418 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 419 |
+
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
}
|
| 423 |
+
}
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
if (!lctx.is_encoding && lctx.inp_embd_enc) {
|
| 427 |
+
assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
|
| 428 |
+
assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
|
| 429 |
+
|
| 430 |
+
ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
|
| 434 |
+
const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
|
| 435 |
+
const int64_t n_tokens = ubatch.n_tokens;
|
| 436 |
+
|
| 437 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
|
| 438 |
+
GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
|
| 439 |
+
|
| 440 |
+
float * data = (float *) lctx.inp_KQ_mask_cross->data;
|
| 441 |
+
|
| 442 |
+
for (int h = 0; h < 1; ++h) {
|
| 443 |
+
for (int j = 0; j < n_tokens; ++j) {
|
| 444 |
+
for (int i = 0; i < n_output_enc; ++i) {
|
| 445 |
+
float f = -INFINITY;
|
| 446 |
+
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
|
| 447 |
+
const llama_seq_id seq_id = ubatch.seq_id[j][s];
|
| 448 |
+
if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
|
| 449 |
+
f = 0.0f;
|
| 450 |
+
}
|
| 451 |
+
}
|
| 452 |
+
data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
|
| 453 |
+
}
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 457 |
+
for (int j = 0; j < n_output_enc; ++j) {
|
| 458 |
+
data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
|
| 459 |
+
}
|
| 460 |
+
}
|
| 461 |
+
}
|
| 462 |
+
}
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
// llama output
|
| 466 |
+
|
| 467 |
+
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
|
| 468 |
+
const auto & cparams = lctx.cparams;
|
| 469 |
+
const auto & hparams = lctx.model.hparams;
|
| 470 |
+
|
| 471 |
+
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
|
| 472 |
+
|
| 473 |
+
const auto n_batch = cparams.n_batch;
|
| 474 |
+
const auto n_vocab = hparams.n_vocab;
|
| 475 |
+
const auto n_embd = hparams.n_embd;
|
| 476 |
+
|
| 477 |
+
// TODO: use a per-batch flag for logits presence instead
|
| 478 |
+
const bool has_logits = !cparams.embeddings;
|
| 479 |
+
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
| 480 |
+
|
| 481 |
+
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
| 482 |
+
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
| 483 |
+
|
| 484 |
+
if (lctx.output_ids.empty()) {
|
| 485 |
+
// init, never resized afterwards
|
| 486 |
+
lctx.output_ids.resize(n_batch);
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0;
|
| 490 |
+
const size_t new_size = (logits_size + embd_size) * sizeof(float);
|
| 491 |
+
|
| 492 |
+
// alloc only when more than the current capacity is required
|
| 493 |
+
// TODO: also consider shrinking the buffer
|
| 494 |
+
if (!lctx.buf_output || prev_size < new_size) {
|
| 495 |
+
if (lctx.buf_output) {
|
| 496 |
+
#ifndef NDEBUG
|
| 497 |
+
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
| 498 |
+
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
| 499 |
+
#endif
|
| 500 |
+
lctx.buf_output = nullptr;
|
| 501 |
+
lctx.logits = nullptr;
|
| 502 |
+
lctx.embd = nullptr;
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
auto * buft = ggml_backend_cpu_buffer_type();
|
| 506 |
+
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
|
| 507 |
+
auto * output_dev = lctx.model.dev_output.dev;
|
| 508 |
+
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
|
| 509 |
+
if (output_dev_host_buft) {
|
| 510 |
+
buft = output_dev_host_buft;
|
| 511 |
+
}
|
| 512 |
+
lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
|
| 513 |
+
if (lctx.buf_output == nullptr) {
|
| 514 |
+
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
|
| 515 |
+
return 0;
|
| 516 |
+
}
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get());
|
| 520 |
+
|
| 521 |
+
lctx.logits = has_logits ? output_base : nullptr;
|
| 522 |
+
lctx.embd = has_embd ? output_base + logits_size : nullptr;
|
| 523 |
+
|
| 524 |
+
lctx.output_size = n_outputs_max;
|
| 525 |
+
lctx.logits_size = logits_size;
|
| 526 |
+
lctx.embd_size = embd_size;
|
| 527 |
+
|
| 528 |
+
// set all ids as invalid (negative)
|
| 529 |
+
std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
|
| 530 |
+
|
| 531 |
+
ggml_backend_buffer_clear(lctx.buf_output.get(), 0);
|
| 532 |
+
|
| 533 |
+
lctx.n_outputs = 0;
|
| 534 |
+
|
| 535 |
+
return n_outputs_max;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
void llama_output_reorder(struct llama_context & ctx) {
|
| 539 |
+
std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
|
| 540 |
+
if (!out_ids.empty()) {
|
| 541 |
+
const uint32_t n_vocab = ctx.model.hparams.n_vocab;
|
| 542 |
+
const uint32_t n_embd = ctx.model.hparams.n_embd;
|
| 543 |
+
|
| 544 |
+
const int32_t n_outputs = ctx.n_outputs;
|
| 545 |
+
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
| 546 |
+
|
| 547 |
+
// TODO: is there something more efficient which also minimizes swaps?
|
| 548 |
+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
| 549 |
+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
| 550 |
+
int32_t j_min = i;
|
| 551 |
+
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
| 552 |
+
if (out_ids[j] < out_ids[j_min]) {
|
| 553 |
+
j_min = j;
|
| 554 |
+
}
|
| 555 |
+
}
|
| 556 |
+
if (j_min == i) { continue; }
|
| 557 |
+
std::swap(out_ids[i], out_ids[j_min]);
|
| 558 |
+
if (ctx.logits_size > 0) {
|
| 559 |
+
for (uint32_t k = 0; k < n_vocab; k++) {
|
| 560 |
+
std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]);
|
| 561 |
+
}
|
| 562 |
+
}
|
| 563 |
+
if (ctx.embd_size > 0) {
|
| 564 |
+
for (uint32_t k = 0; k < n_embd; k++) {
|
| 565 |
+
std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]);
|
| 566 |
+
}
|
| 567 |
+
}
|
| 568 |
+
}
|
| 569 |
+
std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1);
|
| 570 |
+
for (int32_t i = 0; i < n_outputs; ++i) {
|
| 571 |
+
ctx.output_ids[out_ids[i]] = i;
|
| 572 |
+
}
|
| 573 |
+
out_ids.clear();
|
| 574 |
+
}
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
//
|
| 578 |
+
// interface implementation
|
| 579 |
+
//
|
| 580 |
+
|
| 581 |
+
void llama_free(struct llama_context * ctx) {
|
| 582 |
+
delete ctx;
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
uint32_t llama_n_ctx(const struct llama_context * ctx) {
|
| 586 |
+
return ctx->cparams.n_ctx;
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
uint32_t llama_n_batch(const struct llama_context * ctx) {
|
| 590 |
+
return ctx->cparams.n_batch;
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
uint32_t llama_n_ubatch(const struct llama_context * ctx) {
|
| 594 |
+
return ctx->cparams.n_ubatch;
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
uint32_t llama_n_seq_max(const struct llama_context * ctx) {
|
| 598 |
+
return ctx->kv_self.size;
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
|
| 602 |
+
return &ctx->model;
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
|
| 606 |
+
return ctx->cparams.pooling_type;
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
void llama_attach_threadpool(
|
| 610 |
+
struct llama_context * ctx,
|
| 611 |
+
ggml_threadpool_t threadpool,
|
| 612 |
+
ggml_threadpool_t threadpool_batch) {
|
| 613 |
+
ctx->threadpool = threadpool;
|
| 614 |
+
ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
void llama_detach_threadpool(struct llama_context * ctx) {
|
| 618 |
+
ctx->threadpool = nullptr;
|
| 619 |
+
ctx->threadpool_batch = nullptr;
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
|
| 623 |
+
ctx->cparams.n_threads = n_threads;
|
| 624 |
+
ctx->cparams.n_threads_batch = n_threads_batch;
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
int32_t llama_n_threads(struct llama_context * ctx) {
|
| 628 |
+
return ctx->cparams.n_threads;
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
int32_t llama_n_threads_batch(struct llama_context * ctx) {
|
| 632 |
+
return ctx->cparams.n_threads_batch;
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
|
| 636 |
+
ctx->abort_callback = abort_callback;
|
| 637 |
+
ctx->abort_callback_data = abort_callback_data;
|
| 638 |
+
|
| 639 |
+
for (auto & backend : ctx->backends) {
|
| 640 |
+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
|
| 641 |
+
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
|
| 642 |
+
if (set_abort_callback_fn) {
|
| 643 |
+
set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data);
|
| 644 |
+
}
|
| 645 |
+
}
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
|
| 649 |
+
ctx->cparams.embeddings = embeddings;
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
| 653 |
+
ctx->cparams.causal_attn = causal_attn;
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
void llama_synchronize(struct llama_context * ctx) {
|
| 657 |
+
ggml_backend_sched_synchronize(ctx->sched.get());
|
| 658 |
+
|
| 659 |
+
// FIXME: if multiple single tokens are evaluated without a synchronization,
|
| 660 |
+
// the stats will be added to the prompt evaluation stats
|
| 661 |
+
// this should only happen when using batch size 1 to evaluate a batch
|
| 662 |
+
|
| 663 |
+
// add the evaluation to the stats
|
| 664 |
+
if (ctx->n_queued_tokens == 1) {
|
| 665 |
+
if (!ctx->cparams.no_perf) {
|
| 666 |
+
ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
|
| 667 |
+
}
|
| 668 |
+
ctx->n_eval++;
|
| 669 |
+
} else if (ctx->n_queued_tokens > 1) {
|
| 670 |
+
if (!ctx->cparams.no_perf) {
|
| 671 |
+
ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
|
| 672 |
+
}
|
| 673 |
+
ctx->n_p_eval += ctx->n_queued_tokens;
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
// get a more accurate load time, upon first eval
|
| 677 |
+
if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
|
| 678 |
+
ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
|
| 679 |
+
ctx->has_evaluated_once = true;
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
ctx->n_queued_tokens = 0;
|
| 683 |
+
ctx->t_compute_start_us = 0;
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
float * llama_get_logits(struct llama_context * ctx) {
|
| 687 |
+
llama_synchronize(ctx);
|
| 688 |
+
|
| 689 |
+
// reorder logits for backward compatibility
|
| 690 |
+
// TODO: maybe deprecate this
|
| 691 |
+
llama_output_reorder(*ctx);
|
| 692 |
+
|
| 693 |
+
return ctx->logits;
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
| 697 |
+
int32_t j = -1;
|
| 698 |
+
|
| 699 |
+
llama_synchronize(ctx);
|
| 700 |
+
|
| 701 |
+
try {
|
| 702 |
+
if (ctx->logits == nullptr) {
|
| 703 |
+
throw std::runtime_error("no logits");
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
if (i < 0) {
|
| 707 |
+
j = ctx->n_outputs + i;
|
| 708 |
+
if (j < 0) {
|
| 709 |
+
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
| 710 |
+
}
|
| 711 |
+
} else if ((size_t) i >= ctx->output_ids.size()) {
|
| 712 |
+
throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
|
| 713 |
+
} else {
|
| 714 |
+
j = ctx->output_ids[i];
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
if (j < 0) {
|
| 718 |
+
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
| 719 |
+
}
|
| 720 |
+
if (j >= ctx->n_outputs) {
|
| 721 |
+
// This should not happen
|
| 722 |
+
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
| 726 |
+
} catch (const std::exception & err) {
|
| 727 |
+
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
| 728 |
+
#ifndef NDEBUG
|
| 729 |
+
GGML_ABORT("fatal error");
|
| 730 |
+
#else
|
| 731 |
+
return nullptr;
|
| 732 |
+
#endif
|
| 733 |
+
}
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
float * llama_get_embeddings(struct llama_context * ctx) {
|
| 737 |
+
llama_synchronize(ctx);
|
| 738 |
+
|
| 739 |
+
// reorder embeddings for backward compatibility
|
| 740 |
+
// TODO: maybe deprecate this
|
| 741 |
+
llama_output_reorder(*ctx);
|
| 742 |
+
|
| 743 |
+
return ctx->embd;
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
| 747 |
+
int32_t j = -1;
|
| 748 |
+
|
| 749 |
+
llama_synchronize(ctx);
|
| 750 |
+
|
| 751 |
+
try {
|
| 752 |
+
if (ctx->embd == nullptr) {
|
| 753 |
+
throw std::runtime_error("no embeddings");
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
if (i < 0) {
|
| 757 |
+
j = ctx->n_outputs + i;
|
| 758 |
+
if (j < 0) {
|
| 759 |
+
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
| 760 |
+
}
|
| 761 |
+
} else if ((size_t) i >= ctx->output_ids.size()) {
|
| 762 |
+
throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
|
| 763 |
+
} else {
|
| 764 |
+
j = ctx->output_ids[i];
|
| 765 |
+
}
|
| 766 |
+
|
| 767 |
+
if (j < 0) {
|
| 768 |
+
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
| 769 |
+
}
|
| 770 |
+
if (j >= ctx->n_outputs) {
|
| 771 |
+
// This should not happen
|
| 772 |
+
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
| 773 |
+
}
|
| 774 |
+
|
| 775 |
+
return ctx->embd + j*ctx->model.hparams.n_embd;
|
| 776 |
+
} catch (const std::exception & err) {
|
| 777 |
+
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
| 778 |
+
#ifndef NDEBUG
|
| 779 |
+
GGML_ABORT("fatal error");
|
| 780 |
+
#else
|
| 781 |
+
return nullptr;
|
| 782 |
+
#endif
|
| 783 |
+
}
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
| 787 |
+
llama_synchronize(ctx);
|
| 788 |
+
|
| 789 |
+
auto it = ctx->embd_seq.find(seq_id);
|
| 790 |
+
if (it == ctx->embd_seq.end()) {
|
| 791 |
+
return nullptr;
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
return it->second.data();
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
// llama state API
|
| 798 |
+
|
| 799 |
+
// deprecated
|
| 800 |
+
size_t llama_get_state_size(struct llama_context * ctx) {
|
| 801 |
+
return llama_state_get_size(ctx);
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
// deprecated
|
| 805 |
+
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
| 806 |
+
return llama_state_get_data(ctx, dst, -1);
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
// deprecated
|
| 810 |
+
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
| 811 |
+
return llama_state_set_data(ctx, src, -1);
|
| 812 |
+
}
|
| 813 |
+
|
| 814 |
+
// deprecated
|
| 815 |
+
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
| 816 |
+
return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
// deprecated
|
| 820 |
+
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
| 821 |
+
return llama_state_save_file(ctx, path_session, tokens, n_token_count);
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
// TODO: replace all non-fatal assertions with returned errors or exceptions
|
| 825 |
+
struct llama_data_write {
|
| 826 |
+
virtual void write(const void * src, size_t size) = 0;
|
| 827 |
+
virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
| 828 |
+
virtual size_t get_size_written() = 0;
|
| 829 |
+
virtual ~llama_data_write() = default;
|
| 830 |
+
|
| 831 |
+
void write_string(const std::string & str) {
|
| 832 |
+
uint32_t str_size = str.size();
|
| 833 |
+
|
| 834 |
+
write(&str_size, sizeof(str_size));
|
| 835 |
+
write(str.data(), str_size);
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
void write_model_info(const struct llama_context * ctx) {
|
| 839 |
+
const std::string arch_str = llm_arch_name(ctx->model.arch);
|
| 840 |
+
write_string(arch_str);
|
| 841 |
+
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
//void write_rng(const std::mt19937 & rng) {
|
| 845 |
+
// std::ostringstream rng_ss;
|
| 846 |
+
// rng_ss << rng;
|
| 847 |
+
|
| 848 |
+
// const std::string & rng_str = rng_ss.str();
|
| 849 |
+
|
| 850 |
+
// write_string(rng_str);
|
| 851 |
+
//}
|
| 852 |
+
|
| 853 |
+
void write_output_ids(struct llama_context * ctx) {
|
| 854 |
+
llama_output_reorder(*ctx);
|
| 855 |
+
|
| 856 |
+
const uint32_t n_outputs = ctx->n_outputs;
|
| 857 |
+
|
| 858 |
+
std::vector<int32_t> output_pos;
|
| 859 |
+
|
| 860 |
+
const size_t n_batch = ctx->cparams.n_batch;
|
| 861 |
+
const auto & output_ids = ctx->output_ids;
|
| 862 |
+
|
| 863 |
+
GGML_ASSERT(n_outputs <= ctx->output_size);
|
| 864 |
+
|
| 865 |
+
output_pos.resize(n_outputs);
|
| 866 |
+
|
| 867 |
+
// build a more compact representation of the output ids
|
| 868 |
+
for (size_t i = 0; i < n_batch; ++i) {
|
| 869 |
+
// map an output id to a position in the batch
|
| 870 |
+
int32_t pos = output_ids[i];
|
| 871 |
+
if (pos >= 0) {
|
| 872 |
+
GGML_ASSERT((uint32_t) pos < n_outputs);
|
| 873 |
+
output_pos[pos] = i;
|
| 874 |
+
}
|
| 875 |
+
}
|
| 876 |
+
|
| 877 |
+
write(&n_outputs, sizeof(n_outputs));
|
| 878 |
+
|
| 879 |
+
if (n_outputs) {
|
| 880 |
+
write(output_pos.data(), n_outputs * sizeof(int32_t));
|
| 881 |
+
}
|
| 882 |
+
}
|
| 883 |
+
|
| 884 |
+
void write_logits(const struct llama_context * ctx) {
|
| 885 |
+
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
|
| 886 |
+
|
| 887 |
+
write(&logits_size, sizeof(logits_size));
|
| 888 |
+
|
| 889 |
+
if (logits_size) {
|
| 890 |
+
write(ctx->logits, logits_size * sizeof(float));
|
| 891 |
+
}
|
| 892 |
+
}
|
| 893 |
+
|
| 894 |
+
void write_embeddings(const struct llama_context * ctx) {
|
| 895 |
+
const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
|
| 896 |
+
|
| 897 |
+
write(&embeddings_size, sizeof(embeddings_size));
|
| 898 |
+
|
| 899 |
+
if (embeddings_size) {
|
| 900 |
+
write(ctx->embd, embeddings_size * sizeof(float));
|
| 901 |
+
}
|
| 902 |
+
}
|
| 903 |
+
|
| 904 |
+
void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
|
| 905 |
+
for (const auto & range : cell_ranges) {
|
| 906 |
+
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 907 |
+
const auto & cell = kv_self.cells[i];
|
| 908 |
+
const llama_pos pos = cell.pos;
|
| 909 |
+
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
|
| 910 |
+
|
| 911 |
+
write(&pos, sizeof(pos));
|
| 912 |
+
write(&n_seq_id, sizeof(n_seq_id));
|
| 913 |
+
|
| 914 |
+
if (n_seq_id) {
|
| 915 |
+
for (auto seq_id : cell.seq_id) {
|
| 916 |
+
write(&seq_id, sizeof(seq_id));
|
| 917 |
+
}
|
| 918 |
+
}
|
| 919 |
+
}
|
| 920 |
+
}
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
|
| 924 |
+
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
| 925 |
+
const struct llama_hparams & hparams = ctx->model.hparams;
|
| 926 |
+
|
| 927 |
+
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
|
| 928 |
+
const uint32_t n_layer = hparams.n_layer;
|
| 929 |
+
|
| 930 |
+
write(&v_trans, sizeof(v_trans));
|
| 931 |
+
write(&n_layer, sizeof(n_layer));
|
| 932 |
+
|
| 933 |
+
std::vector<uint8_t> tmp_buf;
|
| 934 |
+
|
| 935 |
+
// Iterate and write all the keys first, each row is a cell
|
| 936 |
+
// Get whole range at a time
|
| 937 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 938 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 939 |
+
|
| 940 |
+
// Write key type
|
| 941 |
+
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
| 942 |
+
write(&k_type_i, sizeof(k_type_i));
|
| 943 |
+
|
| 944 |
+
// Write row size of key
|
| 945 |
+
const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
|
| 946 |
+
write(&k_size_row, sizeof(k_size_row));
|
| 947 |
+
|
| 948 |
+
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 949 |
+
for (const auto & range : cell_ranges) {
|
| 950 |
+
const size_t range_size = range.second - range.first;
|
| 951 |
+
const size_t buf_size = range_size * k_size_row;
|
| 952 |
+
write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
|
| 953 |
+
}
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
if (!kv_self.v_trans) {
|
| 957 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 958 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 959 |
+
|
| 960 |
+
// Write value type
|
| 961 |
+
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
| 962 |
+
write(&v_type_i, sizeof(v_type_i));
|
| 963 |
+
|
| 964 |
+
// Write row size of value
|
| 965 |
+
const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
| 966 |
+
write(&v_size_row, sizeof(v_size_row));
|
| 967 |
+
|
| 968 |
+
// Read each range of cells of v_size length each into tmp_buf and write out
|
| 969 |
+
for (const auto & range : cell_ranges) {
|
| 970 |
+
const size_t range_size = range.second - range.first;
|
| 971 |
+
const size_t buf_size = range_size * v_size_row;
|
| 972 |
+
write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
|
| 973 |
+
}
|
| 974 |
+
}
|
| 975 |
+
} else {
|
| 976 |
+
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 977 |
+
const uint32_t kv_size = kv_self.size;
|
| 978 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 979 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 980 |
+
|
| 981 |
+
// Write value type
|
| 982 |
+
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
| 983 |
+
write(&v_type_i, sizeof(v_type_i));
|
| 984 |
+
|
| 985 |
+
// Write element size
|
| 986 |
+
const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
| 987 |
+
write(&v_size_el, sizeof(v_size_el));
|
| 988 |
+
|
| 989 |
+
// Write GQA embedding size
|
| 990 |
+
write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
| 991 |
+
|
| 992 |
+
// For each row, we get the element values of each cell
|
| 993 |
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 994 |
+
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
| 995 |
+
for (const auto & range : cell_ranges) {
|
| 996 |
+
const size_t range_size = range.second - range.first;
|
| 997 |
+
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
| 998 |
+
const size_t buf_size = range_size * v_size_el;
|
| 999 |
+
write_tensor_data(kv_self.v_l[il], src_offset, buf_size);
|
| 1000 |
+
}
|
| 1001 |
+
}
|
| 1002 |
+
}
|
| 1003 |
+
}
|
| 1004 |
+
}
|
| 1005 |
+
|
| 1006 |
+
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
| 1007 |
+
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
| 1008 |
+
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
| 1009 |
+
uint32_t cell_count = 0;
|
| 1010 |
+
|
| 1011 |
+
// Count the number of cells with the specified seq_id
|
| 1012 |
+
// Find all the ranges of cells with this seq id (or all, when -1)
|
| 1013 |
+
uint32_t cell_range_begin = kv_self.size;
|
| 1014 |
+
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
| 1015 |
+
const auto & cell = kv_self.cells[i];
|
| 1016 |
+
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
| 1017 |
+
++cell_count;
|
| 1018 |
+
if (cell_range_begin == kv_self.size) {
|
| 1019 |
+
cell_range_begin = i;
|
| 1020 |
+
}
|
| 1021 |
+
} else {
|
| 1022 |
+
if (cell_range_begin != kv_self.size) {
|
| 1023 |
+
cell_ranges.emplace_back(cell_range_begin, i);
|
| 1024 |
+
cell_range_begin = kv_self.size;
|
| 1025 |
+
}
|
| 1026 |
+
}
|
| 1027 |
+
}
|
| 1028 |
+
if (cell_range_begin != kv_self.size) {
|
| 1029 |
+
cell_ranges.emplace_back(cell_range_begin, kv_self.size);
|
| 1030 |
+
}
|
| 1031 |
+
|
| 1032 |
+
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
| 1033 |
+
uint32_t cell_count_check = 0;
|
| 1034 |
+
for (const auto & range : cell_ranges) {
|
| 1035 |
+
cell_count_check += range.second - range.first;
|
| 1036 |
+
}
|
| 1037 |
+
GGML_ASSERT(cell_count == cell_count_check);
|
| 1038 |
+
|
| 1039 |
+
write(&cell_count, sizeof(cell_count));
|
| 1040 |
+
|
| 1041 |
+
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
|
| 1042 |
+
write_kv_cache_data(ctx, cell_ranges);
|
| 1043 |
+
}
|
| 1044 |
+
};
|
| 1045 |
+
|
| 1046 |
+
struct llama_data_read {
|
| 1047 |
+
virtual const uint8_t * read(size_t size) = 0;
|
| 1048 |
+
virtual void read_to(void * dst, size_t size) = 0;
|
| 1049 |
+
virtual size_t get_size_read() = 0;
|
| 1050 |
+
virtual ~llama_data_read() = default;
|
| 1051 |
+
|
| 1052 |
+
void read_string(std::string & str) {
|
| 1053 |
+
uint32_t str_size;
|
| 1054 |
+
read_to(&str_size, sizeof(str_size));
|
| 1055 |
+
|
| 1056 |
+
str.assign((const char *) read(str_size), str_size);
|
| 1057 |
+
}
|
| 1058 |
+
|
| 1059 |
+
// validate model information
|
| 1060 |
+
void read_model_info(const struct llama_context * ctx) {
|
| 1061 |
+
const std::string cur_arch_str = llm_arch_name(ctx->model.arch);
|
| 1062 |
+
|
| 1063 |
+
std::string arch_str;
|
| 1064 |
+
read_string(arch_str);
|
| 1065 |
+
if (cur_arch_str != arch_str) {
|
| 1066 |
+
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
|
| 1067 |
+
}
|
| 1068 |
+
// TODO: add more info which needs to be identical but which is not verified otherwise
|
| 1069 |
+
}
|
| 1070 |
+
|
| 1071 |
+
//void read_rng(std::mt19937 & rng) {
|
| 1072 |
+
// std::string rng_str;
|
| 1073 |
+
// read_string(rng_str);
|
| 1074 |
+
|
| 1075 |
+
// std::istringstream rng_ss(rng_str);
|
| 1076 |
+
// rng_ss >> rng;
|
| 1077 |
+
|
| 1078 |
+
// if (rng_ss.fail()) {
|
| 1079 |
+
// throw std::runtime_error("failed to load RNG state");
|
| 1080 |
+
// }
|
| 1081 |
+
//}
|
| 1082 |
+
|
| 1083 |
+
void read_output_ids(struct llama_context * ctx) {
|
| 1084 |
+
std::vector<int32_t> output_pos;
|
| 1085 |
+
|
| 1086 |
+
uint32_t n_outputs;
|
| 1087 |
+
read_to(&n_outputs, sizeof(n_outputs));
|
| 1088 |
+
|
| 1089 |
+
if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
|
| 1090 |
+
throw std::runtime_error("could not reserve outputs");
|
| 1091 |
+
}
|
| 1092 |
+
|
| 1093 |
+
if (n_outputs) {
|
| 1094 |
+
output_pos.resize(n_outputs);
|
| 1095 |
+
read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
| 1096 |
+
|
| 1097 |
+
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
| 1098 |
+
int32_t id = output_pos[i];
|
| 1099 |
+
if ((uint32_t) id >= ctx->cparams.n_batch) {
|
| 1100 |
+
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
|
| 1101 |
+
}
|
| 1102 |
+
ctx->output_ids[id] = i;
|
| 1103 |
+
}
|
| 1104 |
+
|
| 1105 |
+
ctx->n_outputs = n_outputs;
|
| 1106 |
+
}
|
| 1107 |
+
}
|
| 1108 |
+
|
| 1109 |
+
void read_logits(struct llama_context * ctx) {
|
| 1110 |
+
uint64_t logits_size;
|
| 1111 |
+
read_to(&logits_size, sizeof(logits_size));
|
| 1112 |
+
|
| 1113 |
+
if (ctx->logits_size < logits_size) {
|
| 1114 |
+
throw std::runtime_error("logits buffer too small");
|
| 1115 |
+
}
|
| 1116 |
+
|
| 1117 |
+
if (logits_size) {
|
| 1118 |
+
read_to(ctx->logits, logits_size * sizeof(float));
|
| 1119 |
+
}
|
| 1120 |
+
}
|
| 1121 |
+
|
| 1122 |
+
void read_embeddings(struct llama_context * ctx) {
|
| 1123 |
+
uint64_t embeddings_size;
|
| 1124 |
+
read_to(&embeddings_size, sizeof(embeddings_size));
|
| 1125 |
+
|
| 1126 |
+
if (ctx->embd_size < embeddings_size) {
|
| 1127 |
+
throw std::runtime_error("embeddings buffer too small");
|
| 1128 |
+
}
|
| 1129 |
+
|
| 1130 |
+
if (embeddings_size) {
|
| 1131 |
+
read_to(ctx->embd, embeddings_size * sizeof(float));
|
| 1132 |
+
}
|
| 1133 |
+
}
|
| 1134 |
+
|
| 1135 |
+
bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
|
| 1136 |
+
struct llama_kv_cache & kv_self = ctx->kv_self;
|
| 1137 |
+
|
| 1138 |
+
if (dest_seq_id != -1) {
|
| 1139 |
+
// single sequence
|
| 1140 |
+
|
| 1141 |
+
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
| 1142 |
+
|
| 1143 |
+
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 1144 |
+
batch.n_tokens = cell_count;
|
| 1145 |
+
batch.n_seq_tokens = cell_count;
|
| 1146 |
+
batch.n_seqs = 1;
|
| 1147 |
+
|
| 1148 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1149 |
+
llama_pos pos;
|
| 1150 |
+
uint32_t n_seq_id;
|
| 1151 |
+
|
| 1152 |
+
read_to(&pos, sizeof(pos));
|
| 1153 |
+
read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1154 |
+
|
| 1155 |
+
if (n_seq_id != 0) {
|
| 1156 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
| 1157 |
+
return false;
|
| 1158 |
+
}
|
| 1159 |
+
|
| 1160 |
+
batch.pos[i] = pos;
|
| 1161 |
+
}
|
| 1162 |
+
batch.n_seq_id[0] = 1;
|
| 1163 |
+
batch.seq_id[0] = &dest_seq_id;
|
| 1164 |
+
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
| 1165 |
+
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 1166 |
+
return false;
|
| 1167 |
+
}
|
| 1168 |
+
|
| 1169 |
+
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 1170 |
+
// Assume that this is one contiguous block of cells
|
| 1171 |
+
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
|
| 1172 |
+
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
|
| 1173 |
+
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
| 1174 |
+
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
|
| 1175 |
+
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
|
| 1176 |
+
} else {
|
| 1177 |
+
// whole KV cache restore
|
| 1178 |
+
|
| 1179 |
+
if (cell_count > kv_self.size) {
|
| 1180 |
+
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
| 1181 |
+
return false;
|
| 1182 |
+
}
|
| 1183 |
+
|
| 1184 |
+
llama_kv_cache_clear(kv_self);
|
| 1185 |
+
|
| 1186 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1187 |
+
llama_kv_cell & cell = kv_self.cells[i];
|
| 1188 |
+
|
| 1189 |
+
llama_pos pos;
|
| 1190 |
+
uint32_t n_seq_id;
|
| 1191 |
+
|
| 1192 |
+
read_to(&pos, sizeof(pos));
|
| 1193 |
+
read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1194 |
+
|
| 1195 |
+
cell.pos = pos;
|
| 1196 |
+
|
| 1197 |
+
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 1198 |
+
llama_seq_id seq_id;
|
| 1199 |
+
read_to(&seq_id, sizeof(seq_id));
|
| 1200 |
+
|
| 1201 |
+
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
| 1202 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
| 1203 |
+
return false;
|
| 1204 |
+
}
|
| 1205 |
+
|
| 1206 |
+
cell.seq_id.insert(seq_id);
|
| 1207 |
+
|
| 1208 |
+
if (kv_self.recurrent) {
|
| 1209 |
+
int32_t & tail = kv_self.cells[seq_id].tail;
|
| 1210 |
+
if (tail != -1) {
|
| 1211 |
+
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
|
| 1212 |
+
return false;
|
| 1213 |
+
}
|
| 1214 |
+
tail = i;
|
| 1215 |
+
}
|
| 1216 |
+
}
|
| 1217 |
+
}
|
| 1218 |
+
|
| 1219 |
+
kv_self.head = 0;
|
| 1220 |
+
kv_self.used = cell_count;
|
| 1221 |
+
}
|
| 1222 |
+
|
| 1223 |
+
if (kv_self.recurrent) {
|
| 1224 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1225 |
+
uint32_t cell_id = kv_self.head + i;
|
| 1226 |
+
// make sure the recurrent states will keep their restored state
|
| 1227 |
+
kv_self.cells[cell_id].src = cell_id;
|
| 1228 |
+
}
|
| 1229 |
+
}
|
| 1230 |
+
|
| 1231 |
+
return true;
|
| 1232 |
+
}
|
| 1233 |
+
|
| 1234 |
+
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
|
| 1235 |
+
const struct llama_hparams & hparams = ctx->model.hparams;
|
| 1236 |
+
struct llama_kv_cache & kv_self = ctx->kv_self;
|
| 1237 |
+
uint32_t v_trans;
|
| 1238 |
+
uint32_t n_layer;
|
| 1239 |
+
read_to(&v_trans, sizeof(v_trans));
|
| 1240 |
+
read_to(&n_layer, sizeof(n_layer));
|
| 1241 |
+
|
| 1242 |
+
if (n_layer != hparams.n_layer) {
|
| 1243 |
+
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
|
| 1244 |
+
return false;
|
| 1245 |
+
}
|
| 1246 |
+
if (cell_count > kv_self.size) {
|
| 1247 |
+
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
|
| 1248 |
+
return false;
|
| 1249 |
+
}
|
| 1250 |
+
if (kv_self.v_trans != (bool) v_trans) {
|
| 1251 |
+
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
| 1252 |
+
return false;
|
| 1253 |
+
}
|
| 1254 |
+
|
| 1255 |
+
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 1256 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1257 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 1258 |
+
|
| 1259 |
+
// Read type of key
|
| 1260 |
+
int32_t k_type_i_ref;
|
| 1261 |
+
read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
| 1262 |
+
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
| 1263 |
+
if (k_type_i != k_type_i_ref) {
|
| 1264 |
+
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
| 1265 |
+
return false;
|
| 1266 |
+
}
|
| 1267 |
+
|
| 1268 |
+
// Read row size of key
|
| 1269 |
+
uint64_t k_size_row_ref;
|
| 1270 |
+
read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
| 1271 |
+
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
|
| 1272 |
+
if (k_size_row != k_size_row_ref) {
|
| 1273 |
+
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
| 1274 |
+
return false;
|
| 1275 |
+
}
|
| 1276 |
+
|
| 1277 |
+
if (cell_count) {
|
| 1278 |
+
// Read and set the keys for the whole cell range
|
| 1279 |
+
ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
|
| 1280 |
+
}
|
| 1281 |
+
}
|
| 1282 |
+
|
| 1283 |
+
if (!kv_self.v_trans) {
|
| 1284 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1285 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1286 |
+
|
| 1287 |
+
// Read type of value
|
| 1288 |
+
int32_t v_type_i_ref;
|
| 1289 |
+
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1290 |
+
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
| 1291 |
+
if (v_type_i != v_type_i_ref) {
|
| 1292 |
+
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1293 |
+
return false;
|
| 1294 |
+
}
|
| 1295 |
+
|
| 1296 |
+
// Read row size of value
|
| 1297 |
+
uint64_t v_size_row_ref;
|
| 1298 |
+
read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
| 1299 |
+
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
| 1300 |
+
if (v_size_row != v_size_row_ref) {
|
| 1301 |
+
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
| 1302 |
+
return false;
|
| 1303 |
+
}
|
| 1304 |
+
|
| 1305 |
+
if (cell_count) {
|
| 1306 |
+
// Read and set the values for the whole cell range
|
| 1307 |
+
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
|
| 1308 |
+
}
|
| 1309 |
+
}
|
| 1310 |
+
} else {
|
| 1311 |
+
// For each layer, read the values for each cell (transposed)
|
| 1312 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1313 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1314 |
+
|
| 1315 |
+
// Read type of value
|
| 1316 |
+
int32_t v_type_i_ref;
|
| 1317 |
+
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1318 |
+
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
| 1319 |
+
if (v_type_i != v_type_i_ref) {
|
| 1320 |
+
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1321 |
+
return false;
|
| 1322 |
+
}
|
| 1323 |
+
|
| 1324 |
+
// Read element size of value
|
| 1325 |
+
uint32_t v_size_el_ref;
|
| 1326 |
+
read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
| 1327 |
+
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
| 1328 |
+
if (v_size_el != v_size_el_ref) {
|
| 1329 |
+
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
| 1330 |
+
return false;
|
| 1331 |
+
}
|
| 1332 |
+
|
| 1333 |
+
// Read GQA embedding size
|
| 1334 |
+
uint32_t n_embd_v_gqa_ref;
|
| 1335 |
+
read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
| 1336 |
+
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
| 1337 |
+
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
| 1338 |
+
return false;
|
| 1339 |
+
}
|
| 1340 |
+
|
| 1341 |
+
if (cell_count) {
|
| 1342 |
+
// For each row in the transposed matrix, read the values for the whole cell range
|
| 1343 |
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1344 |
+
const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
|
| 1345 |
+
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
| 1346 |
+
}
|
| 1347 |
+
}
|
| 1348 |
+
}
|
| 1349 |
+
}
|
| 1350 |
+
return true;
|
| 1351 |
+
}
|
| 1352 |
+
|
| 1353 |
+
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
| 1354 |
+
uint32_t cell_count;
|
| 1355 |
+
read_to(&cell_count, sizeof(cell_count));
|
| 1356 |
+
|
| 1357 |
+
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
|
| 1358 |
+
|
| 1359 |
+
if (!res) {
|
| 1360 |
+
if (seq_id == -1) {
|
| 1361 |
+
llama_kv_cache_clear(ctx);
|
| 1362 |
+
} else {
|
| 1363 |
+
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
|
| 1364 |
+
}
|
| 1365 |
+
throw std::runtime_error("failed to restore kv cache");
|
| 1366 |
+
}
|
| 1367 |
+
}
|
| 1368 |
+
};
|
| 1369 |
+
|
| 1370 |
+
struct llama_data_write_dummy : llama_data_write {
|
| 1371 |
+
size_t size_written = 0;
|
| 1372 |
+
|
| 1373 |
+
llama_data_write_dummy() {}
|
| 1374 |
+
|
| 1375 |
+
void write(const void * /* src */, size_t size) override {
|
| 1376 |
+
size_written += size;
|
| 1377 |
+
}
|
| 1378 |
+
|
| 1379 |
+
void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
|
| 1380 |
+
size_written += size;
|
| 1381 |
+
}
|
| 1382 |
+
|
| 1383 |
+
size_t get_size_written() override {
|
| 1384 |
+
return size_written;
|
| 1385 |
+
}
|
| 1386 |
+
};
|
| 1387 |
+
|
| 1388 |
+
struct llama_data_write_buffer : llama_data_write {
|
| 1389 |
+
uint8_t * ptr;
|
| 1390 |
+
size_t buf_size = 0;
|
| 1391 |
+
size_t size_written = 0;
|
| 1392 |
+
|
| 1393 |
+
llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
| 1394 |
+
|
| 1395 |
+
void write(const void * src, size_t size) override {
|
| 1396 |
+
if (size > buf_size) {
|
| 1397 |
+
throw std::runtime_error("unexpectedly reached end of buffer");
|
| 1398 |
+
}
|
| 1399 |
+
memcpy(ptr, src, size);
|
| 1400 |
+
ptr += size;
|
| 1401 |
+
size_written += size;
|
| 1402 |
+
buf_size -= size;
|
| 1403 |
+
}
|
| 1404 |
+
|
| 1405 |
+
void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
|
| 1406 |
+
if (size > buf_size) {
|
| 1407 |
+
throw std::runtime_error("unexpectedly reached end of buffer");
|
| 1408 |
+
}
|
| 1409 |
+
ggml_backend_tensor_get(tensor, ptr, offset, size);
|
| 1410 |
+
ptr += size;
|
| 1411 |
+
size_written += size;
|
| 1412 |
+
buf_size -= size;
|
| 1413 |
+
}
|
| 1414 |
+
|
| 1415 |
+
size_t get_size_written() override {
|
| 1416 |
+
return size_written;
|
| 1417 |
+
}
|
| 1418 |
+
};
|
| 1419 |
+
|
| 1420 |
+
struct llama_data_read_buffer : llama_data_read {
|
| 1421 |
+
const uint8_t * ptr;
|
| 1422 |
+
size_t buf_size = 0;
|
| 1423 |
+
size_t size_read = 0;
|
| 1424 |
+
|
| 1425 |
+
llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
| 1426 |
+
|
| 1427 |
+
const uint8_t * read(size_t size) override {
|
| 1428 |
+
const uint8_t * base_ptr = ptr;
|
| 1429 |
+
if (size > buf_size) {
|
| 1430 |
+
throw std::runtime_error("unexpectedly reached end of buffer");
|
| 1431 |
+
}
|
| 1432 |
+
ptr += size;
|
| 1433 |
+
size_read += size;
|
| 1434 |
+
buf_size -= size;
|
| 1435 |
+
return base_ptr;
|
| 1436 |
+
}
|
| 1437 |
+
|
| 1438 |
+
void read_to(void * dst, size_t size) override {
|
| 1439 |
+
memcpy(dst, read(size), size);
|
| 1440 |
+
}
|
| 1441 |
+
|
| 1442 |
+
size_t get_size_read() override {
|
| 1443 |
+
return size_read;
|
| 1444 |
+
}
|
| 1445 |
+
};
|
| 1446 |
+
|
| 1447 |
+
struct llama_data_write_file : llama_data_write {
|
| 1448 |
+
llama_file * file;
|
| 1449 |
+
size_t size_written = 0;
|
| 1450 |
+
std::vector<uint8_t> temp_buffer;
|
| 1451 |
+
|
| 1452 |
+
llama_data_write_file(llama_file * f) : file(f) {}
|
| 1453 |
+
|
| 1454 |
+
void write(const void * src, size_t size) override {
|
| 1455 |
+
file->write_raw(src, size);
|
| 1456 |
+
size_written += size;
|
| 1457 |
+
}
|
| 1458 |
+
|
| 1459 |
+
void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
|
| 1460 |
+
temp_buffer.resize(size);
|
| 1461 |
+
ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
|
| 1462 |
+
write(temp_buffer.data(), temp_buffer.size());
|
| 1463 |
+
}
|
| 1464 |
+
|
| 1465 |
+
size_t get_size_written() override {
|
| 1466 |
+
return size_written;
|
| 1467 |
+
}
|
| 1468 |
+
};
|
| 1469 |
+
|
| 1470 |
+
struct llama_data_read_file : llama_data_read {
|
| 1471 |
+
llama_file * file;
|
| 1472 |
+
size_t size_read = 0;
|
| 1473 |
+
std::vector<uint8_t> temp_buffer;
|
| 1474 |
+
|
| 1475 |
+
llama_data_read_file(llama_file * f) : file(f) {}
|
| 1476 |
+
|
| 1477 |
+
void read_to(void * dst, size_t size) override {
|
| 1478 |
+
file->read_raw(dst, size);
|
| 1479 |
+
size_read += size;
|
| 1480 |
+
}
|
| 1481 |
+
|
| 1482 |
+
const uint8_t * read(size_t size) override {
|
| 1483 |
+
temp_buffer.resize(size);
|
| 1484 |
+
read_to(temp_buffer.data(), size);
|
| 1485 |
+
return temp_buffer.data();
|
| 1486 |
+
}
|
| 1487 |
+
|
| 1488 |
+
size_t get_size_read() override {
|
| 1489 |
+
return size_read;
|
| 1490 |
+
}
|
| 1491 |
+
};
|
| 1492 |
+
|
| 1493 |
+
/** copy state data into either a buffer or file depending on the passed in context
|
| 1494 |
+
*
|
| 1495 |
+
* file context:
|
| 1496 |
+
* llama_file file("/path", "wb");
|
| 1497 |
+
* llama_data_write_file data_ctx(&file);
|
| 1498 |
+
* llama_state_get_data_internal(ctx, data_ctx);
|
| 1499 |
+
*
|
| 1500 |
+
* buffer context:
|
| 1501 |
+
* std::vector<uint8_t> buf(max_size, 0);
|
| 1502 |
+
* llama_data_write_buffer data_ctx(buf.data(), max_size);
|
| 1503 |
+
* llama_state_get_data_internal(ctx, data_ctx);
|
| 1504 |
+
*
|
| 1505 |
+
*/
|
| 1506 |
+
static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
|
| 1507 |
+
llama_synchronize(ctx);
|
| 1508 |
+
|
| 1509 |
+
data_ctx.write_model_info(ctx);
|
| 1510 |
+
|
| 1511 |
+
// copy outputs
|
| 1512 |
+
data_ctx.write_output_ids(ctx);
|
| 1513 |
+
data_ctx.write_logits(ctx);
|
| 1514 |
+
data_ctx.write_embeddings(ctx);
|
| 1515 |
+
|
| 1516 |
+
data_ctx.write_kv_cache(ctx);
|
| 1517 |
+
|
| 1518 |
+
return data_ctx.get_size_written();
|
| 1519 |
+
}
|
| 1520 |
+
|
| 1521 |
+
size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
|
| 1522 |
+
llama_data_write_buffer data_ctx(dst, size);
|
| 1523 |
+
try {
|
| 1524 |
+
return llama_state_get_data_internal(ctx, data_ctx);
|
| 1525 |
+
} catch (const std::exception & err) {
|
| 1526 |
+
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
| 1527 |
+
return 0;
|
| 1528 |
+
}
|
| 1529 |
+
}
|
| 1530 |
+
|
| 1531 |
+
// Returns the *actual* size of the state.
|
| 1532 |
+
// Intended to be used when saving to state to a buffer.
|
| 1533 |
+
size_t llama_state_get_size(struct llama_context * ctx) {
|
| 1534 |
+
llama_data_write_dummy data_ctx;
|
| 1535 |
+
try {
|
| 1536 |
+
return llama_state_get_data_internal(ctx, data_ctx);
|
| 1537 |
+
} catch (const std::exception & err) {
|
| 1538 |
+
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
| 1539 |
+
return 0;
|
| 1540 |
+
}
|
| 1541 |
+
}
|
| 1542 |
+
|
| 1543 |
+
static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
|
| 1544 |
+
llama_synchronize(ctx);
|
| 1545 |
+
|
| 1546 |
+
data_ctx.read_model_info(ctx);
|
| 1547 |
+
|
| 1548 |
+
// set outputs
|
| 1549 |
+
data_ctx.read_output_ids(ctx);
|
| 1550 |
+
data_ctx.read_logits(ctx);
|
| 1551 |
+
data_ctx.read_embeddings(ctx);
|
| 1552 |
+
|
| 1553 |
+
data_ctx.read_kv_cache(ctx);
|
| 1554 |
+
|
| 1555 |
+
return data_ctx.get_size_read();
|
| 1556 |
+
}
|
| 1557 |
+
|
| 1558 |
+
// Sets the state reading from the specified source address
|
| 1559 |
+
size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
|
| 1560 |
+
llama_data_read_buffer data_ctx(src, size);
|
| 1561 |
+
try {
|
| 1562 |
+
return llama_state_set_data_internal(ctx, data_ctx);
|
| 1563 |
+
} catch (const std::exception & err) {
|
| 1564 |
+
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
| 1565 |
+
return 0;
|
| 1566 |
+
}
|
| 1567 |
+
}
|
| 1568 |
+
|
| 1569 |
+
static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
| 1570 |
+
llama_file file(path_session, "rb");
|
| 1571 |
+
|
| 1572 |
+
// sanity checks
|
| 1573 |
+
{
|
| 1574 |
+
const uint32_t magic = file.read_u32();
|
| 1575 |
+
const uint32_t version = file.read_u32();
|
| 1576 |
+
|
| 1577 |
+
if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
|
| 1578 |
+
LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
|
| 1579 |
+
return false;
|
| 1580 |
+
}
|
| 1581 |
+
}
|
| 1582 |
+
|
| 1583 |
+
// load the prompt
|
| 1584 |
+
{
|
| 1585 |
+
const uint32_t n_token_count = file.read_u32();
|
| 1586 |
+
|
| 1587 |
+
if (n_token_count > n_token_capacity) {
|
| 1588 |
+
LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
| 1589 |
+
return false;
|
| 1590 |
+
}
|
| 1591 |
+
|
| 1592 |
+
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
| 1593 |
+
*n_token_count_out = n_token_count;
|
| 1594 |
+
}
|
| 1595 |
+
|
| 1596 |
+
// restore the context state
|
| 1597 |
+
{
|
| 1598 |
+
const size_t n_state_size_cur = file.size() - file.tell();
|
| 1599 |
+
|
| 1600 |
+
llama_data_read_file data_ctx(&file);
|
| 1601 |
+
const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
|
| 1602 |
+
|
| 1603 |
+
if (n_read != n_state_size_cur) {
|
| 1604 |
+
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
|
| 1605 |
+
return false;
|
| 1606 |
+
}
|
| 1607 |
+
}
|
| 1608 |
+
return true;
|
| 1609 |
+
}
|
| 1610 |
+
|
| 1611 |
+
bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
| 1612 |
+
try {
|
| 1613 |
+
return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
| 1614 |
+
} catch (const std::exception & err) {
|
| 1615 |
+
LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
|
| 1616 |
+
return false;
|
| 1617 |
+
}
|
| 1618 |
+
}
|
| 1619 |
+
|
| 1620 |
+
static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
| 1621 |
+
llama_file file(path_session, "wb");
|
| 1622 |
+
|
| 1623 |
+
file.write_u32(LLAMA_SESSION_MAGIC);
|
| 1624 |
+
file.write_u32(LLAMA_SESSION_VERSION);
|
| 1625 |
+
|
| 1626 |
+
// save the prompt
|
| 1627 |
+
file.write_u32((uint32_t) n_token_count);
|
| 1628 |
+
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
| 1629 |
+
|
| 1630 |
+
// save the context state using stream saving
|
| 1631 |
+
llama_data_write_file data_ctx(&file);
|
| 1632 |
+
llama_state_get_data_internal(ctx, data_ctx);
|
| 1633 |
+
|
| 1634 |
+
return true;
|
| 1635 |
+
}
|
| 1636 |
+
|
| 1637 |
+
bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
| 1638 |
+
try {
|
| 1639 |
+
return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
|
| 1640 |
+
} catch (const std::exception & err) {
|
| 1641 |
+
LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
|
| 1642 |
+
return false;
|
| 1643 |
+
}
|
| 1644 |
+
}
|
| 1645 |
+
|
| 1646 |
+
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
|
| 1647 |
+
llama_synchronize(ctx);
|
| 1648 |
+
|
| 1649 |
+
data_ctx.write_kv_cache(ctx, seq_id);
|
| 1650 |
+
|
| 1651 |
+
return data_ctx.get_size_written();
|
| 1652 |
+
}
|
| 1653 |
+
|
| 1654 |
+
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
|
| 1655 |
+
llama_data_write_dummy data_ctx;
|
| 1656 |
+
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
| 1657 |
+
}
|
| 1658 |
+
|
| 1659 |
+
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
| 1660 |
+
llama_data_write_buffer data_ctx(dst, size);
|
| 1661 |
+
try {
|
| 1662 |
+
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
| 1663 |
+
} catch (const std::exception & err) {
|
| 1664 |
+
LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
|
| 1665 |
+
return 0;
|
| 1666 |
+
}
|
| 1667 |
+
}
|
| 1668 |
+
|
| 1669 |
+
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
|
| 1670 |
+
llama_synchronize(ctx);
|
| 1671 |
+
|
| 1672 |
+
data_ctx.read_kv_cache(ctx, dest_seq_id);
|
| 1673 |
+
|
| 1674 |
+
return data_ctx.get_size_read();
|
| 1675 |
+
}
|
| 1676 |
+
|
| 1677 |
+
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
|
| 1678 |
+
llama_data_read_buffer data_ctx(src, size);
|
| 1679 |
+
try {
|
| 1680 |
+
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
| 1681 |
+
} catch (const std::exception & err) {
|
| 1682 |
+
LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
|
| 1683 |
+
return 0;
|
| 1684 |
+
}
|
| 1685 |
+
}
|
| 1686 |
+
|
| 1687 |
+
static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
| 1688 |
+
llama_file file(filepath, "wb");
|
| 1689 |
+
|
| 1690 |
+
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
|
| 1691 |
+
file.write_u32(LLAMA_STATE_SEQ_VERSION);
|
| 1692 |
+
|
| 1693 |
+
// save the prompt
|
| 1694 |
+
file.write_u32((uint32_t) n_token_count);
|
| 1695 |
+
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
| 1696 |
+
|
| 1697 |
+
// save the context state using stream saving
|
| 1698 |
+
llama_data_write_file data_ctx(&file);
|
| 1699 |
+
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
| 1700 |
+
|
| 1701 |
+
const size_t res = file.tell();
|
| 1702 |
+
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
|
| 1703 |
+
return res;
|
| 1704 |
+
}
|
| 1705 |
+
|
| 1706 |
+
static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
| 1707 |
+
llama_file file(filepath, "rb");
|
| 1708 |
+
|
| 1709 |
+
// version checks
|
| 1710 |
+
{
|
| 1711 |
+
const uint32_t magic = file.read_u32();
|
| 1712 |
+
const uint32_t version = file.read_u32();
|
| 1713 |
+
|
| 1714 |
+
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
| 1715 |
+
LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
|
| 1716 |
+
return 0;
|
| 1717 |
+
}
|
| 1718 |
+
}
|
| 1719 |
+
|
| 1720 |
+
// load the prompt
|
| 1721 |
+
{
|
| 1722 |
+
const uint32_t n_token_count = file.read_u32();
|
| 1723 |
+
|
| 1724 |
+
if (n_token_count > n_token_capacity) {
|
| 1725 |
+
LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
| 1726 |
+
return 0;
|
| 1727 |
+
}
|
| 1728 |
+
|
| 1729 |
+
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
| 1730 |
+
*n_token_count_out = n_token_count;
|
| 1731 |
+
}
|
| 1732 |
+
|
| 1733 |
+
// restore the context state
|
| 1734 |
+
{
|
| 1735 |
+
const size_t state_size = file.size() - file.tell();
|
| 1736 |
+
llama_data_read_file data_ctx(&file);
|
| 1737 |
+
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
| 1738 |
+
if (!nread) {
|
| 1739 |
+
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
| 1740 |
+
return 0;
|
| 1741 |
+
}
|
| 1742 |
+
GGML_ASSERT(nread <= state_size);
|
| 1743 |
+
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
|
| 1744 |
+
}
|
| 1745 |
+
|
| 1746 |
+
return file.tell();
|
| 1747 |
+
}
|
| 1748 |
+
|
| 1749 |
+
size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
| 1750 |
+
try {
|
| 1751 |
+
return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
|
| 1752 |
+
} catch (const std::exception & err) {
|
| 1753 |
+
LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
|
| 1754 |
+
return 0;
|
| 1755 |
+
}
|
| 1756 |
+
}
|
| 1757 |
+
|
| 1758 |
+
size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
| 1759 |
+
try {
|
| 1760 |
+
return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
|
| 1761 |
+
} catch (const std::exception & err) {
|
| 1762 |
+
LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
|
| 1763 |
+
return 0;
|
| 1764 |
+
}
|
| 1765 |
+
}
|
| 1766 |
+
|
| 1767 |
+
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
| 1768 |
+
struct llama_context * ctx
|
| 1769 |
+
) {
|
| 1770 |
+
return ctx->model.tensors_by_name;
|
| 1771 |
+
}
|
examples/talk-llama/llama-context.h
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
#include "llama-batch.h"
|
| 5 |
+
#include "llama-cparams.h"
|
| 6 |
+
#include "llama-model.h"
|
| 7 |
+
#include "llama-kv-cache.h"
|
| 8 |
+
#include "llama-adapter.h"
|
| 9 |
+
|
| 10 |
+
#include "ggml-cpp.h"
|
| 11 |
+
|
| 12 |
+
#include <map>
|
| 13 |
+
#include <unordered_map>
|
| 14 |
+
#include <vector>
|
| 15 |
+
#include <set>
|
| 16 |
+
|
| 17 |
+
struct llama_context {
|
| 18 |
+
llama_context(const llama_model & model)
|
| 19 |
+
: model(model)
|
| 20 |
+
, t_start_us(model.t_start_us)
|
| 21 |
+
, t_load_us(model.t_load_us) {}
|
| 22 |
+
|
| 23 |
+
const struct llama_model & model;
|
| 24 |
+
|
| 25 |
+
struct llama_cparams cparams;
|
| 26 |
+
struct llama_sbatch sbatch; // TODO: revisit if needed
|
| 27 |
+
struct llama_kv_cache kv_self;
|
| 28 |
+
struct llama_control_vector cvec;
|
| 29 |
+
|
| 30 |
+
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
|
| 31 |
+
|
| 32 |
+
std::vector<ggml_backend_ptr> backends;
|
| 33 |
+
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
| 34 |
+
|
| 35 |
+
ggml_backend_t backend_cpu = nullptr;
|
| 36 |
+
|
| 37 |
+
ggml_threadpool_t threadpool = nullptr;
|
| 38 |
+
ggml_threadpool_t threadpool_batch = nullptr;
|
| 39 |
+
|
| 40 |
+
bool has_evaluated_once = false;
|
| 41 |
+
|
| 42 |
+
mutable int64_t t_start_us;
|
| 43 |
+
mutable int64_t t_load_us;
|
| 44 |
+
mutable int64_t t_p_eval_us = 0;
|
| 45 |
+
mutable int64_t t_eval_us = 0;
|
| 46 |
+
|
| 47 |
+
mutable int64_t t_compute_start_us = 0;
|
| 48 |
+
mutable int64_t n_queued_tokens = 0;
|
| 49 |
+
|
| 50 |
+
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
| 51 |
+
mutable int32_t n_eval = 0; // number of eval calls
|
| 52 |
+
|
| 53 |
+
// host buffer for the model output (logits and embeddings)
|
| 54 |
+
ggml_backend_buffer_ptr buf_output;
|
| 55 |
+
|
| 56 |
+
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
| 57 |
+
size_t logits_size = 0; // capacity (of floats) for logits
|
| 58 |
+
float * logits = nullptr;
|
| 59 |
+
|
| 60 |
+
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
| 61 |
+
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
| 62 |
+
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
| 63 |
+
|
| 64 |
+
bool logits_all = false;
|
| 65 |
+
|
| 66 |
+
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
| 67 |
+
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
| 68 |
+
size_t embd_size = 0; // capacity (of floats) for embeddings
|
| 69 |
+
float * embd = nullptr;
|
| 70 |
+
|
| 71 |
+
// sequence embeddings output (map of [n_embd] vectors)
|
| 72 |
+
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
| 73 |
+
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
| 74 |
+
|
| 75 |
+
// whether we are computing encoder output or decoder output
|
| 76 |
+
bool is_encoding = false;
|
| 77 |
+
|
| 78 |
+
// TODO: find a better way to accommodate mutli-dimension position encoding methods
|
| 79 |
+
// number of position id each token get, 1 for each token in most cases.
|
| 80 |
+
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
|
| 81 |
+
int n_pos_per_token = 1;
|
| 82 |
+
|
| 83 |
+
// output of the encoder part of the encoder-decoder models
|
| 84 |
+
std::vector<float> embd_enc;
|
| 85 |
+
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
| 86 |
+
|
| 87 |
+
// memory buffers used to evaluate the model
|
| 88 |
+
std::vector<uint8_t> buf_compute_meta;
|
| 89 |
+
ggml_backend_sched_ptr sched;
|
| 90 |
+
|
| 91 |
+
ggml_abort_callback abort_callback = nullptr;
|
| 92 |
+
void * abort_callback_data = nullptr;
|
| 93 |
+
|
| 94 |
+
// input tensors
|
| 95 |
+
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
| 96 |
+
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
| 97 |
+
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
| 98 |
+
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
|
| 99 |
+
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
|
| 100 |
+
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
|
| 101 |
+
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
| 102 |
+
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
| 103 |
+
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
| 104 |
+
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
| 105 |
+
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
| 106 |
+
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
|
| 107 |
+
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
| 108 |
+
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
| 109 |
+
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
// TODO: make these methods of llama_context
|
| 113 |
+
void llama_set_k_shift(struct llama_context & lctx);
|
| 114 |
+
|
| 115 |
+
void llama_set_s_copy(struct llama_context & lctx);
|
| 116 |
+
|
| 117 |
+
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
|
| 118 |
+
|
| 119 |
+
// Make sure enough space is available for outputs.
|
| 120 |
+
// Returns max number of outputs for which space was reserved.
|
| 121 |
+
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
|
| 122 |
+
|
| 123 |
+
// make the outputs have the same order they had in the user-provided batch
|
| 124 |
+
void llama_output_reorder(struct llama_context & ctx);
|
| 125 |
+
|
| 126 |
+
// For internal test use
|
| 127 |
+
// TODO: remove
|
| 128 |
+
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
|
examples/talk-llama/llama-cparams.cpp
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include "llama-cparams.h"
|
examples/talk-llama/llama-cparams.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
|
| 7 |
+
struct llama_cparams {
|
| 8 |
+
uint32_t n_ctx; // context size used during inference
|
| 9 |
+
uint32_t n_batch;
|
| 10 |
+
uint32_t n_ubatch;
|
| 11 |
+
uint32_t n_seq_max;
|
| 12 |
+
int n_threads; // number of threads to use for generation
|
| 13 |
+
int n_threads_batch; // number of threads to use for batch processing
|
| 14 |
+
|
| 15 |
+
float rope_freq_base;
|
| 16 |
+
float rope_freq_scale;
|
| 17 |
+
|
| 18 |
+
uint32_t n_ctx_orig_yarn;
|
| 19 |
+
// These hyperparameters are not exposed in GGUF, because all
|
| 20 |
+
// existing YaRN models use the same values for them.
|
| 21 |
+
float yarn_ext_factor;
|
| 22 |
+
float yarn_attn_factor;
|
| 23 |
+
float yarn_beta_fast;
|
| 24 |
+
float yarn_beta_slow;
|
| 25 |
+
float defrag_thold;
|
| 26 |
+
|
| 27 |
+
bool embeddings;
|
| 28 |
+
bool causal_attn;
|
| 29 |
+
bool offload_kqv;
|
| 30 |
+
bool flash_attn;
|
| 31 |
+
bool no_perf;
|
| 32 |
+
|
| 33 |
+
enum llama_pooling_type pooling_type;
|
| 34 |
+
|
| 35 |
+
ggml_backend_sched_eval_callback cb_eval;
|
| 36 |
+
void * cb_eval_user_data;
|
| 37 |
+
};
|
examples/talk-llama/llama-grammar.cpp
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
#include "llama-grammar.h"
|
| 2 |
|
|
|
|
| 3 |
#include "llama-vocab.h"
|
| 4 |
#include "llama-sampling.h"
|
| 5 |
|
|
@@ -822,15 +823,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
|
| 822 |
return grammar->stacks;
|
| 823 |
}
|
| 824 |
|
| 825 |
-
void llama_grammar_accept(
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
const uint32_t chr,
|
| 829 |
-
llama_grammar_stacks & stacks_new) {
|
| 830 |
-
stacks_new.clear();
|
| 831 |
-
stacks_new.reserve(stacks.size());
|
| 832 |
|
| 833 |
-
for (const auto & stack : stacks) {
|
| 834 |
if (stack.empty()) {
|
| 835 |
continue;
|
| 836 |
}
|
|
@@ -844,9 +841,11 @@ void llama_grammar_accept(
|
|
| 844 |
if (!llama_grammar_is_end_of_sequence(pos)) {
|
| 845 |
new_stack.push_back(pos);
|
| 846 |
}
|
| 847 |
-
llama_grammar_advance_stack(rules, new_stack, stacks_new);
|
| 848 |
}
|
| 849 |
}
|
|
|
|
|
|
|
| 850 |
}
|
| 851 |
|
| 852 |
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|
@@ -1051,7 +1050,12 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
|
| 1051 |
}
|
| 1052 |
|
| 1053 |
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
| 1054 |
-
llama_grammar * result = new llama_grammar {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
// redirect elements in stacks to point to new rules
|
| 1057 |
for (size_t is = 0; is < result->stacks.size(); is++) {
|
|
@@ -1059,7 +1063,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
|
| 1059 |
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
|
| 1060 |
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
|
| 1061 |
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
|
| 1062 |
-
|
| 1063 |
}
|
| 1064 |
}
|
| 1065 |
}
|
|
@@ -1126,11 +1130,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|
| 1126 |
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
| 1127 |
const auto & code_points = decoded.first;
|
| 1128 |
|
| 1129 |
-
llama_grammar_stacks stacks_new;
|
| 1130 |
-
|
| 1131 |
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
| 1132 |
-
llama_grammar_accept(grammar
|
| 1133 |
-
grammar.stacks = std::move(stacks_new);
|
| 1134 |
}
|
| 1135 |
|
| 1136 |
grammar.partial_utf8 = decoded.second;
|
|
|
|
| 1 |
#include "llama-grammar.h"
|
| 2 |
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
#include "llama-vocab.h"
|
| 5 |
#include "llama-sampling.h"
|
| 6 |
|
|
|
|
| 823 |
return grammar->stacks;
|
| 824 |
}
|
| 825 |
|
| 826 |
+
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
|
| 827 |
+
llama_grammar_stacks stacks_new;
|
| 828 |
+
stacks_new.reserve(grammar->stacks.size());
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
|
| 830 |
+
for (const auto & stack : grammar->stacks) {
|
| 831 |
if (stack.empty()) {
|
| 832 |
continue;
|
| 833 |
}
|
|
|
|
| 841 |
if (!llama_grammar_is_end_of_sequence(pos)) {
|
| 842 |
new_stack.push_back(pos);
|
| 843 |
}
|
| 844 |
+
llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
|
| 845 |
}
|
| 846 |
}
|
| 847 |
+
|
| 848 |
+
grammar->stacks = std::move(stacks_new);
|
| 849 |
}
|
| 850 |
|
| 851 |
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|
|
|
| 1050 |
}
|
| 1051 |
|
| 1052 |
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
| 1053 |
+
llama_grammar * result = new llama_grammar {
|
| 1054 |
+
grammar.vocab,
|
| 1055 |
+
grammar.rules,
|
| 1056 |
+
grammar.stacks,
|
| 1057 |
+
grammar.partial_utf8,
|
| 1058 |
+
};
|
| 1059 |
|
| 1060 |
// redirect elements in stacks to point to new rules
|
| 1061 |
for (size_t is = 0; is < result->stacks.size(); is++) {
|
|
|
|
| 1063 |
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
|
| 1064 |
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
|
| 1065 |
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
|
| 1066 |
+
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
| 1067 |
}
|
| 1068 |
}
|
| 1069 |
}
|
|
|
|
| 1130 |
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
| 1131 |
const auto & code_points = decoded.first;
|
| 1132 |
|
|
|
|
|
|
|
| 1133 |
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
| 1134 |
+
llama_grammar_accept(&grammar, *it);
|
|
|
|
| 1135 |
}
|
| 1136 |
|
| 1137 |
grammar.partial_utf8 = decoded.second;
|
examples/talk-llama/llama-grammar.h
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
-
#include "llama
|
| 4 |
|
| 5 |
#include <map>
|
|
|
|
|
|
|
| 6 |
|
| 7 |
struct llama_vocab;
|
| 8 |
|
|
@@ -58,6 +60,7 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
|
| 58 |
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
| 59 |
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
| 60 |
|
|
|
|
| 61 |
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
| 62 |
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
| 63 |
|
|
@@ -65,11 +68,7 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar
|
|
| 65 |
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
| 66 |
// produces the N possible stacks if the given char is accepted at those
|
| 67 |
// positions
|
| 68 |
-
void llama_grammar_accept(
|
| 69 |
-
const llama_grammar_rules & rules,
|
| 70 |
-
const llama_grammar_stacks & stacks,
|
| 71 |
-
uint32_t chr,
|
| 72 |
-
llama_grammar_stacks & stacks_new);
|
| 73 |
|
| 74 |
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
| 75 |
const llama_grammar_rules & rules,
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
+
#include "llama.h"
|
| 4 |
|
| 5 |
#include <map>
|
| 6 |
+
#include <string>
|
| 7 |
+
#include <vector>
|
| 8 |
|
| 9 |
struct llama_vocab;
|
| 10 |
|
|
|
|
| 60 |
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
| 61 |
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
| 62 |
|
| 63 |
+
// TODO: remove, needed for tests atm
|
| 64 |
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
| 65 |
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
| 66 |
|
|
|
|
| 68 |
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
| 69 |
// produces the N possible stacks if the given char is accepted at those
|
| 70 |
// positions
|
| 71 |
+
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
| 74 |
const llama_grammar_rules & rules,
|
examples/talk-llama/llama-hparams.cpp
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-hparams.h"
|
| 2 |
+
|
| 3 |
+
#include "ggml.h"
|
| 4 |
+
|
| 5 |
+
uint32_t llama_hparams::n_head(uint32_t il) const {
|
| 6 |
+
if (il < n_layer) {
|
| 7 |
+
return n_head_arr[il];
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
GGML_ABORT("fatal error");
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
uint32_t llama_hparams::n_head_kv(uint32_t il) const {
|
| 14 |
+
if (il < n_layer) {
|
| 15 |
+
return n_head_kv_arr[il];
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
GGML_ABORT("fatal error");
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
uint32_t llama_hparams::n_ff(uint32_t il) const {
|
| 22 |
+
if (il < n_layer) {
|
| 23 |
+
return n_ff_arr[il];
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
GGML_ABORT("fatal error");
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
uint32_t llama_hparams::n_gqa(uint32_t il) const {
|
| 30 |
+
const uint32_t n_head = this->n_head(il);
|
| 31 |
+
const uint32_t n_head_kv = this->n_head_kv(il);
|
| 32 |
+
|
| 33 |
+
if (n_head_kv == 0) {
|
| 34 |
+
return 0;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
return n_head/n_head_kv;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
|
| 41 |
+
const uint32_t n_head_kv = this->n_head_kv(il);
|
| 42 |
+
|
| 43 |
+
return n_embd_head_k * n_head_kv;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
| 47 |
+
const uint32_t n_head_kv = this->n_head_kv(il);
|
| 48 |
+
|
| 49 |
+
return n_embd_head_v * n_head_kv;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
uint32_t llama_hparams::n_embd_k_s() const {
|
| 53 |
+
if (wkv_head_size != 0) {
|
| 54 |
+
// for RWKV models
|
| 55 |
+
return 2 * n_embd;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// TODO: maybe support other convolution strides than 1
|
| 59 |
+
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
| 60 |
+
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
uint32_t llama_hparams::n_embd_v_s() const {
|
| 64 |
+
if (wkv_head_size != 0) {
|
| 65 |
+
// corresponds to RWKV's wkv_states size
|
| 66 |
+
return n_embd * wkv_head_size;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// corresponds to Mamba's ssm_states size
|
| 70 |
+
return ssm_d_state * ssm_d_inner;
|
| 71 |
+
}
|
examples/talk-llama/llama-hparams.h
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
|
| 5 |
+
#include <array>
|
| 6 |
+
|
| 7 |
+
// bump if necessary
|
| 8 |
+
#define LLAMA_MAX_LAYERS 512
|
| 9 |
+
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3
|
| 10 |
+
|
| 11 |
+
enum llama_expert_gating_func_type {
|
| 12 |
+
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
| 13 |
+
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
|
| 14 |
+
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
struct llama_hparams_posnet {
|
| 18 |
+
uint32_t n_embd;
|
| 19 |
+
uint32_t n_layer;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
struct llama_hparams_convnext {
|
| 23 |
+
uint32_t n_embd;
|
| 24 |
+
uint32_t n_layer;
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
struct llama_hparams {
|
| 28 |
+
bool vocab_only;
|
| 29 |
+
bool rope_finetuned;
|
| 30 |
+
bool use_par_res;
|
| 31 |
+
bool swin_norm;
|
| 32 |
+
|
| 33 |
+
uint32_t n_vocab = 0;
|
| 34 |
+
uint32_t n_ctx_train; // context size the model was trained on
|
| 35 |
+
uint32_t n_embd;
|
| 36 |
+
uint32_t n_embd_features = 0;
|
| 37 |
+
uint32_t n_layer;
|
| 38 |
+
uint32_t n_rot;
|
| 39 |
+
uint32_t n_swa = 0; // sliding window attention (SWA)
|
| 40 |
+
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
| 41 |
+
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
| 42 |
+
uint32_t n_expert = 0;
|
| 43 |
+
uint32_t n_expert_used = 0;
|
| 44 |
+
uint32_t n_vocab_type = 0; // for BERT-style token types
|
| 45 |
+
uint32_t n_rel_attn_bkts = 0;
|
| 46 |
+
|
| 47 |
+
// for WavTokenizer
|
| 48 |
+
struct llama_hparams_posnet posnet;
|
| 49 |
+
struct llama_hparams_convnext convnext;
|
| 50 |
+
|
| 51 |
+
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
|
| 52 |
+
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
| 53 |
+
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
| 54 |
+
|
| 55 |
+
uint32_t n_layer_dense_lead = 0;
|
| 56 |
+
uint32_t n_lora_q = 0;
|
| 57 |
+
uint32_t n_lora_kv = 0;
|
| 58 |
+
uint32_t n_ff_exp = 0;
|
| 59 |
+
uint32_t n_ff_shexp = 0;
|
| 60 |
+
uint32_t n_expert_shared = 0;
|
| 61 |
+
uint32_t n_norm_groups = 0;
|
| 62 |
+
|
| 63 |
+
float expert_weights_scale = 0.0;
|
| 64 |
+
bool expert_weights_norm = false;
|
| 65 |
+
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
|
| 66 |
+
|
| 67 |
+
float f_norm_eps;
|
| 68 |
+
float f_norm_rms_eps;
|
| 69 |
+
float f_norm_group_eps;
|
| 70 |
+
|
| 71 |
+
float f_attn_logit_softcapping = 50.0f;
|
| 72 |
+
float f_final_logit_softcapping = 30.0f;
|
| 73 |
+
|
| 74 |
+
// for RWKV
|
| 75 |
+
uint32_t rescale_every_n_layers = 0;
|
| 76 |
+
uint32_t time_mix_extra_dim = 0;
|
| 77 |
+
uint32_t time_decay_extra_dim = 0;
|
| 78 |
+
uint32_t wkv_head_size = 0;
|
| 79 |
+
|
| 80 |
+
float rope_attn_factor = 1.0f;
|
| 81 |
+
float rope_freq_base_train;
|
| 82 |
+
float rope_freq_scale_train;
|
| 83 |
+
uint32_t n_ctx_orig_yarn;
|
| 84 |
+
float rope_yarn_log_mul;
|
| 85 |
+
|
| 86 |
+
std::array<int, 4> rope_sections;
|
| 87 |
+
|
| 88 |
+
// for State Space Models
|
| 89 |
+
uint32_t ssm_d_conv = 0;
|
| 90 |
+
uint32_t ssm_d_inner = 0;
|
| 91 |
+
uint32_t ssm_d_state = 0;
|
| 92 |
+
uint32_t ssm_dt_rank = 0;
|
| 93 |
+
|
| 94 |
+
bool ssm_dt_b_c_rms = false;
|
| 95 |
+
|
| 96 |
+
float f_clamp_kqv = 0.0f;
|
| 97 |
+
float f_max_alibi_bias = 0.0f;
|
| 98 |
+
float f_logit_scale = 0.0f;
|
| 99 |
+
|
| 100 |
+
// Additional scale factors (Granite/Granite MoE)
|
| 101 |
+
float f_residual_scale = 0.0f;
|
| 102 |
+
float f_embedding_scale = 0.0f;
|
| 103 |
+
float f_attention_scale = 0.0f;
|
| 104 |
+
|
| 105 |
+
bool causal_attn = true;
|
| 106 |
+
bool use_alibi = false;
|
| 107 |
+
bool attn_soft_cap = false;
|
| 108 |
+
|
| 109 |
+
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
| 110 |
+
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
| 111 |
+
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
| 112 |
+
|
| 113 |
+
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
| 114 |
+
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
| 115 |
+
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
| 116 |
+
|
| 117 |
+
uint32_t n_head(uint32_t il = 0) const;
|
| 118 |
+
|
| 119 |
+
uint32_t n_head_kv(uint32_t il = 0) const;
|
| 120 |
+
|
| 121 |
+
uint32_t n_ff(uint32_t il = 0) const;
|
| 122 |
+
|
| 123 |
+
uint32_t n_gqa(uint32_t il = 0) const;
|
| 124 |
+
|
| 125 |
+
// dimension of key embeddings across all k-v heads
|
| 126 |
+
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
|
| 127 |
+
|
| 128 |
+
// dimension of value embeddings across all k-v heads
|
| 129 |
+
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
|
| 130 |
+
|
| 131 |
+
// dimension of the rolling state embeddings
|
| 132 |
+
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
| 133 |
+
uint32_t n_embd_k_s() const;
|
| 134 |
+
|
| 135 |
+
// dimension of the recurrent state embeddings
|
| 136 |
+
uint32_t n_embd_v_s() const;
|
| 137 |
+
};
|
| 138 |
+
|
| 139 |
+
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
| 140 |
+
|
examples/talk-llama/llama-impl.cpp
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-impl.h"
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
|
| 5 |
+
#include <cinttypes>
|
| 6 |
+
#include <climits>
|
| 7 |
+
#include <cstdarg>
|
| 8 |
+
#include <cstring>
|
| 9 |
+
#include <vector>
|
| 10 |
+
#include <sstream>
|
| 11 |
+
|
| 12 |
+
struct llama_logger_state {
|
| 13 |
+
ggml_log_callback log_callback = llama_log_callback_default;
|
| 14 |
+
void * log_callback_user_data = nullptr;
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
static llama_logger_state g_logger_state;
|
| 18 |
+
|
| 19 |
+
time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
| 20 |
+
|
| 21 |
+
time_meas::~time_meas() {
|
| 22 |
+
if (t_start_us >= 0) {
|
| 23 |
+
t_acc += ggml_time_us() - t_start_us;
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
void llama_log_set(ggml_log_callback log_callback, void * user_data) {
|
| 28 |
+
ggml_log_set(log_callback, user_data);
|
| 29 |
+
g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
|
| 30 |
+
g_logger_state.log_callback_user_data = user_data;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {
|
| 34 |
+
va_list args_copy;
|
| 35 |
+
va_copy(args_copy, args);
|
| 36 |
+
char buffer[128];
|
| 37 |
+
int len = vsnprintf(buffer, 128, format, args);
|
| 38 |
+
if (len < 128) {
|
| 39 |
+
g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
|
| 40 |
+
} else {
|
| 41 |
+
char * buffer2 = new char[len + 1];
|
| 42 |
+
vsnprintf(buffer2, len + 1, format, args_copy);
|
| 43 |
+
buffer2[len] = 0;
|
| 44 |
+
g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
|
| 45 |
+
delete[] buffer2;
|
| 46 |
+
}
|
| 47 |
+
va_end(args_copy);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
void llama_log_internal(ggml_log_level level, const char * format, ...) {
|
| 51 |
+
va_list args;
|
| 52 |
+
va_start(args, format);
|
| 53 |
+
llama_log_internal_v(level, format, args);
|
| 54 |
+
va_end(args);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
|
| 58 |
+
(void) level;
|
| 59 |
+
(void) user_data;
|
| 60 |
+
fputs(text, stderr);
|
| 61 |
+
fflush(stderr);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
| 65 |
+
if (search.empty()) {
|
| 66 |
+
return;
|
| 67 |
+
}
|
| 68 |
+
std::string builder;
|
| 69 |
+
builder.reserve(s.length());
|
| 70 |
+
size_t pos = 0;
|
| 71 |
+
size_t last_pos = 0;
|
| 72 |
+
while ((pos = s.find(search, last_pos)) != std::string::npos) {
|
| 73 |
+
builder.append(s, last_pos, pos - last_pos);
|
| 74 |
+
builder.append(replace);
|
| 75 |
+
last_pos = pos + search.length();
|
| 76 |
+
}
|
| 77 |
+
builder.append(s, last_pos, std::string::npos);
|
| 78 |
+
s = std::move(builder);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
std::string format(const char * fmt, ...) {
|
| 82 |
+
va_list ap;
|
| 83 |
+
va_list ap2;
|
| 84 |
+
va_start(ap, fmt);
|
| 85 |
+
va_copy(ap2, ap);
|
| 86 |
+
int size = vsnprintf(NULL, 0, fmt, ap);
|
| 87 |
+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
| 88 |
+
std::vector<char> buf(size + 1);
|
| 89 |
+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
| 90 |
+
GGML_ASSERT(size2 == size);
|
| 91 |
+
va_end(ap2);
|
| 92 |
+
va_end(ap);
|
| 93 |
+
return std::string(buf.data(), size);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
std::string llama_format_tensor_shape(const std::vector<int64_t> & ne) {
|
| 97 |
+
char buf[256];
|
| 98 |
+
snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
|
| 99 |
+
for (size_t i = 1; i < ne.size(); i++) {
|
| 100 |
+
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
|
| 101 |
+
}
|
| 102 |
+
return buf;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
|
| 106 |
+
char buf[256];
|
| 107 |
+
snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
|
| 108 |
+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
| 109 |
+
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]);
|
| 110 |
+
}
|
| 111 |
+
return buf;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
|
| 115 |
+
switch (type) {
|
| 116 |
+
case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]);
|
| 117 |
+
case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]);
|
| 118 |
+
case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]);
|
| 119 |
+
case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]);
|
| 120 |
+
case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]);
|
| 121 |
+
case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]);
|
| 122 |
+
case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]);
|
| 123 |
+
case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]);
|
| 124 |
+
case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]);
|
| 125 |
+
case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]);
|
| 126 |
+
case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false";
|
| 127 |
+
default: return format("unknown type %d", type);
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
|
| 132 |
+
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
|
| 133 |
+
|
| 134 |
+
switch (type) {
|
| 135 |
+
case GGUF_TYPE_STRING:
|
| 136 |
+
return gguf_get_val_str(ctx_gguf, i);
|
| 137 |
+
case GGUF_TYPE_ARRAY:
|
| 138 |
+
{
|
| 139 |
+
const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
|
| 140 |
+
int arr_n = gguf_get_arr_n(ctx_gguf, i);
|
| 141 |
+
const void * data = gguf_get_arr_data(ctx_gguf, i);
|
| 142 |
+
std::stringstream ss;
|
| 143 |
+
ss << "[";
|
| 144 |
+
for (int j = 0; j < arr_n; j++) {
|
| 145 |
+
if (arr_type == GGUF_TYPE_STRING) {
|
| 146 |
+
std::string val = gguf_get_arr_str(ctx_gguf, i, j);
|
| 147 |
+
// escape quotes
|
| 148 |
+
replace_all(val, "\\", "\\\\");
|
| 149 |
+
replace_all(val, "\"", "\\\"");
|
| 150 |
+
ss << '"' << val << '"';
|
| 151 |
+
} else if (arr_type == GGUF_TYPE_ARRAY) {
|
| 152 |
+
ss << "???";
|
| 153 |
+
} else {
|
| 154 |
+
ss << gguf_data_to_str(arr_type, data, j);
|
| 155 |
+
}
|
| 156 |
+
if (j < arr_n - 1) {
|
| 157 |
+
ss << ", ";
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
ss << "]";
|
| 161 |
+
return ss.str();
|
| 162 |
+
}
|
| 163 |
+
default:
|
| 164 |
+
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
|
| 165 |
+
}
|
| 166 |
+
}
|
examples/talk-llama/llama-impl.h
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
-
#include "
|
| 4 |
|
| 5 |
#include <string>
|
| 6 |
#include <vector>
|
| 7 |
-
#include <stdexcept>
|
| 8 |
|
| 9 |
#ifdef __GNUC__
|
| 10 |
#ifdef __MINGW32__
|
|
@@ -35,147 +34,28 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
|
|
| 35 |
// helpers
|
| 36 |
//
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
}
|
| 45 |
-
}
|
| 46 |
|
| 47 |
const int64_t t_start_us;
|
| 48 |
|
| 49 |
int64_t & t_acc;
|
| 50 |
};
|
| 51 |
|
| 52 |
-
|
| 53 |
-
if (search.empty()) {
|
| 54 |
-
return;
|
| 55 |
-
}
|
| 56 |
-
std::string builder;
|
| 57 |
-
builder.reserve(s.length());
|
| 58 |
-
size_t pos = 0;
|
| 59 |
-
size_t last_pos = 0;
|
| 60 |
-
while ((pos = s.find(search, last_pos)) != std::string::npos) {
|
| 61 |
-
builder.append(s, last_pos, pos - last_pos);
|
| 62 |
-
builder.append(replace);
|
| 63 |
-
last_pos = pos + search.length();
|
| 64 |
-
}
|
| 65 |
-
builder.append(s, last_pos, std::string::npos);
|
| 66 |
-
s = std::move(builder);
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
| 70 |
-
struct llama_context * ctx
|
| 71 |
-
);
|
| 72 |
-
|
| 73 |
-
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
| 74 |
-
template<typename T>
|
| 75 |
-
struct ring_buffer {
|
| 76 |
-
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
| 77 |
-
|
| 78 |
-
T & front() {
|
| 79 |
-
if (sz == 0) {
|
| 80 |
-
throw std::runtime_error("ring buffer is empty");
|
| 81 |
-
}
|
| 82 |
-
return data[first];
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
-
const T & front() const {
|
| 86 |
-
if (sz == 0) {
|
| 87 |
-
throw std::runtime_error("ring buffer is empty");
|
| 88 |
-
}
|
| 89 |
-
return data[first];
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
T & back() {
|
| 93 |
-
if (sz == 0) {
|
| 94 |
-
throw std::runtime_error("ring buffer is empty");
|
| 95 |
-
}
|
| 96 |
-
return data[pos];
|
| 97 |
-
}
|
| 98 |
-
|
| 99 |
-
const T & back() const {
|
| 100 |
-
if (sz == 0) {
|
| 101 |
-
throw std::runtime_error("ring buffer is empty");
|
| 102 |
-
}
|
| 103 |
-
return data[pos];
|
| 104 |
-
}
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
}
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
first = (first + 1) % capacity;
|
| 114 |
-
} else {
|
| 115 |
-
sz++;
|
| 116 |
-
}
|
| 117 |
-
data[pos] = value;
|
| 118 |
-
pos = (pos + 1) % capacity;
|
| 119 |
-
}
|
| 120 |
|
| 121 |
-
|
| 122 |
-
if (sz == 0) {
|
| 123 |
-
throw std::runtime_error("ring buffer is empty");
|
| 124 |
-
}
|
| 125 |
-
T value = data[first];
|
| 126 |
-
first = (first + 1) % capacity;
|
| 127 |
-
sz--;
|
| 128 |
-
return value;
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
//T & operator[](size_t i) {
|
| 132 |
-
// if (i >= sz) {
|
| 133 |
-
// throw std::runtime_error("ring buffer: index out of bounds");
|
| 134 |
-
// }
|
| 135 |
-
// return data[(first + i) % capacity];
|
| 136 |
-
//}
|
| 137 |
-
|
| 138 |
-
//const T & at(size_t i) const {
|
| 139 |
-
// if (i >= sz) {
|
| 140 |
-
// throw std::runtime_error("ring buffer: index out of bounds");
|
| 141 |
-
// }
|
| 142 |
-
// return data[(first + i) % capacity];
|
| 143 |
-
//}
|
| 144 |
-
|
| 145 |
-
const T & rat(size_t i) const {
|
| 146 |
-
if (i >= sz) {
|
| 147 |
-
throw std::runtime_error("ring buffer: index out of bounds");
|
| 148 |
-
}
|
| 149 |
-
return data[(first + sz - i - 1) % capacity];
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
std::vector<T> to_vector() const {
|
| 153 |
-
std::vector<T> result;
|
| 154 |
-
result.reserve(sz);
|
| 155 |
-
for (size_t i = 0; i < sz; i++) {
|
| 156 |
-
result.push_back(data[(first + i) % capacity]);
|
| 157 |
-
}
|
| 158 |
-
return result;
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
void clear() {
|
| 162 |
-
// here only reset the status of the buffer
|
| 163 |
-
sz = 0;
|
| 164 |
-
first = 0;
|
| 165 |
-
pos = 0;
|
| 166 |
-
}
|
| 167 |
-
|
| 168 |
-
bool empty() const {
|
| 169 |
-
return sz == 0;
|
| 170 |
-
}
|
| 171 |
-
|
| 172 |
-
size_t size() const {
|
| 173 |
-
return sz;
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
size_t capacity = 0;
|
| 177 |
-
size_t sz = 0;
|
| 178 |
-
size_t first = 0;
|
| 179 |
-
size_t pos = 0;
|
| 180 |
-
std::vector<T> data;
|
| 181 |
-
};
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
+
#include "ggml.h" // for ggml_log_level
|
| 4 |
|
| 5 |
#include <string>
|
| 6 |
#include <vector>
|
|
|
|
| 7 |
|
| 8 |
#ifdef __GNUC__
|
| 9 |
#ifdef __MINGW32__
|
|
|
|
| 34 |
// helpers
|
| 35 |
//
|
| 36 |
|
| 37 |
+
template <typename T>
|
| 38 |
+
struct no_init {
|
| 39 |
+
T value;
|
| 40 |
+
no_init() { /* do nothing */ }
|
| 41 |
+
};
|
| 42 |
|
| 43 |
+
struct time_meas {
|
| 44 |
+
time_meas(int64_t & t_acc, bool disable = false);
|
| 45 |
+
~time_meas();
|
|
|
|
|
|
|
| 46 |
|
| 47 |
const int64_t t_start_us;
|
| 48 |
|
| 49 |
int64_t & t_acc;
|
| 50 |
};
|
| 51 |
|
| 52 |
+
void replace_all(std::string & s, const std::string & search, const std::string & replace);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
// TODO: rename to llama_format ?
|
| 55 |
+
LLAMA_ATTRIBUTE_FORMAT(1, 2)
|
| 56 |
+
std::string format(const char * fmt, ...);
|
|
|
|
| 57 |
|
| 58 |
+
std::string llama_format_tensor_shape(const std::vector<int64_t> & ne);
|
| 59 |
+
std::string llama_format_tensor_shape(const struct ggml_tensor * t);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/talk-llama/llama-kv-cache.cpp
ADDED
|
@@ -0,0 +1,718 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-kv-cache.h"
|
| 2 |
+
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
#include "llama-batch.h"
|
| 5 |
+
#include "llama-cparams.h"
|
| 6 |
+
#include "llama-model.h"
|
| 7 |
+
|
| 8 |
+
#include <algorithm>
|
| 9 |
+
#include <limits>
|
| 10 |
+
#include <map>
|
| 11 |
+
|
| 12 |
+
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
|
| 13 |
+
|
| 14 |
+
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
|
| 15 |
+
// the FA kernels require padding to avoid extra runtime boundary checks
|
| 16 |
+
return cparams.flash_attn ? 256u : 32u;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
bool llama_kv_cache_init(
|
| 20 |
+
struct llama_kv_cache & cache,
|
| 21 |
+
const llama_model & model,
|
| 22 |
+
const llama_cparams & cparams,
|
| 23 |
+
ggml_type type_k,
|
| 24 |
+
ggml_type type_v,
|
| 25 |
+
uint32_t kv_size,
|
| 26 |
+
bool offload) {
|
| 27 |
+
const struct llama_hparams & hparams = model.hparams;
|
| 28 |
+
|
| 29 |
+
const int32_t n_layer = hparams.n_layer;
|
| 30 |
+
|
| 31 |
+
cache.has_shift = false;
|
| 32 |
+
|
| 33 |
+
cache.recurrent = llama_model_is_recurrent(&model);
|
| 34 |
+
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
|
| 35 |
+
cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
|
| 36 |
+
|
| 37 |
+
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
|
| 38 |
+
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift);
|
| 39 |
+
|
| 40 |
+
cache.head = 0;
|
| 41 |
+
cache.size = kv_size;
|
| 42 |
+
cache.used = 0;
|
| 43 |
+
|
| 44 |
+
cache.type_k = type_k;
|
| 45 |
+
cache.type_v = type_v;
|
| 46 |
+
|
| 47 |
+
cache.cells.clear();
|
| 48 |
+
cache.cells.resize(kv_size);
|
| 49 |
+
|
| 50 |
+
// create a context for each buffer type
|
| 51 |
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 52 |
+
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 53 |
+
auto it = ctx_map.find(buft);
|
| 54 |
+
if (it == ctx_map.end()) {
|
| 55 |
+
struct ggml_init_params params = {
|
| 56 |
+
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
|
| 57 |
+
/*.mem_buffer =*/ NULL,
|
| 58 |
+
/*.no_alloc =*/ true,
|
| 59 |
+
};
|
| 60 |
+
ggml_context * ctx = ggml_init(params);
|
| 61 |
+
if (!ctx) {
|
| 62 |
+
return nullptr;
|
| 63 |
+
}
|
| 64 |
+
ctx_map[buft] = ctx;
|
| 65 |
+
cache.ctxs.emplace_back(ctx);
|
| 66 |
+
return ctx;
|
| 67 |
+
}
|
| 68 |
+
return it->second;
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
cache.k_l.reserve(n_layer);
|
| 72 |
+
cache.v_l.reserve(n_layer);
|
| 73 |
+
|
| 74 |
+
for (int i = 0; i < n_layer; i++) {
|
| 75 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
| 76 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
| 77 |
+
|
| 78 |
+
LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
|
| 79 |
+
|
| 80 |
+
ggml_backend_buffer_type_t buft;
|
| 81 |
+
if (offload) {
|
| 82 |
+
auto * dev = model.dev_layer.at(i).dev;
|
| 83 |
+
buft = ggml_backend_dev_buffer_type(dev);
|
| 84 |
+
} else {
|
| 85 |
+
buft = ggml_backend_cpu_buffer_type();
|
| 86 |
+
}
|
| 87 |
+
ggml_context * ctx = ctx_for_buft(buft);
|
| 88 |
+
|
| 89 |
+
if (!ctx) {
|
| 90 |
+
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
|
| 91 |
+
return false;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
| 95 |
+
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
| 96 |
+
ggml_format_name(k, "cache_k_l%d", i);
|
| 97 |
+
ggml_format_name(v, "cache_v_l%d", i);
|
| 98 |
+
cache.k_l.push_back(k);
|
| 99 |
+
cache.v_l.push_back(v);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
| 103 |
+
for (auto it : ctx_map) {
|
| 104 |
+
auto * buft = it.first;
|
| 105 |
+
auto * ctx = it.second;
|
| 106 |
+
|
| 107 |
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
| 108 |
+
if (!buf) {
|
| 109 |
+
LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
|
| 110 |
+
return false;
|
| 111 |
+
}
|
| 112 |
+
ggml_backend_buffer_clear(buf, 0);
|
| 113 |
+
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
| 114 |
+
cache.bufs.emplace_back(buf);
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
return true;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
| 121 |
+
struct llama_kv_cache & cache,
|
| 122 |
+
const struct llama_ubatch & ubatch) {
|
| 123 |
+
const uint32_t n_tokens = ubatch.n_tokens;
|
| 124 |
+
const uint32_t n_seqs = ubatch.n_seqs;
|
| 125 |
+
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 126 |
+
|
| 127 |
+
if (cache.recurrent) {
|
| 128 |
+
// For recurrent state architectures (like Mamba or RWKV),
|
| 129 |
+
// each cache cell can store the state for a whole sequence.
|
| 130 |
+
// A slot should be always be contiguous.
|
| 131 |
+
|
| 132 |
+
// can only process batches with an equal number of new tokens in each sequence
|
| 133 |
+
GGML_ASSERT(ubatch.equal_seqs);
|
| 134 |
+
|
| 135 |
+
int32_t min = cache.size - 1;
|
| 136 |
+
int32_t max = 0;
|
| 137 |
+
|
| 138 |
+
// everything should fit if all seq_ids are smaller than the max
|
| 139 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 140 |
+
const uint32_t n_seq_id = ubatch.n_seq_id[s];
|
| 141 |
+
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 142 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
| 143 |
+
|
| 144 |
+
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
|
| 145 |
+
// too big seq_id
|
| 146 |
+
// TODO: would it be possible to resize the cache instead?
|
| 147 |
+
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
|
| 148 |
+
return llama_kv_cache_slot_info_failed;
|
| 149 |
+
}
|
| 150 |
+
if (j > 0) {
|
| 151 |
+
llama_kv_cell & seq = cache.cells[seq_id];
|
| 152 |
+
if (seq.tail >= 0) {
|
| 153 |
+
llama_kv_cell & cell = cache.cells[seq.tail];
|
| 154 |
+
// clear cells from seq_ids that become shared
|
| 155 |
+
// (should not normally happen, but let's handle it anyway)
|
| 156 |
+
cell.seq_id.erase(seq_id);
|
| 157 |
+
seq.tail = -1;
|
| 158 |
+
if (cell.seq_id.empty()) {
|
| 159 |
+
cell.pos = -1;
|
| 160 |
+
cell.src = -1;
|
| 161 |
+
cache.used -= 1;
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
#ifndef NDEBUG
|
| 169 |
+
{
|
| 170 |
+
std::vector<int32_t> tails_verif;
|
| 171 |
+
tails_verif.assign(cache.size, -1);
|
| 172 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 173 |
+
llama_kv_cell & cell = cache.cells[i];
|
| 174 |
+
for (llama_seq_id seq_id : cell.seq_id) {
|
| 175 |
+
if (tails_verif[seq_id] != -1) {
|
| 176 |
+
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
| 177 |
+
}
|
| 178 |
+
tails_verif[seq_id] = i;
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 182 |
+
if (tails_verif[i] != cache.cells[i].tail) {
|
| 183 |
+
LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
#endif
|
| 188 |
+
|
| 189 |
+
// find next empty cell
|
| 190 |
+
uint32_t next_empty_cell = cache.head;
|
| 191 |
+
|
| 192 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 193 |
+
if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
|
| 194 |
+
llama_kv_cell & cell = cache.cells[next_empty_cell];
|
| 195 |
+
if (cell.is_empty()) { break; }
|
| 196 |
+
next_empty_cell += 1;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
// find usable cell range
|
| 200 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 201 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 202 |
+
llama_kv_cell & seq_meta = cache.cells[seq_id];
|
| 203 |
+
bool has_cell = false;
|
| 204 |
+
if (seq_meta.tail >= 0) {
|
| 205 |
+
llama_kv_cell & cell = cache.cells[seq_meta.tail];
|
| 206 |
+
GGML_ASSERT(cell.has_seq_id(seq_id));
|
| 207 |
+
// does this seq_id "own" the cell?
|
| 208 |
+
if (cell.seq_id.size() == 1) { has_cell = true; }
|
| 209 |
+
}
|
| 210 |
+
if (!has_cell) {
|
| 211 |
+
llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
|
| 212 |
+
GGML_ASSERT(empty_cell.is_empty());
|
| 213 |
+
// copy old tail into the empty cell
|
| 214 |
+
if (seq_meta.tail >= 0) {
|
| 215 |
+
llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
|
| 216 |
+
empty_cell.pos = orig_cell.pos;
|
| 217 |
+
empty_cell.src = orig_cell.src;
|
| 218 |
+
orig_cell.seq_id.erase(seq_id);
|
| 219 |
+
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
| 220 |
+
}
|
| 221 |
+
seq_meta.tail = next_empty_cell;
|
| 222 |
+
// find next empty cell
|
| 223 |
+
if (s + 1 < n_seqs) {
|
| 224 |
+
next_empty_cell += 1;
|
| 225 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 226 |
+
if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
|
| 227 |
+
llama_kv_cell & cell = cache.cells[next_empty_cell];
|
| 228 |
+
if (cell.is_empty()) { break; }
|
| 229 |
+
next_empty_cell += 1;
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
if (min > seq_meta.tail) { min = seq_meta.tail; }
|
| 234 |
+
if (max < seq_meta.tail) { max = seq_meta.tail; }
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
// gather and re-order
|
| 238 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 239 |
+
int32_t dst_id = s + min;
|
| 240 |
+
int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
|
| 241 |
+
if (dst_id != src_id) {
|
| 242 |
+
llama_kv_cell & dst_cell = cache.cells[dst_id];
|
| 243 |
+
llama_kv_cell & src_cell = cache.cells[src_id];
|
| 244 |
+
|
| 245 |
+
std::swap(dst_cell.pos, src_cell.pos);
|
| 246 |
+
std::swap(dst_cell.src, src_cell.src);
|
| 247 |
+
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
| 248 |
+
|
| 249 |
+
// swap tails (assuming they NEVER overlap)
|
| 250 |
+
for (const llama_seq_id seq_id : src_cell.seq_id) {
|
| 251 |
+
cache.cells[seq_id].tail = src_id;
|
| 252 |
+
}
|
| 253 |
+
for (const llama_seq_id seq_id : dst_cell.seq_id) {
|
| 254 |
+
cache.cells[seq_id].tail = dst_id;
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// update the pos of the used seqs
|
| 260 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 261 |
+
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
| 262 |
+
int32_t cell_id = s + min;
|
| 263 |
+
llama_kv_cell & cell = cache.cells[cell_id];
|
| 264 |
+
|
| 265 |
+
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
| 266 |
+
// What should happen when the pos backtracks or skips a value?
|
| 267 |
+
// Clearing the state mid-batch would require special-casing which isn't done.
|
| 268 |
+
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
| 269 |
+
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
|
| 270 |
+
}
|
| 271 |
+
cell.pos = last_pos;
|
| 272 |
+
cell.seq_id.clear();
|
| 273 |
+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
| 274 |
+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
| 275 |
+
cell.seq_id.insert(seq_id);
|
| 276 |
+
cache.cells[seq_id].tail = cell_id;
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
// allow getting the range of used cells, from head to head + n
|
| 281 |
+
cache.head = min;
|
| 282 |
+
cache.n = max - min + 1;
|
| 283 |
+
cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
|
| 284 |
+
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
| 285 |
+
|
| 286 |
+
// sanity check
|
| 287 |
+
return llama_kv_cache_slot_info(cache.n >= n_seqs);
|
| 288 |
+
}
|
| 289 |
+
// otherwise, one cell per token.
|
| 290 |
+
|
| 291 |
+
if (n_tokens > cache.size) {
|
| 292 |
+
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
|
| 293 |
+
return llama_kv_cache_slot_info_failed;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
uint32_t n_tested = 0;
|
| 297 |
+
|
| 298 |
+
while (true) {
|
| 299 |
+
if (cache.head + n_tokens > cache.size) {
|
| 300 |
+
n_tested += cache.size - cache.head;
|
| 301 |
+
cache.head = 0;
|
| 302 |
+
continue;
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
bool found = true;
|
| 306 |
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 307 |
+
if (cache.cells[cache.head + i].pos >= 0) {
|
| 308 |
+
found = false;
|
| 309 |
+
cache.head += i + 1;
|
| 310 |
+
n_tested += i + 1;
|
| 311 |
+
break;
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
if (found) {
|
| 316 |
+
break;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
if (n_tested >= cache.size) {
|
| 320 |
+
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 321 |
+
return llama_kv_cache_slot_info_failed;
|
| 322 |
+
}
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
for (uint32_t s = 0; s < n_seqs; s++) {
|
| 326 |
+
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
| 327 |
+
uint32_t k = s*n_seq_tokens + i;
|
| 328 |
+
cache.cells[cache.head + k].pos = ubatch.pos[k];
|
| 329 |
+
|
| 330 |
+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
|
| 331 |
+
cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
cache.used += n_tokens;
|
| 337 |
+
|
| 338 |
+
return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
| 342 |
+
for (uint32_t i = cache.size; i > 0; --i) {
|
| 343 |
+
const llama_kv_cell & cell = cache.cells[i - 1];
|
| 344 |
+
|
| 345 |
+
if (cell.pos >= 0 && !cell.is_empty()) {
|
| 346 |
+
return i;
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
return 0;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
void llama_kv_cache_clear(struct llama_kv_cache & cache) {
|
| 354 |
+
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
|
| 355 |
+
cache.cells[i].pos = -1;
|
| 356 |
+
cache.cells[i].seq_id.clear();
|
| 357 |
+
cache.cells[i].src = -1;
|
| 358 |
+
cache.cells[i].tail = -1;
|
| 359 |
+
}
|
| 360 |
+
cache.head = 0;
|
| 361 |
+
cache.used = 0;
|
| 362 |
+
|
| 363 |
+
for (auto & buf : cache.bufs) {
|
| 364 |
+
ggml_backend_buffer_clear(buf.get(), 0);
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
bool llama_kv_cache_seq_rm(
|
| 369 |
+
struct llama_kv_cache & cache,
|
| 370 |
+
llama_seq_id seq_id,
|
| 371 |
+
llama_pos p0,
|
| 372 |
+
llama_pos p1) {
|
| 373 |
+
uint32_t new_head = cache.size;
|
| 374 |
+
|
| 375 |
+
if (p0 < 0) p0 = 0;
|
| 376 |
+
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
| 377 |
+
|
| 378 |
+
// models like Mamba or RWKV can't have a state partially erased
|
| 379 |
+
if (cache.recurrent) {
|
| 380 |
+
if (seq_id >= (int64_t) cache.size) {
|
| 381 |
+
// could be fatal
|
| 382 |
+
return false;
|
| 383 |
+
}
|
| 384 |
+
if (0 <= seq_id) {
|
| 385 |
+
int32_t & tail_id = cache.cells[seq_id].tail;
|
| 386 |
+
if (tail_id >= 0) {
|
| 387 |
+
const llama_kv_cell & cell = cache.cells[tail_id];
|
| 388 |
+
// partial intersection is invalid
|
| 389 |
+
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
| 390 |
+
return false;
|
| 391 |
+
}
|
| 392 |
+
// invalidate tails which will be cleared
|
| 393 |
+
if (p0 <= cell.pos && cell.pos < p1) {
|
| 394 |
+
tail_id = -1;
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
} else {
|
| 398 |
+
// seq_id is negative, then the range should include everything or nothing
|
| 399 |
+
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
|
| 400 |
+
return false;
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 406 |
+
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
| 407 |
+
if (seq_id < 0) {
|
| 408 |
+
cache.cells[i].seq_id.clear();
|
| 409 |
+
} else if (cache.cells[i].has_seq_id(seq_id)) {
|
| 410 |
+
cache.cells[i].seq_id.erase(seq_id);
|
| 411 |
+
} else {
|
| 412 |
+
continue;
|
| 413 |
+
}
|
| 414 |
+
if (cache.cells[i].is_empty()) {
|
| 415 |
+
// keep count of the number of used cells
|
| 416 |
+
if (cache.cells[i].pos >= 0) cache.used--;
|
| 417 |
+
|
| 418 |
+
cache.cells[i].pos = -1;
|
| 419 |
+
cache.cells[i].src = -1;
|
| 420 |
+
if (new_head == cache.size) new_head = i;
|
| 421 |
+
}
|
| 422 |
+
}
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 426 |
+
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
|
| 427 |
+
|
| 428 |
+
return true;
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
void llama_kv_cache_seq_cp(
|
| 432 |
+
struct llama_kv_cache & cache,
|
| 433 |
+
llama_seq_id seq_id_src,
|
| 434 |
+
llama_seq_id seq_id_dst,
|
| 435 |
+
llama_pos p0,
|
| 436 |
+
llama_pos p1) {
|
| 437 |
+
if (p0 < 0) p0 = 0;
|
| 438 |
+
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
| 439 |
+
|
| 440 |
+
if (cache.recurrent) {
|
| 441 |
+
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
|
| 442 |
+
llama_kv_cell & tail_src = cache.cells[seq_id_src];
|
| 443 |
+
llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
|
| 444 |
+
if (tail_dst.tail >= 0) {
|
| 445 |
+
// clear destination seq_id if it wasn't empty
|
| 446 |
+
llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
|
| 447 |
+
|
| 448 |
+
cell_dst.seq_id.erase(seq_id_dst);
|
| 449 |
+
tail_dst.tail = -1;
|
| 450 |
+
if (cell_dst.seq_id.empty()) {
|
| 451 |
+
cell_dst.pos = -1;
|
| 452 |
+
cell_dst.delta = -1;
|
| 453 |
+
cell_dst.src = -1;
|
| 454 |
+
cache.used -= 1;
|
| 455 |
+
}
|
| 456 |
+
}
|
| 457 |
+
if (tail_src.tail >= 0) {
|
| 458 |
+
llama_kv_cell & cell_src = cache.cells[tail_src.tail];
|
| 459 |
+
|
| 460 |
+
cell_src.seq_id.insert(seq_id_dst);
|
| 461 |
+
tail_dst.tail = tail_src.tail;
|
| 462 |
+
}
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
return;
|
| 466 |
+
}
|
| 467 |
+
// otherwise, this is the KV cache of a Transformer-like model
|
| 468 |
+
|
| 469 |
+
cache.head = 0;
|
| 470 |
+
|
| 471 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 472 |
+
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
| 473 |
+
cache.cells[i].seq_id.insert(seq_id_dst);
|
| 474 |
+
}
|
| 475 |
+
}
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
| 479 |
+
uint32_t new_head = cache.size;
|
| 480 |
+
|
| 481 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 482 |
+
if (cache.recurrent && (llama_seq_id) i != seq_id) {
|
| 483 |
+
cache.cells[i].tail = -1;
|
| 484 |
+
}
|
| 485 |
+
if (!cache.cells[i].has_seq_id(seq_id)) {
|
| 486 |
+
if (cache.cells[i].pos >= 0) cache.used--;
|
| 487 |
+
cache.cells[i].pos = -1;
|
| 488 |
+
cache.cells[i].src = -1;
|
| 489 |
+
cache.cells[i].seq_id.clear();
|
| 490 |
+
if (new_head == cache.size) new_head = i;
|
| 491 |
+
} else {
|
| 492 |
+
cache.cells[i].seq_id.clear();
|
| 493 |
+
cache.cells[i].seq_id.insert(seq_id);
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 498 |
+
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
void llama_kv_cache_seq_add(
|
| 502 |
+
struct llama_kv_cache & cache,
|
| 503 |
+
llama_seq_id seq_id,
|
| 504 |
+
llama_pos p0,
|
| 505 |
+
llama_pos p1,
|
| 506 |
+
llama_pos delta) {
|
| 507 |
+
uint32_t new_head = cache.size;
|
| 508 |
+
|
| 509 |
+
if (p0 < 0) p0 = 0;
|
| 510 |
+
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
| 511 |
+
// If there is no range then return early to avoid looping over the cache.
|
| 512 |
+
if (p0 == p1) return;
|
| 513 |
+
|
| 514 |
+
if (cache.recurrent) {
|
| 515 |
+
// for Mamba-like or RWKV models, only the pos needs to be shifted
|
| 516 |
+
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
| 517 |
+
const int32_t tail_id = cache.cells[seq_id].tail;
|
| 518 |
+
if (tail_id >= 0) {
|
| 519 |
+
llama_kv_cell & cell = cache.cells[tail_id];
|
| 520 |
+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 521 |
+
cell.pos += delta;
|
| 522 |
+
}
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
return;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 529 |
+
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
| 530 |
+
cache.has_shift = true;
|
| 531 |
+
cache.cells[i].pos += delta;
|
| 532 |
+
cache.cells[i].delta += delta;
|
| 533 |
+
|
| 534 |
+
if (cache.cells[i].pos < 0) {
|
| 535 |
+
if (!cache.cells[i].is_empty()) {
|
| 536 |
+
cache.used--;
|
| 537 |
+
}
|
| 538 |
+
cache.cells[i].pos = -1;
|
| 539 |
+
cache.cells[i].seq_id.clear();
|
| 540 |
+
if (new_head == cache.size) {
|
| 541 |
+
new_head = i;
|
| 542 |
+
}
|
| 543 |
+
}
|
| 544 |
+
}
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 548 |
+
// Otherwise we just start the next search from the beginning.
|
| 549 |
+
cache.head = new_head != cache.size ? new_head : 0;
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
void llama_kv_cache_seq_div(
|
| 553 |
+
struct llama_kv_cache & cache,
|
| 554 |
+
llama_seq_id seq_id,
|
| 555 |
+
llama_pos p0,
|
| 556 |
+
llama_pos p1,
|
| 557 |
+
int d) {
|
| 558 |
+
if (p0 < 0) p0 = 0;
|
| 559 |
+
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
| 560 |
+
// If there is no range then return early to avoid looping over the cache.
|
| 561 |
+
if (p0 == p1) return;
|
| 562 |
+
|
| 563 |
+
if (cache.recurrent) {
|
| 564 |
+
// for Mamba-like or RWKV models, only the pos needs to be changed
|
| 565 |
+
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
| 566 |
+
const int32_t tail_id = cache.cells[seq_id].tail;
|
| 567 |
+
if (tail_id >= 0) {
|
| 568 |
+
llama_kv_cell & cell = cache.cells[tail_id];
|
| 569 |
+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 570 |
+
cell.pos /= d;
|
| 571 |
+
}
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
return;
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 578 |
+
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
| 579 |
+
cache.has_shift = true;
|
| 580 |
+
|
| 581 |
+
{
|
| 582 |
+
llama_pos p_old = cache.cells[i].pos;
|
| 583 |
+
cache.cells[i].pos /= d;
|
| 584 |
+
cache.cells[i].delta += cache.cells[i].pos - p_old;
|
| 585 |
+
}
|
| 586 |
+
}
|
| 587 |
+
}
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
| 591 |
+
llama_pos result = 0;
|
| 592 |
+
|
| 593 |
+
for (uint32_t i = 0; i < cache.size; ++i) {
|
| 594 |
+
if (cache.cells[i].has_seq_id(seq_id)) {
|
| 595 |
+
result = std::max(result, cache.cells[i].pos);
|
| 596 |
+
}
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
return result;
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
|
| 603 |
+
if (!cache.recurrent) {
|
| 604 |
+
cache.do_defrag = true;
|
| 605 |
+
}
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv) {
|
| 609 |
+
int result = 0;
|
| 610 |
+
|
| 611 |
+
for (uint32_t i = 0; i < kv.size; i++) {
|
| 612 |
+
result += kv.cells[i].seq_id.size();
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
return result;
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv) {
|
| 619 |
+
return kv.used;
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv) {
|
| 623 |
+
return kv.can_shift;
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
//
|
| 627 |
+
// kv cache view
|
| 628 |
+
//
|
| 629 |
+
|
| 630 |
+
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max) {
|
| 631 |
+
struct llama_kv_cache_view result = {
|
| 632 |
+
/*.n_cells = */ 0,
|
| 633 |
+
/*.n_seq_max = */ n_seq_max,
|
| 634 |
+
/*.token_count = */ 0,
|
| 635 |
+
/*.used_cells = */ llama_get_kv_cache_used_cells(kv),
|
| 636 |
+
/*.max_contiguous = */ 0,
|
| 637 |
+
/*.max_contiguous_idx = */ -1,
|
| 638 |
+
/*.cells = */ nullptr,
|
| 639 |
+
/*.cells_sequences = */ nullptr,
|
| 640 |
+
};
|
| 641 |
+
|
| 642 |
+
return result;
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
|
| 646 |
+
if (view->cells != nullptr) {
|
| 647 |
+
free(view->cells);
|
| 648 |
+
view->cells = nullptr;
|
| 649 |
+
}
|
| 650 |
+
if (view->cells_sequences != nullptr) {
|
| 651 |
+
free(view->cells_sequences);
|
| 652 |
+
view->cells_sequences = nullptr;
|
| 653 |
+
}
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) {
|
| 657 |
+
if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) {
|
| 658 |
+
view->n_cells = int32_t(kv.size);
|
| 659 |
+
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
|
| 660 |
+
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
|
| 661 |
+
view->cells = (struct llama_kv_cache_view_cell *)p;
|
| 662 |
+
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
|
| 663 |
+
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
|
| 664 |
+
view->cells_sequences = (llama_seq_id *)p;
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
const std::vector<llama_kv_cell> & kv_cells = kv.cells;
|
| 668 |
+
llama_kv_cache_view_cell * c_curr = view->cells;
|
| 669 |
+
llama_seq_id * cs_curr = view->cells_sequences;
|
| 670 |
+
int32_t used_cells = 0;
|
| 671 |
+
int32_t token_count = 0;
|
| 672 |
+
int32_t curr_contig_idx = -1;
|
| 673 |
+
uint32_t max_contig = 0;
|
| 674 |
+
int32_t max_contig_idx = -1;
|
| 675 |
+
|
| 676 |
+
for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) {
|
| 677 |
+
const size_t curr_size = kv_cells[i].seq_id.size();
|
| 678 |
+
token_count += curr_size;
|
| 679 |
+
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
|
| 680 |
+
|
| 681 |
+
if (curr_size > 0) {
|
| 682 |
+
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
|
| 683 |
+
max_contig = i - curr_contig_idx;
|
| 684 |
+
max_contig_idx = curr_contig_idx;
|
| 685 |
+
}
|
| 686 |
+
curr_contig_idx = -1;
|
| 687 |
+
} else if (curr_contig_idx < 0) {
|
| 688 |
+
curr_contig_idx = i;
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
int seq_idx = 0;
|
| 692 |
+
for (const llama_seq_id it : kv_cells[i].seq_id) {
|
| 693 |
+
if (seq_idx >= view->n_seq_max) {
|
| 694 |
+
break;
|
| 695 |
+
}
|
| 696 |
+
cs_curr[seq_idx] = it;
|
| 697 |
+
seq_idx++;
|
| 698 |
+
}
|
| 699 |
+
if (seq_idx != 0) {
|
| 700 |
+
used_cells++;
|
| 701 |
+
}
|
| 702 |
+
for (; seq_idx < view->n_seq_max; seq_idx++) {
|
| 703 |
+
cs_curr[seq_idx] = -1;
|
| 704 |
+
}
|
| 705 |
+
}
|
| 706 |
+
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
|
| 707 |
+
max_contig_idx = curr_contig_idx;
|
| 708 |
+
max_contig = kv_cells.size() - curr_contig_idx;
|
| 709 |
+
}
|
| 710 |
+
view->max_contiguous = max_contig;
|
| 711 |
+
view->max_contiguous_idx = max_contig_idx;
|
| 712 |
+
view->token_count = token_count;
|
| 713 |
+
view->used_cells = used_cells;
|
| 714 |
+
if (uint32_t(used_cells) != kv.used) {
|
| 715 |
+
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
|
| 716 |
+
__func__, kv.used, used_cells);
|
| 717 |
+
}
|
| 718 |
+
}
|
examples/talk-llama/llama-kv-cache.h
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
|
| 5 |
+
#include "ggml-cpp.h"
|
| 6 |
+
|
| 7 |
+
#include <set>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
struct llama_kv_cell {
|
| 11 |
+
llama_pos pos = -1;
|
| 12 |
+
llama_pos delta = 0;
|
| 13 |
+
int32_t src = -1; // used by recurrent state models to copy states
|
| 14 |
+
int32_t tail = -1;
|
| 15 |
+
|
| 16 |
+
std::set<llama_seq_id> seq_id;
|
| 17 |
+
|
| 18 |
+
bool has_seq_id(const llama_seq_id & id) const {
|
| 19 |
+
return seq_id.find(id) != seq_id.end();
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
bool is_empty() const {
|
| 23 |
+
return seq_id.empty();
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
bool is_same_seq(const llama_kv_cell & other) const {
|
| 27 |
+
return seq_id == other.seq_id;
|
| 28 |
+
}
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
// ring-buffer of cached KV data
|
| 32 |
+
struct llama_kv_cache {
|
| 33 |
+
bool has_shift = false;
|
| 34 |
+
bool do_defrag = false;
|
| 35 |
+
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
| 36 |
+
bool v_trans = true; // the value tensor is transposed
|
| 37 |
+
bool can_shift = false;
|
| 38 |
+
|
| 39 |
+
// Note: The value of head isn't only used to optimize searching
|
| 40 |
+
// for a free KV slot. llama_decode_internal also uses it, so it
|
| 41 |
+
// cannot be freely changed after a slot has been allocated.
|
| 42 |
+
uint32_t head = 0;
|
| 43 |
+
uint32_t size = 0;
|
| 44 |
+
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
| 45 |
+
|
| 46 |
+
// computed before each graph build
|
| 47 |
+
uint32_t n = 0;
|
| 48 |
+
|
| 49 |
+
ggml_type type_k = GGML_TYPE_F16;
|
| 50 |
+
ggml_type type_v = GGML_TYPE_F16;
|
| 51 |
+
|
| 52 |
+
std::vector<llama_kv_cell> cells;
|
| 53 |
+
|
| 54 |
+
std::vector<struct ggml_tensor *> k_l; // per layer
|
| 55 |
+
std::vector<struct ggml_tensor *> v_l;
|
| 56 |
+
|
| 57 |
+
std::vector<ggml_context_ptr> ctxs;
|
| 58 |
+
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 59 |
+
|
| 60 |
+
size_t total_size() const {
|
| 61 |
+
size_t size = 0;
|
| 62 |
+
for (const auto & buf : bufs) {
|
| 63 |
+
size += ggml_backend_buffer_get_size(buf.get());
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
return size;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// TODO: better data structures to reduce the cost of this operation
|
| 70 |
+
llama_pos max_pos() const {
|
| 71 |
+
llama_pos max_pos = -1;
|
| 72 |
+
for (const auto & cell : cells) {
|
| 73 |
+
max_pos = std::max(max_pos, cell.pos);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return max_pos;
|
| 77 |
+
}
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
// a structure holds information about the slot found in llama_kv_cache_find_slot
|
| 81 |
+
struct llama_kv_cache_slot_info {
|
| 82 |
+
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
|
| 83 |
+
bool found = false; // the slot was found
|
| 84 |
+
|
| 85 |
+
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
|
| 86 |
+
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
|
| 87 |
+
|
| 88 |
+
operator bool() const { return found; }
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
// TODO: maybe not needed
|
| 92 |
+
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
|
| 93 |
+
|
| 94 |
+
bool llama_kv_cache_init(
|
| 95 |
+
struct llama_kv_cache & cache,
|
| 96 |
+
const llama_model & model,
|
| 97 |
+
const llama_cparams & cparams,
|
| 98 |
+
ggml_type type_k,
|
| 99 |
+
ggml_type type_v,
|
| 100 |
+
uint32_t kv_size,
|
| 101 |
+
bool offload);
|
| 102 |
+
|
| 103 |
+
// find an empty slot of size "n_tokens" in the cache
|
| 104 |
+
// updates the cache head
|
| 105 |
+
// returns a structure holding information about the slot found
|
| 106 |
+
// Note: On success, it's important that cache.head points
|
| 107 |
+
// to the first cell of the slot.
|
| 108 |
+
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
| 109 |
+
struct llama_kv_cache & cache,
|
| 110 |
+
const struct llama_ubatch & batch);
|
| 111 |
+
|
| 112 |
+
// find how many cells are currently in use
|
| 113 |
+
uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
|
| 114 |
+
|
| 115 |
+
void llama_kv_cache_clear(struct llama_kv_cache & cache);
|
| 116 |
+
|
| 117 |
+
bool llama_kv_cache_seq_rm(
|
| 118 |
+
struct llama_kv_cache & cache,
|
| 119 |
+
llama_seq_id seq_id,
|
| 120 |
+
llama_pos p0,
|
| 121 |
+
llama_pos p1);
|
| 122 |
+
|
| 123 |
+
void llama_kv_cache_seq_cp(
|
| 124 |
+
struct llama_kv_cache & cache,
|
| 125 |
+
llama_seq_id seq_id_src,
|
| 126 |
+
llama_seq_id seq_id_dst,
|
| 127 |
+
llama_pos p0,
|
| 128 |
+
llama_pos p1);
|
| 129 |
+
|
| 130 |
+
void llama_kv_cache_seq_keep(
|
| 131 |
+
struct llama_kv_cache & cache,
|
| 132 |
+
llama_seq_id seq_id);
|
| 133 |
+
|
| 134 |
+
void llama_kv_cache_seq_add(
|
| 135 |
+
struct llama_kv_cache & cache,
|
| 136 |
+
llama_seq_id seq_id,
|
| 137 |
+
llama_pos p0,
|
| 138 |
+
llama_pos p1,
|
| 139 |
+
llama_pos delta);
|
| 140 |
+
|
| 141 |
+
void llama_kv_cache_seq_div(
|
| 142 |
+
struct llama_kv_cache & cache,
|
| 143 |
+
llama_seq_id seq_id,
|
| 144 |
+
llama_pos p0,
|
| 145 |
+
llama_pos p1,
|
| 146 |
+
int d);
|
| 147 |
+
|
| 148 |
+
llama_pos llama_kv_cache_seq_pos_max(
|
| 149 |
+
struct llama_kv_cache & cache,
|
| 150 |
+
llama_seq_id seq_id);
|
| 151 |
+
|
| 152 |
+
void llama_kv_cache_defrag(struct llama_kv_cache & cache);
|
| 153 |
+
|
| 154 |
+
int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
|
| 155 |
+
|
| 156 |
+
int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
|
| 157 |
+
|
| 158 |
+
bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
|
| 159 |
+
|
| 160 |
+
//
|
| 161 |
+
// kv cache view
|
| 162 |
+
//
|
| 163 |
+
|
| 164 |
+
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
|
| 165 |
+
|
| 166 |
+
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
|
| 167 |
+
|
| 168 |
+
//
|
| 169 |
+
// kv cache restore
|
| 170 |
+
//
|
| 171 |
+
|
| 172 |
+
// saves the kv_cache state for future recovery.
|
| 173 |
+
// used to rollback llama_kv_cache_find_slot changes.
|
| 174 |
+
struct llama_kv_slot_restorer {
|
| 175 |
+
struct llama_kv_cache_state {
|
| 176 |
+
uint32_t head = 0;
|
| 177 |
+
uint32_t n = 0;
|
| 178 |
+
} old_state;
|
| 179 |
+
|
| 180 |
+
// for non-recurrent models only
|
| 181 |
+
// list of slots to restore
|
| 182 |
+
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
|
| 183 |
+
|
| 184 |
+
bool do_restore = false;
|
| 185 |
+
|
| 186 |
+
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
|
| 187 |
+
old_state.head = cache.head;
|
| 188 |
+
old_state.n = cache.n;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// saves a slot information for future restoration
|
| 192 |
+
void save(const struct llama_kv_cache_slot_info & slot) {
|
| 193 |
+
if (slot) {
|
| 194 |
+
do_restore = true;
|
| 195 |
+
if (slot.boundaries.first != slot.boundaries.second) {
|
| 196 |
+
slot_boundaries.push_back(slot.boundaries);
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
// must be explicitly called to restore the kv_cache state
|
| 202 |
+
// and rollback changes from all llama_kv_cache_find_slot calls
|
| 203 |
+
void restore(struct llama_kv_cache & cache) {
|
| 204 |
+
if (do_restore) {
|
| 205 |
+
cache.head = old_state.head;
|
| 206 |
+
cache.n = old_state.n;
|
| 207 |
+
|
| 208 |
+
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
|
| 209 |
+
llama_kv_cache_seq_rm(cache, -1, -1, -1);
|
| 210 |
+
} else {
|
| 211 |
+
for (auto & slot : slot_boundaries) {
|
| 212 |
+
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
};
|
| 218 |
+
|
examples/talk-llama/llama-mmap.cpp
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-mmap.h"
|
| 2 |
+
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
|
| 5 |
+
#include "ggml.h"
|
| 6 |
+
|
| 7 |
+
#include <cstring>
|
| 8 |
+
#include <climits>
|
| 9 |
+
#include <stdexcept>
|
| 10 |
+
|
| 11 |
+
#ifdef __has_include
|
| 12 |
+
#if __has_include(<unistd.h>)
|
| 13 |
+
#include <unistd.h>
|
| 14 |
+
#if defined(_POSIX_MAPPED_FILES)
|
| 15 |
+
#include <sys/mman.h>
|
| 16 |
+
#include <fcntl.h>
|
| 17 |
+
#endif
|
| 18 |
+
#if defined(_POSIX_MEMLOCK_RANGE)
|
| 19 |
+
#include <sys/resource.h>
|
| 20 |
+
#endif
|
| 21 |
+
#endif
|
| 22 |
+
#endif
|
| 23 |
+
|
| 24 |
+
#if defined(_WIN32)
|
| 25 |
+
#define WIN32_LEAN_AND_MEAN
|
| 26 |
+
#ifndef NOMINMAX
|
| 27 |
+
#define NOMINMAX
|
| 28 |
+
#endif
|
| 29 |
+
#include <windows.h>
|
| 30 |
+
#ifndef PATH_MAX
|
| 31 |
+
#define PATH_MAX MAX_PATH
|
| 32 |
+
#endif
|
| 33 |
+
#include <io.h>
|
| 34 |
+
#endif
|
| 35 |
+
|
| 36 |
+
// TODO: consider moving to llama-impl.h if needed in more places
|
| 37 |
+
#if defined(_WIN32)
|
| 38 |
+
std::string llama_format_win_err(DWORD err) {
|
| 39 |
+
LPSTR buf;
|
| 40 |
+
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
| 41 |
+
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
|
| 42 |
+
if (!size) {
|
| 43 |
+
return "FormatMessageA failed";
|
| 44 |
+
}
|
| 45 |
+
std::string ret(buf, size);
|
| 46 |
+
LocalFree(buf);
|
| 47 |
+
return ret;
|
| 48 |
+
}
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
+
// llama_file
|
| 52 |
+
|
| 53 |
+
struct llama_file::impl {
|
| 54 |
+
#if defined(_WIN32)
|
| 55 |
+
HANDLE fp_win32;
|
| 56 |
+
std::string GetErrorMessageWin32(DWORD error_code) const {
|
| 57 |
+
std::string ret;
|
| 58 |
+
LPSTR lpMsgBuf = NULL;
|
| 59 |
+
DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
| 60 |
+
NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
|
| 61 |
+
if (!bufLen) {
|
| 62 |
+
ret = format("Win32 error code: %lx", error_code);
|
| 63 |
+
} else {
|
| 64 |
+
ret = lpMsgBuf;
|
| 65 |
+
LocalFree(lpMsgBuf);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
return ret;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
impl(const char * fname, const char * mode) {
|
| 72 |
+
fp = ggml_fopen(fname, mode);
|
| 73 |
+
if (fp == NULL) {
|
| 74 |
+
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
| 75 |
+
}
|
| 76 |
+
fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp));
|
| 77 |
+
seek(0, SEEK_END);
|
| 78 |
+
size = tell();
|
| 79 |
+
seek(0, SEEK_SET);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
size_t tell() const {
|
| 83 |
+
LARGE_INTEGER li;
|
| 84 |
+
li.QuadPart = 0;
|
| 85 |
+
BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT);
|
| 86 |
+
if (!ret) {
|
| 87 |
+
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
return li.QuadPart;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
void seek(size_t offset, int whence) const {
|
| 94 |
+
static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN");
|
| 95 |
+
static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT");
|
| 96 |
+
static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END");
|
| 97 |
+
|
| 98 |
+
LARGE_INTEGER li;
|
| 99 |
+
li.QuadPart = offset;
|
| 100 |
+
BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence);
|
| 101 |
+
if (!ret) {
|
| 102 |
+
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
void read_raw(void * ptr, size_t len) const {
|
| 107 |
+
size_t bytes_read = 0;
|
| 108 |
+
while (bytes_read < len) {
|
| 109 |
+
size_t chunk_size = std::min<size_t>(len - bytes_read, 64*1024*1024);
|
| 110 |
+
DWORD chunk_read = 0;
|
| 111 |
+
BOOL result = ReadFile(fp_win32, reinterpret_cast<char*>(ptr) + bytes_read, chunk_size, &chunk_read, NULL);
|
| 112 |
+
if (!result) {
|
| 113 |
+
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
| 114 |
+
}
|
| 115 |
+
if (chunk_read < chunk_size || chunk_read == 0) {
|
| 116 |
+
throw std::runtime_error("unexpectedly reached end of file");
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
bytes_read += chunk_read;
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
uint32_t read_u32() const {
|
| 124 |
+
uint32_t val;
|
| 125 |
+
read_raw(&val, sizeof(val));
|
| 126 |
+
return val;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
void write_raw(const void * ptr, size_t len) const {
|
| 130 |
+
size_t bytes_written = 0;
|
| 131 |
+
while (bytes_written < len) {
|
| 132 |
+
size_t chunk_size = std::min<size_t>(len - bytes_written, 64*1024*1024);
|
| 133 |
+
DWORD chunk_written = 0;
|
| 134 |
+
BOOL result = WriteFile(fp_win32, reinterpret_cast<char const*>(ptr) + bytes_written, chunk_size, &chunk_written, NULL);
|
| 135 |
+
if (!result) {
|
| 136 |
+
throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
| 137 |
+
}
|
| 138 |
+
if (chunk_written < chunk_size || chunk_written == 0) {
|
| 139 |
+
throw std::runtime_error("unexpectedly failed to write bytes");
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
bytes_written += chunk_written;
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
void write_u32(uint32_t val) const {
|
| 147 |
+
write_raw(&val, sizeof(val));
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
~impl() {
|
| 151 |
+
if (fp) {
|
| 152 |
+
std::fclose(fp);
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
#else
|
| 156 |
+
impl(const char * fname, const char * mode) {
|
| 157 |
+
fp = ggml_fopen(fname, mode);
|
| 158 |
+
if (fp == NULL) {
|
| 159 |
+
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
| 160 |
+
}
|
| 161 |
+
seek(0, SEEK_END);
|
| 162 |
+
size = tell();
|
| 163 |
+
seek(0, SEEK_SET);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
size_t tell() const {
|
| 167 |
+
// TODO: this ifdef is never true?
|
| 168 |
+
#ifdef _WIN32
|
| 169 |
+
__int64 ret = _ftelli64(fp);
|
| 170 |
+
#else
|
| 171 |
+
long ret = std::ftell(fp);
|
| 172 |
+
#endif
|
| 173 |
+
if (ret == -1) {
|
| 174 |
+
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
return (size_t) ret;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
void seek(size_t offset, int whence) const {
|
| 181 |
+
// TODO: this ifdef is never true?
|
| 182 |
+
#ifdef _WIN32
|
| 183 |
+
int ret = _fseeki64(fp, (__int64) offset, whence);
|
| 184 |
+
#else
|
| 185 |
+
int ret = std::fseek(fp, (long) offset, whence);
|
| 186 |
+
#endif
|
| 187 |
+
if (ret != 0) {
|
| 188 |
+
throw std::runtime_error(format("seek error: %s", strerror(errno)));
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
void read_raw(void * ptr, size_t len) const {
|
| 193 |
+
if (len == 0) {
|
| 194 |
+
return;
|
| 195 |
+
}
|
| 196 |
+
errno = 0;
|
| 197 |
+
std::size_t ret = std::fread(ptr, len, 1, fp);
|
| 198 |
+
if (ferror(fp)) {
|
| 199 |
+
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
| 200 |
+
}
|
| 201 |
+
if (ret != 1) {
|
| 202 |
+
throw std::runtime_error("unexpectedly reached end of file");
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
uint32_t read_u32() const {
|
| 207 |
+
uint32_t ret;
|
| 208 |
+
read_raw(&ret, sizeof(ret));
|
| 209 |
+
return ret;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
void write_raw(const void * ptr, size_t len) const {
|
| 213 |
+
if (len == 0) {
|
| 214 |
+
return;
|
| 215 |
+
}
|
| 216 |
+
errno = 0;
|
| 217 |
+
size_t ret = std::fwrite(ptr, len, 1, fp);
|
| 218 |
+
if (ret != 1) {
|
| 219 |
+
throw std::runtime_error(format("write error: %s", strerror(errno)));
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
void write_u32(uint32_t val) const {
|
| 224 |
+
write_raw(&val, sizeof(val));
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
~impl() {
|
| 228 |
+
if (fp) {
|
| 229 |
+
std::fclose(fp);
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
#endif
|
| 233 |
+
|
| 234 |
+
FILE * fp;
|
| 235 |
+
size_t size;
|
| 236 |
+
};
|
| 237 |
+
|
| 238 |
+
llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique<impl>(fname, mode)) {}
|
| 239 |
+
llama_file::~llama_file() = default;
|
| 240 |
+
|
| 241 |
+
size_t llama_file::tell() const { return pimpl->tell(); }
|
| 242 |
+
size_t llama_file::size() const { return pimpl->size; }
|
| 243 |
+
|
| 244 |
+
int llama_file::file_id() const {
|
| 245 |
+
#ifdef _WIN32
|
| 246 |
+
return _fileno(pimpl->fp);
|
| 247 |
+
#else
|
| 248 |
+
#if defined(fileno)
|
| 249 |
+
return fileno(pimpl->fp);
|
| 250 |
+
#else
|
| 251 |
+
return ::fileno(pimpl->fp);
|
| 252 |
+
#endif
|
| 253 |
+
#endif
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
|
| 257 |
+
void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
|
| 258 |
+
|
| 259 |
+
uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
|
| 260 |
+
|
| 261 |
+
void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
|
| 262 |
+
void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); }
|
| 263 |
+
|
| 264 |
+
// llama_mmap
|
| 265 |
+
|
| 266 |
+
struct llama_mmap::impl {
|
| 267 |
+
#ifdef _POSIX_MAPPED_FILES
|
| 268 |
+
std::vector<std::pair<size_t, size_t>> mapped_fragments;
|
| 269 |
+
|
| 270 |
+
impl(struct llama_file * file, size_t prefetch, bool numa) {
|
| 271 |
+
size = file->size();
|
| 272 |
+
int fd = file->file_id();
|
| 273 |
+
int flags = MAP_SHARED;
|
| 274 |
+
if (numa) { prefetch = 0; }
|
| 275 |
+
#ifdef __linux__
|
| 276 |
+
if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
|
| 277 |
+
LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n",
|
| 278 |
+
strerror(errno));
|
| 279 |
+
}
|
| 280 |
+
if (prefetch) { flags |= MAP_POPULATE; }
|
| 281 |
+
#endif
|
| 282 |
+
addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0);
|
| 283 |
+
if (addr == MAP_FAILED) {
|
| 284 |
+
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
if (prefetch > 0) {
|
| 288 |
+
if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) {
|
| 289 |
+
LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
|
| 290 |
+
strerror(errno));
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
if (numa) {
|
| 294 |
+
if (posix_madvise(addr, file->size(), POSIX_MADV_RANDOM)) {
|
| 295 |
+
LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
|
| 296 |
+
strerror(errno));
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
mapped_fragments.emplace_back(0, file->size());
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
static void align_range(size_t * first, size_t * last, size_t page_size) {
|
| 304 |
+
size_t offset_in_page = *first & (page_size - 1);
|
| 305 |
+
size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
|
| 306 |
+
*first += offset_to_page;
|
| 307 |
+
|
| 308 |
+
*last = *last & ~(page_size - 1);
|
| 309 |
+
|
| 310 |
+
if (*last <= *first) {
|
| 311 |
+
*last = *first;
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
void unmap_fragment(size_t first, size_t last) {
|
| 316 |
+
int page_size = sysconf(_SC_PAGESIZE);
|
| 317 |
+
align_range(&first, &last, page_size);
|
| 318 |
+
size_t len = last - first;
|
| 319 |
+
|
| 320 |
+
if (len == 0) {
|
| 321 |
+
return;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
GGML_ASSERT(first % page_size == 0);
|
| 325 |
+
GGML_ASSERT(last % page_size == 0);
|
| 326 |
+
GGML_ASSERT(last > first);
|
| 327 |
+
|
| 328 |
+
void * next_page_start = (uint8_t *) addr + first;
|
| 329 |
+
|
| 330 |
+
if (munmap(next_page_start, len)) {
|
| 331 |
+
LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
std::vector<std::pair<size_t, size_t>> new_mapped_fragments;
|
| 335 |
+
for (const auto & frag : mapped_fragments) {
|
| 336 |
+
if (frag.first < first && frag.second > last) {
|
| 337 |
+
new_mapped_fragments.emplace_back(frag.first, first);
|
| 338 |
+
new_mapped_fragments.emplace_back(last, frag.second);
|
| 339 |
+
} else if (frag.first < first && frag.second > first) {
|
| 340 |
+
new_mapped_fragments.emplace_back(frag.first, first);
|
| 341 |
+
} else if (frag.first < last && frag.second > last) {
|
| 342 |
+
new_mapped_fragments.emplace_back(last, frag.second);
|
| 343 |
+
} else if (frag.first >= first && frag.second <= last) {
|
| 344 |
+
} else {
|
| 345 |
+
new_mapped_fragments.push_back(frag);
|
| 346 |
+
}
|
| 347 |
+
}
|
| 348 |
+
mapped_fragments = std::move(new_mapped_fragments);
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
~impl() {
|
| 352 |
+
for (const auto & frag : mapped_fragments) {
|
| 353 |
+
if (munmap((char *) addr + frag.first, frag.second - frag.first)) {
|
| 354 |
+
LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
#elif defined(_WIN32)
|
| 359 |
+
impl(struct llama_file * file, size_t prefetch, bool numa) {
|
| 360 |
+
GGML_UNUSED(numa);
|
| 361 |
+
|
| 362 |
+
size = file->size();
|
| 363 |
+
|
| 364 |
+
HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id());
|
| 365 |
+
|
| 366 |
+
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
| 367 |
+
|
| 368 |
+
if (hMapping == NULL) {
|
| 369 |
+
DWORD error = GetLastError();
|
| 370 |
+
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
| 374 |
+
DWORD error = GetLastError();
|
| 375 |
+
CloseHandle(hMapping);
|
| 376 |
+
|
| 377 |
+
if (addr == NULL) {
|
| 378 |
+
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
if (prefetch > 0) {
|
| 382 |
+
#if _WIN32_WINNT >= 0x602
|
| 383 |
+
BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG);
|
| 384 |
+
HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll");
|
| 385 |
+
|
| 386 |
+
pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory");
|
| 387 |
+
|
| 388 |
+
if (pPrefetchVirtualMemory) {
|
| 389 |
+
WIN32_MEMORY_RANGE_ENTRY range;
|
| 390 |
+
range.VirtualAddress = addr;
|
| 391 |
+
range.NumberOfBytes = (SIZE_T) std::min(size, prefetch);
|
| 392 |
+
if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
|
| 393 |
+
LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n",
|
| 394 |
+
llama_format_win_err(GetLastError()).c_str());
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
#else
|
| 398 |
+
throw std::runtime_error("PrefetchVirtualMemory unavailable");
|
| 399 |
+
#endif
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
void unmap_fragment(size_t first, size_t last) {
|
| 404 |
+
GGML_UNUSED(first);
|
| 405 |
+
GGML_UNUSED(last);
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
~impl() {
|
| 409 |
+
if (!UnmapViewOfFile(addr)) {
|
| 410 |
+
LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
|
| 411 |
+
llama_format_win_err(GetLastError()).c_str());
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
#else
|
| 415 |
+
impl(struct llama_file * file, size_t prefetch, bool numa) {
|
| 416 |
+
GGML_UNUSED(file);
|
| 417 |
+
GGML_UNUSED(prefetch);
|
| 418 |
+
GGML_UNUSED(numa);
|
| 419 |
+
|
| 420 |
+
throw std::runtime_error("mmap not supported");
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
void unmap_fragment(size_t first, size_t last) {
|
| 424 |
+
GGML_UNUSED(first);
|
| 425 |
+
GGML_UNUSED(last);
|
| 426 |
+
|
| 427 |
+
throw std::runtime_error("mmap not supported");
|
| 428 |
+
}
|
| 429 |
+
#endif
|
| 430 |
+
|
| 431 |
+
void * addr;
|
| 432 |
+
size_t size;
|
| 433 |
+
};
|
| 434 |
+
|
| 435 |
+
llama_mmap::llama_mmap(struct llama_file * file, size_t prefetch, bool numa) : pimpl(std::make_unique<impl>(file, prefetch, numa)) {}
|
| 436 |
+
llama_mmap::~llama_mmap() = default;
|
| 437 |
+
|
| 438 |
+
size_t llama_mmap::size() const { return pimpl->size; }
|
| 439 |
+
void * llama_mmap::addr() const { return pimpl->addr; }
|
| 440 |
+
|
| 441 |
+
void llama_mmap::unmap_fragment(size_t first, size_t last) { pimpl->unmap_fragment(first, last); }
|
| 442 |
+
|
| 443 |
+
#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32)
|
| 444 |
+
const bool llama_mmap::SUPPORTED = true;
|
| 445 |
+
#else
|
| 446 |
+
const bool llama_mmap::SUPPORTED = false;
|
| 447 |
+
#endif
|
| 448 |
+
|
| 449 |
+
// llama_mlock
|
| 450 |
+
|
| 451 |
+
struct llama_mlock::impl {
|
| 452 |
+
#ifdef _POSIX_MEMLOCK_RANGE
|
| 453 |
+
static size_t lock_granularity() {
|
| 454 |
+
return (size_t) sysconf(_SC_PAGESIZE);
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
bool raw_lock(const void * addr, size_t size) const {
|
| 458 |
+
if (!mlock(addr, size)) {
|
| 459 |
+
return true;
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
#ifdef __APPLE__
|
| 463 |
+
#define MLOCK_SUGGESTION \
|
| 464 |
+
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
|
| 465 |
+
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n"
|
| 466 |
+
#else
|
| 467 |
+
#define MLOCK_SUGGESTION \
|
| 468 |
+
"Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n"
|
| 469 |
+
#endif
|
| 470 |
+
|
| 471 |
+
char* errmsg = std::strerror(errno);
|
| 472 |
+
bool suggest = (errno == ENOMEM);
|
| 473 |
+
|
| 474 |
+
struct rlimit lock_limit;
|
| 475 |
+
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
|
| 476 |
+
suggest = false;
|
| 477 |
+
}
|
| 478 |
+
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
|
| 479 |
+
suggest = false;
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
| 483 |
+
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
| 484 |
+
return false;
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
static void raw_unlock(void * addr, size_t size) {
|
| 488 |
+
if (munlock(addr, size)) {
|
| 489 |
+
LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno));
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
#elif defined(_WIN32)
|
| 493 |
+
static size_t lock_granularity() {
|
| 494 |
+
SYSTEM_INFO si;
|
| 495 |
+
GetSystemInfo(&si);
|
| 496 |
+
return (size_t) si.dwPageSize;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
bool raw_lock(void * ptr, size_t len) const {
|
| 500 |
+
for (int tries = 1; ; tries++) {
|
| 501 |
+
if (VirtualLock(ptr, len)) {
|
| 502 |
+
return true;
|
| 503 |
+
}
|
| 504 |
+
if (tries == 2) {
|
| 505 |
+
LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
|
| 506 |
+
len, size, llama_format_win_err(GetLastError()).c_str());
|
| 507 |
+
return false;
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
SIZE_T min_ws_size, max_ws_size;
|
| 511 |
+
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
|
| 512 |
+
LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n",
|
| 513 |
+
llama_format_win_err(GetLastError()).c_str());
|
| 514 |
+
return false;
|
| 515 |
+
}
|
| 516 |
+
size_t increment = len + 1048576;
|
| 517 |
+
min_ws_size += increment;
|
| 518 |
+
max_ws_size += increment;
|
| 519 |
+
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
|
| 520 |
+
LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n",
|
| 521 |
+
llama_format_win_err(GetLastError()).c_str());
|
| 522 |
+
return false;
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
static void raw_unlock(void * ptr, size_t len) {
|
| 528 |
+
if (!VirtualUnlock(ptr, len)) {
|
| 529 |
+
LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n",
|
| 530 |
+
llama_format_win_err(GetLastError()).c_str());
|
| 531 |
+
}
|
| 532 |
+
}
|
| 533 |
+
#else
|
| 534 |
+
static size_t lock_granularity() {
|
| 535 |
+
return (size_t) 65536;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
bool raw_lock(const void * addr, size_t len) const {
|
| 539 |
+
LLAMA_LOG_WARN("warning: mlock not supported on this system\n");
|
| 540 |
+
return false;
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
static void raw_unlock(const void * addr, size_t len) {}
|
| 544 |
+
#endif
|
| 545 |
+
|
| 546 |
+
impl() : addr(NULL), size(0), failed_already(false) {}
|
| 547 |
+
|
| 548 |
+
void init(void * ptr) {
|
| 549 |
+
GGML_ASSERT(addr == NULL && size == 0);
|
| 550 |
+
addr = ptr;
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
void grow_to(size_t target_size) {
|
| 554 |
+
GGML_ASSERT(addr);
|
| 555 |
+
if (failed_already) {
|
| 556 |
+
return;
|
| 557 |
+
}
|
| 558 |
+
size_t granularity = lock_granularity();
|
| 559 |
+
target_size = (target_size + granularity - 1) & ~(granularity - 1);
|
| 560 |
+
if (target_size > size) {
|
| 561 |
+
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
|
| 562 |
+
size = target_size;
|
| 563 |
+
} else {
|
| 564 |
+
failed_already = true;
|
| 565 |
+
}
|
| 566 |
+
}
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
void * addr;
|
| 570 |
+
size_t size;
|
| 571 |
+
|
| 572 |
+
bool failed_already;
|
| 573 |
+
};
|
| 574 |
+
|
| 575 |
+
llama_mlock::llama_mlock() : pimpl(std::make_unique<impl>()) {}
|
| 576 |
+
llama_mlock::~llama_mlock() = default;
|
| 577 |
+
|
| 578 |
+
void llama_mlock::init(void * ptr) { pimpl->init(ptr); }
|
| 579 |
+
void llama_mlock::grow_to(size_t target_size) { pimpl->grow_to(target_size); }
|
| 580 |
+
|
| 581 |
+
#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32)
|
| 582 |
+
const bool llama_mlock::SUPPORTED = true;
|
| 583 |
+
#else
|
| 584 |
+
const bool llama_mlock::SUPPORTED = false;
|
| 585 |
+
#endif
|
| 586 |
+
|
| 587 |
+
size_t llama_path_max() {
|
| 588 |
+
return PATH_MAX;
|
| 589 |
+
}
|
examples/talk-llama/llama-mmap.h
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <memory>
|
| 4 |
+
#include <vector>
|
| 5 |
+
|
| 6 |
+
struct llama_file;
|
| 7 |
+
struct llama_mmap;
|
| 8 |
+
struct llama_mlock;
|
| 9 |
+
|
| 10 |
+
using llama_files = std::vector<std::unique_ptr<llama_file>>;
|
| 11 |
+
using llama_mmaps = std::vector<std::unique_ptr<llama_mmap>>;
|
| 12 |
+
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
|
| 13 |
+
|
| 14 |
+
struct llama_file {
|
| 15 |
+
llama_file(const char * fname, const char * mode);
|
| 16 |
+
~llama_file();
|
| 17 |
+
|
| 18 |
+
size_t tell() const;
|
| 19 |
+
size_t size() const;
|
| 20 |
+
|
| 21 |
+
int file_id() const; // fileno overload
|
| 22 |
+
|
| 23 |
+
void seek(size_t offset, int whence) const;
|
| 24 |
+
|
| 25 |
+
void read_raw(void * ptr, size_t len) const;
|
| 26 |
+
uint32_t read_u32() const;
|
| 27 |
+
|
| 28 |
+
void write_raw(const void * ptr, size_t len) const;
|
| 29 |
+
void write_u32(uint32_t val) const;
|
| 30 |
+
|
| 31 |
+
private:
|
| 32 |
+
struct impl;
|
| 33 |
+
std::unique_ptr<impl> pimpl;
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
struct llama_mmap {
|
| 37 |
+
llama_mmap(const llama_mmap &) = delete;
|
| 38 |
+
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false);
|
| 39 |
+
~llama_mmap();
|
| 40 |
+
|
| 41 |
+
size_t size() const;
|
| 42 |
+
void * addr() const;
|
| 43 |
+
|
| 44 |
+
void unmap_fragment(size_t first, size_t last);
|
| 45 |
+
|
| 46 |
+
static const bool SUPPORTED;
|
| 47 |
+
|
| 48 |
+
private:
|
| 49 |
+
struct impl;
|
| 50 |
+
std::unique_ptr<impl> pimpl;
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
struct llama_mlock {
|
| 54 |
+
llama_mlock();
|
| 55 |
+
~llama_mlock();
|
| 56 |
+
|
| 57 |
+
void init(void * ptr);
|
| 58 |
+
void grow_to(size_t target_size);
|
| 59 |
+
|
| 60 |
+
static const bool SUPPORTED;
|
| 61 |
+
|
| 62 |
+
private:
|
| 63 |
+
struct impl;
|
| 64 |
+
std::unique_ptr<impl> pimpl;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
size_t llama_path_max();
|
examples/talk-llama/llama-model-loader.cpp
ADDED
|
@@ -0,0 +1,1010 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-model-loader.h"
|
| 2 |
+
|
| 3 |
+
#include "ggml.h"
|
| 4 |
+
|
| 5 |
+
#include <array>
|
| 6 |
+
#include <cinttypes>
|
| 7 |
+
#include <cstring>
|
| 8 |
+
#include <future>
|
| 9 |
+
|
| 10 |
+
const char * llama_file_version_name(llama_fver version) {
|
| 11 |
+
switch (version) {
|
| 12 |
+
case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
|
| 13 |
+
case GGUF_FILE_VERSION_V2: return "GGUF V2";
|
| 14 |
+
case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)";
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
return "unknown";
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
namespace GGUFMeta {
|
| 21 |
+
template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int)>
|
| 22 |
+
struct GKV_Base_Type {
|
| 23 |
+
static constexpr gguf_type gt = gt_;
|
| 24 |
+
|
| 25 |
+
static T getter(const gguf_context * ctx, const int kid) {
|
| 26 |
+
return gfun(ctx, kid);
|
| 27 |
+
}
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
template<typename T> struct GKV_Base;
|
| 31 |
+
|
| 32 |
+
template<> struct GKV_Base<bool >: GKV_Base_Type<bool, GGUF_TYPE_BOOL, gguf_get_val_bool> {};
|
| 33 |
+
template<> struct GKV_Base<uint8_t >: GKV_Base_Type<uint8_t, GGUF_TYPE_UINT8, gguf_get_val_u8 > {};
|
| 34 |
+
template<> struct GKV_Base<uint16_t >: GKV_Base_Type<uint16_t, GGUF_TYPE_UINT16, gguf_get_val_u16 > {};
|
| 35 |
+
template<> struct GKV_Base<uint32_t >: GKV_Base_Type<uint32_t, GGUF_TYPE_UINT32, gguf_get_val_u32 > {};
|
| 36 |
+
template<> struct GKV_Base<uint64_t >: GKV_Base_Type<uint64_t, GGUF_TYPE_UINT64, gguf_get_val_u64 > {};
|
| 37 |
+
template<> struct GKV_Base<int8_t >: GKV_Base_Type<int8_t, GGUF_TYPE_INT8, gguf_get_val_i8 > {};
|
| 38 |
+
template<> struct GKV_Base<int16_t >: GKV_Base_Type<int16_t, GGUF_TYPE_INT16, gguf_get_val_i16 > {};
|
| 39 |
+
template<> struct GKV_Base<int32_t >: GKV_Base_Type<int32_t, GGUF_TYPE_INT32, gguf_get_val_i32 > {};
|
| 40 |
+
template<> struct GKV_Base<int64_t >: GKV_Base_Type<int64_t, GGUF_TYPE_INT64, gguf_get_val_i64 > {};
|
| 41 |
+
template<> struct GKV_Base<float >: GKV_Base_Type<float, GGUF_TYPE_FLOAT32, gguf_get_val_f32 > {};
|
| 42 |
+
template<> struct GKV_Base<double >: GKV_Base_Type<double, GGUF_TYPE_FLOAT64, gguf_get_val_f64 > {};
|
| 43 |
+
template<> struct GKV_Base<const char *>: GKV_Base_Type<const char *, GGUF_TYPE_STRING, gguf_get_val_str > {};
|
| 44 |
+
|
| 45 |
+
template<> struct GKV_Base<std::string> {
|
| 46 |
+
static constexpr gguf_type gt = GGUF_TYPE_STRING;
|
| 47 |
+
|
| 48 |
+
static std::string getter(const gguf_context * ctx, const int kid) {
|
| 49 |
+
return gguf_get_val_str(ctx, kid);
|
| 50 |
+
}
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
struct ArrayInfo {
|
| 54 |
+
const gguf_type gt;
|
| 55 |
+
const size_t length;
|
| 56 |
+
const void * data;
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
template<> struct GKV_Base<ArrayInfo> {
|
| 60 |
+
public:
|
| 61 |
+
static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
|
| 62 |
+
static ArrayInfo getter(const gguf_context *ctx, const int k) {
|
| 63 |
+
return ArrayInfo {
|
| 64 |
+
gguf_get_arr_type(ctx, k),
|
| 65 |
+
size_t(gguf_get_arr_n(ctx, k)),
|
| 66 |
+
gguf_get_arr_data(ctx, k),
|
| 67 |
+
};
|
| 68 |
+
}
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
template<typename T>
|
| 72 |
+
class GKV : public GKV_Base<T> {
|
| 73 |
+
GKV() = delete;
|
| 74 |
+
|
| 75 |
+
public:
|
| 76 |
+
static T get_kv(const gguf_context * ctx, const int k) {
|
| 77 |
+
const enum gguf_type kt = gguf_get_kv_type(ctx, k);
|
| 78 |
+
|
| 79 |
+
if (kt != GKV::gt) {
|
| 80 |
+
throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
|
| 81 |
+
gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt)));
|
| 82 |
+
}
|
| 83 |
+
return GKV::getter(ctx, k);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
static const char * override_type_to_str(const llama_model_kv_override_type ty) {
|
| 87 |
+
switch (ty) {
|
| 88 |
+
case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool";
|
| 89 |
+
case LLAMA_KV_OVERRIDE_TYPE_INT: return "int";
|
| 90 |
+
case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float";
|
| 91 |
+
case LLAMA_KV_OVERRIDE_TYPE_STR: return "str";
|
| 92 |
+
}
|
| 93 |
+
return "unknown";
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) {
|
| 97 |
+
if (!ovrd) { return false; }
|
| 98 |
+
if (ovrd->tag == expected_type) {
|
| 99 |
+
LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
|
| 100 |
+
__func__, override_type_to_str(ovrd->tag), ovrd->key);
|
| 101 |
+
switch (ovrd->tag) {
|
| 102 |
+
case LLAMA_KV_OVERRIDE_TYPE_BOOL: {
|
| 103 |
+
LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false");
|
| 104 |
+
} break;
|
| 105 |
+
case LLAMA_KV_OVERRIDE_TYPE_INT: {
|
| 106 |
+
LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64);
|
| 107 |
+
} break;
|
| 108 |
+
case LLAMA_KV_OVERRIDE_TYPE_FLOAT: {
|
| 109 |
+
LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64);
|
| 110 |
+
} break;
|
| 111 |
+
case LLAMA_KV_OVERRIDE_TYPE_STR: {
|
| 112 |
+
LLAMA_LOG_INFO("%s\n", ovrd->val_str);
|
| 113 |
+
} break;
|
| 114 |
+
default:
|
| 115 |
+
// Shouldn't be possible to end up here, but just in case...
|
| 116 |
+
throw std::runtime_error(
|
| 117 |
+
format("Unsupported attempt to override %s type for metadata key %s\n",
|
| 118 |
+
override_type_to_str(ovrd->tag), ovrd->key));
|
| 119 |
+
}
|
| 120 |
+
return true;
|
| 121 |
+
}
|
| 122 |
+
LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
|
| 123 |
+
__func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag));
|
| 124 |
+
return false;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
template<typename OT>
|
| 128 |
+
static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
|
| 129 |
+
try_override(OT & target, const struct llama_model_kv_override * ovrd) {
|
| 130 |
+
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) {
|
| 131 |
+
target = ovrd->val_bool;
|
| 132 |
+
return true;
|
| 133 |
+
}
|
| 134 |
+
return false;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
template<typename OT>
|
| 138 |
+
static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
|
| 139 |
+
try_override(OT & target, const struct llama_model_kv_override * ovrd) {
|
| 140 |
+
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) {
|
| 141 |
+
target = ovrd->val_i64;
|
| 142 |
+
return true;
|
| 143 |
+
}
|
| 144 |
+
return false;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
template<typename OT>
|
| 148 |
+
static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
|
| 149 |
+
try_override(T & target, const struct llama_model_kv_override * ovrd) {
|
| 150 |
+
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) {
|
| 151 |
+
target = ovrd->val_f64;
|
| 152 |
+
return true;
|
| 153 |
+
}
|
| 154 |
+
return false;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
template<typename OT>
|
| 158 |
+
static typename std::enable_if<std::is_same<OT, std::string>::value, bool>::type
|
| 159 |
+
try_override(T & target, const struct llama_model_kv_override * ovrd) {
|
| 160 |
+
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) {
|
| 161 |
+
target = ovrd->val_str;
|
| 162 |
+
return true;
|
| 163 |
+
}
|
| 164 |
+
return false;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
|
| 168 |
+
if (try_override<T>(target, ovrd)) {
|
| 169 |
+
return true;
|
| 170 |
+
}
|
| 171 |
+
if (k < 0) { return false; }
|
| 172 |
+
target = get_kv(ctx, k);
|
| 173 |
+
return true;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
|
| 177 |
+
return set(ctx, gguf_find_key(ctx, key), target, ovrd);
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
|
| 181 |
+
return set(ctx, key.c_str(), target, ovrd);
|
| 182 |
+
}
|
| 183 |
+
};
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
template<typename T>
|
| 187 |
+
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
| 188 |
+
llama_model_loader::get_arr_n(const std::string & key, T & result, bool required) {
|
| 189 |
+
const int kid = gguf_find_key(meta.get(), key.c_str());
|
| 190 |
+
|
| 191 |
+
if (kid < 0) {
|
| 192 |
+
if (required) {
|
| 193 |
+
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
|
| 194 |
+
}
|
| 195 |
+
return false;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
struct GGUFMeta::ArrayInfo arr_info =
|
| 199 |
+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
result = arr_info.length;
|
| 203 |
+
return true;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
template<typename T>
|
| 207 |
+
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
| 208 |
+
llama_model_loader::get_arr_n(enum llm_kv kid, T & result, bool required) {
|
| 209 |
+
return get_arr_n(llm_kv(kid), result, required);
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
template bool llama_model_loader::get_arr_n(enum llm_kv kid, uint32_t & result, bool required);
|
| 213 |
+
|
| 214 |
+
template<typename T>
|
| 215 |
+
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
|
| 216 |
+
const int kid = gguf_find_key(meta.get(), key.c_str());
|
| 217 |
+
|
| 218 |
+
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
|
| 219 |
+
if (required) {
|
| 220 |
+
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
| 221 |
+
}
|
| 222 |
+
return false;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
struct GGUFMeta::ArrayInfo arr_info =
|
| 226 |
+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
| 227 |
+
|
| 228 |
+
switch (arr_info.gt) {
|
| 229 |
+
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
| 230 |
+
case GGUF_TYPE_INT32: GGML_ASSERT(
|
| 231 |
+
(std::is_same<T, int32_t>::value) ||
|
| 232 |
+
(std::is_same<T, uint32_t>::value)); break;
|
| 233 |
+
default:
|
| 234 |
+
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
result.resize(arr_info.length);
|
| 238 |
+
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
|
| 239 |
+
|
| 240 |
+
return true;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
template<typename T, size_t N_MAX>
|
| 244 |
+
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
|
| 245 |
+
const int kid = gguf_find_key(meta.get(), key.c_str());
|
| 246 |
+
|
| 247 |
+
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
|
| 248 |
+
if (required) {
|
| 249 |
+
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
| 250 |
+
}
|
| 251 |
+
return false;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
struct GGUFMeta::ArrayInfo arr_info =
|
| 255 |
+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
| 256 |
+
|
| 257 |
+
switch (arr_info.gt) {
|
| 258 |
+
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
| 259 |
+
case GGUF_TYPE_INT32: GGML_ASSERT(
|
| 260 |
+
(std::is_same<T, int32_t>::value) ||
|
| 261 |
+
(std::is_same<T, uint32_t>::value)); break;
|
| 262 |
+
default:
|
| 263 |
+
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
if (arr_info.length > N_MAX) {
|
| 267 |
+
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
|
| 271 |
+
|
| 272 |
+
return true;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
template<typename T>
|
| 276 |
+
bool llama_model_loader::get_arr(enum llm_kv kid, T & result, bool required) {
|
| 277 |
+
return get_arr(llm_kv(kid), result, required);
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
template<typename T>
|
| 281 |
+
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
|
| 282 |
+
auto it = kv_overrides.find(key);
|
| 283 |
+
|
| 284 |
+
const struct llama_model_kv_override * override =
|
| 285 |
+
it != kv_overrides.end() ? &it->second : nullptr;
|
| 286 |
+
|
| 287 |
+
const bool found = GGUFMeta::GKV<T>::set(meta.get(), key, result, override);
|
| 288 |
+
|
| 289 |
+
if (required && !found) {
|
| 290 |
+
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
return found;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
template<typename T>
|
| 297 |
+
bool llama_model_loader::get_key(enum llm_kv kid, T & result, bool required) {
|
| 298 |
+
return get_key(llm_kv(kid), result, required);
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
template bool llama_model_loader::get_key<bool> (enum llm_kv kid, bool & result, bool required);
|
| 302 |
+
template bool llama_model_loader::get_key<float> (enum llm_kv kid, float & result, bool required);
|
| 303 |
+
template bool llama_model_loader::get_key<uint32_t> (enum llm_kv kid, uint32_t & result, bool required);
|
| 304 |
+
template bool llama_model_loader::get_key<std::string>(enum llm_kv kid, std::string & result, bool required);
|
| 305 |
+
|
| 306 |
+
template<>
|
| 307 |
+
bool llama_model_loader::get_key(enum llm_kv kid, enum llama_pooling_type & result, bool required) {
|
| 308 |
+
uint32_t tmp;
|
| 309 |
+
const bool found = get_key(kid, tmp, required);
|
| 310 |
+
if (found) {
|
| 311 |
+
result = (enum llama_pooling_type) tmp;
|
| 312 |
+
} else {
|
| 313 |
+
result = LLAMA_POOLING_TYPE_UNSPECIFIED;
|
| 314 |
+
}
|
| 315 |
+
return found;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
// get array of n <= N_MAX elements, or a single element repeated n times
|
| 319 |
+
template<typename T, size_t N_MAX>
|
| 320 |
+
bool llama_model_loader::get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, bool required) {
|
| 321 |
+
const int kid = gguf_find_key(meta.get(), key.c_str());
|
| 322 |
+
|
| 323 |
+
if (kid < 0) {
|
| 324 |
+
if (required) {
|
| 325 |
+
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
|
| 326 |
+
}
|
| 327 |
+
return false;
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
if (n > N_MAX) {
|
| 331 |
+
throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str()));
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
if (gguf_get_kv_type(meta.get(), kid) == GGUF_TYPE_ARRAY) {
|
| 335 |
+
struct GGUFMeta::ArrayInfo arr_info =
|
| 336 |
+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
| 337 |
+
|
| 338 |
+
if (n != arr_info.length) {
|
| 339 |
+
throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length));
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
return get_arr(key, result, required);
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
T value;
|
| 346 |
+
|
| 347 |
+
bool ok = get_key(key, value, required);
|
| 348 |
+
if (!ok) {
|
| 349 |
+
return false;
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
for (uint32_t i = 0; i < n; i++) {
|
| 353 |
+
result[i] = value;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
return true;
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
template<typename T>
|
| 360 |
+
bool llama_model_loader::get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required) {
|
| 361 |
+
return get_key_or_arr(llm_kv(kid), result, n, required);
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
// TODO: this is not very clever - figure out something better
|
| 365 |
+
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
|
| 366 |
+
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
|
| 367 |
+
|
| 368 |
+
llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
|
| 369 |
+
int trace = 0;
|
| 370 |
+
if (getenv("LLAMA_TRACE")) {
|
| 371 |
+
trace = atoi(getenv("LLAMA_TRACE"));
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
if (param_overrides_p != nullptr) {
|
| 375 |
+
for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) {
|
| 376 |
+
kv_overrides.insert({std::string(p->key), *p});
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
struct ggml_context * ctx = NULL;
|
| 381 |
+
struct gguf_init_params params = {
|
| 382 |
+
/*.no_alloc = */ true,
|
| 383 |
+
/*.ctx = */ &ctx,
|
| 384 |
+
};
|
| 385 |
+
|
| 386 |
+
meta.reset(gguf_init_from_file(fname.c_str(), params));
|
| 387 |
+
if (!meta) {
|
| 388 |
+
throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
| 392 |
+
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
|
| 393 |
+
|
| 394 |
+
files.emplace_back(new llama_file(fname.c_str(), "rb"));
|
| 395 |
+
contexts.emplace_back(ctx);
|
| 396 |
+
|
| 397 |
+
// Save tensors data offset of the main file.
|
| 398 |
+
// For subsidiary files, `meta` tensor data offset must not be used,
|
| 399 |
+
// so we build a unified tensors index for weights.
|
| 400 |
+
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
|
| 401 |
+
std::string tensor_name = std::string(cur->name);
|
| 402 |
+
// make sure there is no duplicated tensor names
|
| 403 |
+
if (weights_map.find(tensor_name) != weights_map.end()) {
|
| 404 |
+
throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur)));
|
| 405 |
+
}
|
| 406 |
+
n_elements += ggml_nelements(cur);
|
| 407 |
+
n_bytes += ggml_nbytes(cur);
|
| 408 |
+
weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur));
|
| 409 |
+
}
|
| 410 |
+
uint16_t n_split = 0;
|
| 411 |
+
get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false);
|
| 412 |
+
|
| 413 |
+
// Load additional GGML contexts
|
| 414 |
+
if (n_split > 1) {
|
| 415 |
+
uint16_t idx = 0;
|
| 416 |
+
get_key(llm_kv(LLM_KV_SPLIT_NO), idx);
|
| 417 |
+
if (idx != 0) {
|
| 418 |
+
throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx));
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
std::vector<char> split_prefix(llama_path_max(), 0);
|
| 422 |
+
if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) {
|
| 423 |
+
throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
if (trace > 0) {
|
| 427 |
+
LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
std::vector<char> split_path(llama_path_max(), 0);
|
| 431 |
+
for (idx = 1; idx < n_split; idx++) {
|
| 432 |
+
llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split);
|
| 433 |
+
|
| 434 |
+
struct gguf_init_params split_params = {
|
| 435 |
+
/*.no_alloc = */ true,
|
| 436 |
+
/*.ctx = */ &ctx,
|
| 437 |
+
};
|
| 438 |
+
gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) };
|
| 439 |
+
if (!ctx_gguf) {
|
| 440 |
+
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data()));
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
files.emplace_back(new llama_file(split_path.data(), "rb"));
|
| 444 |
+
contexts.emplace_back(ctx);
|
| 445 |
+
|
| 446 |
+
// Save tensors data offset info of the shard.
|
| 447 |
+
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
|
| 448 |
+
std::string tensor_name = std::string(cur->name);
|
| 449 |
+
// make sure there is no duplicated tensor names
|
| 450 |
+
if (weights_map.find(tensor_name) != weights_map.end()) {
|
| 451 |
+
throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur)));
|
| 452 |
+
}
|
| 453 |
+
n_elements += ggml_nelements(cur);
|
| 454 |
+
n_bytes += ggml_nbytes(cur);
|
| 455 |
+
weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur));
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors);
|
| 460 |
+
|
| 461 |
+
// sanity check
|
| 462 |
+
{
|
| 463 |
+
const int n_tensors_loaded = (int) weights_map.size();
|
| 464 |
+
if (n_tensors != n_tensors_loaded) {
|
| 465 |
+
throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded));
|
| 466 |
+
}
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1);
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
n_kv = gguf_get_n_kv(meta.get());
|
| 473 |
+
n_tensors = weights_map.size();
|
| 474 |
+
|
| 475 |
+
fver = (enum llama_fver) gguf_get_version(meta.get());
|
| 476 |
+
|
| 477 |
+
LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n",
|
| 478 |
+
__func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver));
|
| 479 |
+
|
| 480 |
+
// determine file type based on the number of tensors for each quantization and print meta data
|
| 481 |
+
// TODO: make optional
|
| 482 |
+
{
|
| 483 |
+
std::map<enum ggml_type, uint32_t> n_type;
|
| 484 |
+
|
| 485 |
+
uint32_t n_type_max = 0;
|
| 486 |
+
enum ggml_type type_max = GGML_TYPE_F32;
|
| 487 |
+
|
| 488 |
+
for (const auto & it : weights_map) {
|
| 489 |
+
const llama_tensor_weight & w = it.second;
|
| 490 |
+
const ggml_tensor * tensor = w.tensor;
|
| 491 |
+
|
| 492 |
+
enum ggml_type type = tensor->type;
|
| 493 |
+
|
| 494 |
+
n_type[type]++;
|
| 495 |
+
|
| 496 |
+
if (n_type_max < n_type[type]) {
|
| 497 |
+
n_type_max = n_type[type];
|
| 498 |
+
type_max = type;
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
if (trace > 0) {
|
| 502 |
+
const uint16_t sid = w.idx;
|
| 503 |
+
LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str());
|
| 504 |
+
}
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
switch (type_max) {
|
| 508 |
+
case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break;
|
| 509 |
+
case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break;
|
| 510 |
+
case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break;
|
| 511 |
+
case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break;
|
| 512 |
+
case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break;
|
| 513 |
+
case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break;
|
| 514 |
+
case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break;
|
| 515 |
+
case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break;
|
| 516 |
+
case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break;
|
| 517 |
+
case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break;
|
| 518 |
+
case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break;
|
| 519 |
+
case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break;
|
| 520 |
+
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
|
| 521 |
+
case GGML_TYPE_TQ1_0: ftype = LLAMA_FTYPE_MOSTLY_TQ1_0; break;
|
| 522 |
+
case GGML_TYPE_TQ2_0: ftype = LLAMA_FTYPE_MOSTLY_TQ2_0; break;
|
| 523 |
+
case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
|
| 524 |
+
case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break;
|
| 525 |
+
case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break;
|
| 526 |
+
case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
|
| 527 |
+
case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break;
|
| 528 |
+
case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break;
|
| 529 |
+
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
|
| 530 |
+
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
|
| 531 |
+
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
|
| 532 |
+
default:
|
| 533 |
+
{
|
| 534 |
+
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
|
| 535 |
+
ftype = LLAMA_FTYPE_ALL_F32;
|
| 536 |
+
} break;
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
// this is a way to mark that we have "guessed" the file type
|
| 540 |
+
ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
|
| 541 |
+
|
| 542 |
+
{
|
| 543 |
+
const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV
|
| 544 |
+
if (kid >= 0) {
|
| 545 |
+
ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid);
|
| 546 |
+
}
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
|
| 550 |
+
|
| 551 |
+
for (int i = 0; i < n_kv; i++) {
|
| 552 |
+
const char * name = gguf_get_key(meta.get(), i);
|
| 553 |
+
const enum gguf_type type = gguf_get_kv_type(meta.get(), i);
|
| 554 |
+
const std::string type_name =
|
| 555 |
+
type == GGUF_TYPE_ARRAY
|
| 556 |
+
? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i))
|
| 557 |
+
: gguf_type_name(type);
|
| 558 |
+
|
| 559 |
+
std::string value = gguf_kv_to_str(meta.get(), i);
|
| 560 |
+
const size_t MAX_VALUE_LEN = 40;
|
| 561 |
+
if (value.size() > MAX_VALUE_LEN) {
|
| 562 |
+
value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
|
| 563 |
+
}
|
| 564 |
+
replace_all(value, "\n", "\\n");
|
| 565 |
+
|
| 566 |
+
LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
// print type counts
|
| 570 |
+
for (auto & kv : n_type) {
|
| 571 |
+
if (kv.second == 0) {
|
| 572 |
+
continue;
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second);
|
| 576 |
+
}
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
if (!llama_mmap::SUPPORTED) {
|
| 580 |
+
LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__);
|
| 581 |
+
use_mmap = false;
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
this->use_mmap = use_mmap;
|
| 585 |
+
this->check_tensors = check_tensors;
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
std::string llama_model_loader::get_arch_name() const {
|
| 589 |
+
return arch_name;
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
enum llm_arch llama_model_loader::get_arch() const {
|
| 593 |
+
return llm_kv.arch;
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
const llama_model_loader::llama_tensor_weight * llama_model_loader::get_weight(const char * name) const {
|
| 597 |
+
auto pos = weights_map.find(name);
|
| 598 |
+
if (pos != weights_map.end()) {
|
| 599 |
+
return &pos->second;
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
return nullptr;
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
const llama_model_loader::llama_tensor_weight & llama_model_loader::require_weight(const char * name) const {
|
| 606 |
+
const llama_tensor_weight * weight = get_weight(name);
|
| 607 |
+
if (!weight) {
|
| 608 |
+
throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
|
| 609 |
+
}
|
| 610 |
+
return *weight;
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
struct ggml_tensor * llama_model_loader::get_tensor_meta(const char * name) const {
|
| 614 |
+
const auto * weight = get_weight(name);
|
| 615 |
+
if (!weight) {
|
| 616 |
+
return nullptr;
|
| 617 |
+
}
|
| 618 |
+
return weight->tensor;
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
struct ggml_tensor * llama_model_loader::require_tensor_meta(const std::string & name) const {
|
| 622 |
+
struct ggml_tensor * tensor = get_tensor_meta(name.c_str());
|
| 623 |
+
if (!tensor) {
|
| 624 |
+
throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
|
| 625 |
+
}
|
| 626 |
+
return tensor;
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const {
|
| 630 |
+
const struct ggml_tensor * cur = get_tensor_meta(name.c_str());
|
| 631 |
+
|
| 632 |
+
if (cur == NULL) {
|
| 633 |
+
if (!required) {
|
| 634 |
+
return NULL;
|
| 635 |
+
}
|
| 636 |
+
throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
{
|
| 640 |
+
bool is_ok = true;
|
| 641 |
+
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
|
| 642 |
+
if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) {
|
| 643 |
+
is_ok = false;
|
| 644 |
+
break;
|
| 645 |
+
}
|
| 646 |
+
}
|
| 647 |
+
if (!is_ok) {
|
| 648 |
+
throw std::runtime_error(
|
| 649 |
+
format("%s: tensor '%s' has wrong shape; expected %s, got %s",
|
| 650 |
+
__func__, name.c_str(),
|
| 651 |
+
llama_format_tensor_shape(ne).c_str(),
|
| 652 |
+
llama_format_tensor_shape(cur).c_str()));
|
| 653 |
+
}
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
return cur;
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags) {
|
| 660 |
+
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
|
| 661 |
+
|
| 662 |
+
if (cur == NULL) {
|
| 663 |
+
return NULL;
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
bool duplicated = flags & TENSOR_DUPLICATED;
|
| 667 |
+
|
| 668 |
+
struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur);
|
| 669 |
+
ggml_set_name(tensor, ggml_get_name(cur));
|
| 670 |
+
|
| 671 |
+
if (duplicated) {
|
| 672 |
+
size_data += ggml_nbytes(cur);
|
| 673 |
+
} else {
|
| 674 |
+
n_created++;
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
return tensor;
|
| 678 |
+
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required) {
|
| 682 |
+
const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
|
| 683 |
+
|
| 684 |
+
if (cur == NULL) {
|
| 685 |
+
return NULL;
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
if (cur->type != base->type) {
|
| 689 |
+
throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type)));
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
std::array<int64_t, GGML_MAX_DIMS> dims;
|
| 693 |
+
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
|
| 694 |
+
dims[i] = i < ne.size() ? ne.begin()[i] : 1;
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
struct ggml_tensor * tensor = ggml_view_4d(ctx, base,
|
| 698 |
+
dims[0], dims[1], dims[2], dims[3],
|
| 699 |
+
cur->nb[1], cur->nb[2], cur->nb[3],
|
| 700 |
+
offset);
|
| 701 |
+
|
| 702 |
+
ggml_set_name(tensor, name.c_str());
|
| 703 |
+
|
| 704 |
+
n_created++;
|
| 705 |
+
|
| 706 |
+
return tensor;
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
void llama_model_loader::done_getting_tensors() const {
|
| 710 |
+
if (n_created != n_tensors) {
|
| 711 |
+
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
|
| 712 |
+
}
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps) {
|
| 716 |
+
if (use_mmap) {
|
| 717 |
+
mappings.reserve(files.size());
|
| 718 |
+
mmaps_used.reserve(files.size());
|
| 719 |
+
for (const auto & file : files) {
|
| 720 |
+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
|
| 721 |
+
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
| 722 |
+
std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn()));
|
| 723 |
+
mmaps_used.emplace_back(mapping->size(), 0);
|
| 724 |
+
if (mlock_mmaps) {
|
| 725 |
+
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
|
| 726 |
+
mlock_mmap->init(mapping->addr());
|
| 727 |
+
mlock_mmaps->emplace_back(std::move(mlock_mmap));
|
| 728 |
+
}
|
| 729 |
+
mappings.emplace_back(std::move(mapping));
|
| 730 |
+
}
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
// compute the total size of all tensors for progress reporting
|
| 734 |
+
for (const auto & it : weights_map) {
|
| 735 |
+
size_data += ggml_nbytes(it.second.tensor);
|
| 736 |
+
}
|
| 737 |
+
}
|
| 738 |
+
|
| 739 |
+
void llama_model_loader::get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const {
|
| 740 |
+
GGML_ASSERT(!mappings.empty());
|
| 741 |
+
const auto & mapping = mappings.at(idx);
|
| 742 |
+
|
| 743 |
+
*first = mapping->size();
|
| 744 |
+
*last = 0;
|
| 745 |
+
*addr = mapping->addr();
|
| 746 |
+
for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
|
| 747 |
+
const auto * weight = get_weight(ggml_get_name(tensor));
|
| 748 |
+
if (!weight || weight->idx != idx) {
|
| 749 |
+
continue;
|
| 750 |
+
}
|
| 751 |
+
*first = std::min(*first, weight->offs);
|
| 752 |
+
*last = std::max(*last, weight->offs + ggml_nbytes(tensor));
|
| 753 |
+
}
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
void llama_model_loader::load_data_for(struct ggml_tensor * cur) const {
|
| 757 |
+
const auto & w = require_weight(ggml_get_name(cur));
|
| 758 |
+
|
| 759 |
+
if (use_mmap) {
|
| 760 |
+
const auto & mapping = mappings.at(w.idx);
|
| 761 |
+
if (cur->data == nullptr) {
|
| 762 |
+
cur->data = (uint8_t *)mapping->addr() + w.offs;
|
| 763 |
+
} else {
|
| 764 |
+
memcpy(cur->data, (uint8_t *)mapping->addr() + w.offs, ggml_nbytes(cur));
|
| 765 |
+
}
|
| 766 |
+
} else {
|
| 767 |
+
GGML_ASSERT(cur->data != nullptr);
|
| 768 |
+
GGML_ASSERT(w.idx < files.size());
|
| 769 |
+
const auto & file = files.at(w.idx);
|
| 770 |
+
file->seek(w.offs, SEEK_SET);
|
| 771 |
+
file->read_raw(cur->data, ggml_nbytes(cur));
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) {
|
| 775 |
+
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
|
| 776 |
+
}
|
| 777 |
+
}
|
| 778 |
+
|
| 779 |
+
bool llama_model_loader::load_all_data(
|
| 780 |
+
struct ggml_context * ctx,
|
| 781 |
+
llama_buf_map & bufs,
|
| 782 |
+
llama_mlocks * lmlocks,
|
| 783 |
+
llama_progress_callback progress_callback,
|
| 784 |
+
void * progress_callback_user_data) {
|
| 785 |
+
GGML_ASSERT(size_data != 0 && "call init_mappings() first");
|
| 786 |
+
|
| 787 |
+
std::vector<no_init<uint8_t>> read_buf;
|
| 788 |
+
std::vector<std::future<std::pair<ggml_tensor *, bool>>> validation_result;
|
| 789 |
+
|
| 790 |
+
// 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
|
| 791 |
+
// NVMe raid configurations might require more / larger buffers.
|
| 792 |
+
constexpr size_t n_buffers = 4;
|
| 793 |
+
constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
|
| 794 |
+
|
| 795 |
+
std::vector<ggml_backend_buffer_t> host_buffers;
|
| 796 |
+
std::vector<ggml_backend_event_t> events;
|
| 797 |
+
std::vector<void *> host_ptrs;
|
| 798 |
+
size_t buffer_idx = 0; // buffer to use for async loads
|
| 799 |
+
ggml_backend_t upload_backend = [&](const char * func) -> ggml_backend_t {
|
| 800 |
+
if (use_mmap || check_tensors) {
|
| 801 |
+
return nullptr;
|
| 802 |
+
}
|
| 803 |
+
// When not using mmaped io use async uploads from pinned memory to GPU memory.
|
| 804 |
+
// First determine if the backend supports the necessary features for async uploads.
|
| 805 |
+
auto * buf = bufs.count(0) ? bufs.at(0) : nullptr;
|
| 806 |
+
if (!buf) {
|
| 807 |
+
LLAMA_LOG_DEBUG("%s: no buffer found for async uploads\n", func);
|
| 808 |
+
return nullptr;
|
| 809 |
+
}
|
| 810 |
+
|
| 811 |
+
auto * buft = ggml_backend_buffer_get_type(buf);
|
| 812 |
+
auto * dev = ggml_backend_buft_get_device(buft);
|
| 813 |
+
if (!dev) {
|
| 814 |
+
LLAMA_LOG_DEBUG("%s: no device found for buffer type %s for async uploads\n", func,
|
| 815 |
+
ggml_backend_buft_name(buft));
|
| 816 |
+
return nullptr;
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
if (buft != ggml_backend_dev_buffer_type(dev)) {
|
| 820 |
+
LLAMA_LOG_DEBUG("%s: buffer type %s is not the default buffer type for device %s for async uploads\n", func,
|
| 821 |
+
ggml_backend_buft_name(buft), ggml_backend_dev_name(dev));
|
| 822 |
+
return nullptr;
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
ggml_backend_dev_props props;
|
| 826 |
+
ggml_backend_dev_get_props(dev, &props);
|
| 827 |
+
if (!props.caps.async || !props.caps.host_buffer || !props.caps.events) {
|
| 828 |
+
LLAMA_LOG_DEBUG("%s: device %s does not support async, host buffers or events\n", func,
|
| 829 |
+
ggml_backend_dev_name(dev));
|
| 830 |
+
return nullptr;
|
| 831 |
+
}
|
| 832 |
+
|
| 833 |
+
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
|
| 834 |
+
if (!host_buft) {
|
| 835 |
+
LLAMA_LOG_DEBUG("%s: no host buffer type found for device %s\n", func,
|
| 836 |
+
ggml_backend_dev_name(dev));
|
| 837 |
+
return nullptr;
|
| 838 |
+
}
|
| 839 |
+
|
| 840 |
+
// If the backend is supported, create pinned memory buffers and events for synchronisation.
|
| 841 |
+
for (size_t idx = 0; idx < n_buffers; ++idx) {
|
| 842 |
+
auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size);
|
| 843 |
+
if (!buf) {
|
| 844 |
+
LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func,
|
| 845 |
+
ggml_backend_dev_name(dev));
|
| 846 |
+
return nullptr;
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
host_buffers.emplace_back(buf);
|
| 850 |
+
host_ptrs.emplace_back(ggml_backend_buffer_get_base(buf));
|
| 851 |
+
|
| 852 |
+
auto * event = ggml_backend_event_new(dev);
|
| 853 |
+
if (!event) {
|
| 854 |
+
LLAMA_LOG_DEBUG("%s: failed to create event for async uploads for device %s\n", func,
|
| 855 |
+
ggml_backend_dev_name(dev));
|
| 856 |
+
return nullptr;
|
| 857 |
+
}
|
| 858 |
+
|
| 859 |
+
events.emplace_back(event);
|
| 860 |
+
}
|
| 861 |
+
|
| 862 |
+
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
| 863 |
+
if (!backend) {
|
| 864 |
+
LLAMA_LOG_DEBUG("%s: failed to initialize backend for device %s for async uploads\n", func,
|
| 865 |
+
ggml_backend_dev_name(dev));
|
| 866 |
+
return nullptr;
|
| 867 |
+
}
|
| 868 |
+
|
| 869 |
+
return backend;
|
| 870 |
+
}(__func__);
|
| 871 |
+
|
| 872 |
+
if (upload_backend) {
|
| 873 |
+
LLAMA_LOG_DEBUG("%s: using async uploads for device %s, buffer type %s, backend %s\n", __func__,
|
| 874 |
+
ggml_backend_dev_name(ggml_backend_get_device(upload_backend)),
|
| 875 |
+
ggml_backend_buft_name(ggml_backend_buffer_get_type(bufs.at(0))),
|
| 876 |
+
ggml_backend_name(upload_backend));
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
|
| 880 |
+
const auto * weight = get_weight(ggml_get_name(cur));
|
| 881 |
+
if (weight == nullptr) {
|
| 882 |
+
// this can happen with split experts models
|
| 883 |
+
continue;
|
| 884 |
+
}
|
| 885 |
+
|
| 886 |
+
if (progress_callback) {
|
| 887 |
+
if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) {
|
| 888 |
+
return false;
|
| 889 |
+
}
|
| 890 |
+
}
|
| 891 |
+
|
| 892 |
+
size_t n_size = ggml_nbytes(cur);
|
| 893 |
+
|
| 894 |
+
if (use_mmap) {
|
| 895 |
+
const auto & mapping = mappings.at(weight->idx);
|
| 896 |
+
ggml_backend_buffer_t buf_mmap = nullptr;
|
| 897 |
+
if (bufs.count(weight->idx)) {
|
| 898 |
+
buf_mmap = bufs.at(weight->idx);
|
| 899 |
+
}
|
| 900 |
+
uint8_t * data = (uint8_t *) mapping->addr() + weight->offs;
|
| 901 |
+
|
| 902 |
+
if (check_tensors) {
|
| 903 |
+
validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] {
|
| 904 |
+
return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size));
|
| 905 |
+
}));
|
| 906 |
+
}
|
| 907 |
+
|
| 908 |
+
GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
|
| 909 |
+
if (buf_mmap && cur->data == nullptr) {
|
| 910 |
+
ggml_backend_tensor_alloc(buf_mmap, cur, data);
|
| 911 |
+
if (lmlocks) {
|
| 912 |
+
const auto & lmlock = lmlocks->at(weight->idx);
|
| 913 |
+
lmlock->grow_to(weight->offs + n_size);
|
| 914 |
+
}
|
| 915 |
+
|
| 916 |
+
auto & mmap_used = mmaps_used[weight->idx];
|
| 917 |
+
mmap_used.first = std::min(mmap_used.first, weight->offs);
|
| 918 |
+
mmap_used.second = std::max(mmap_used.second, weight->offs + n_size);
|
| 919 |
+
} else {
|
| 920 |
+
ggml_backend_tensor_set(cur, data, 0, n_size);
|
| 921 |
+
}
|
| 922 |
+
} else {
|
| 923 |
+
const auto & file = files.at(weight->idx);
|
| 924 |
+
if (ggml_backend_buffer_is_host(cur->buffer)) {
|
| 925 |
+
file->seek(weight->offs, SEEK_SET);
|
| 926 |
+
file->read_raw(cur->data, n_size);
|
| 927 |
+
if (check_tensors) {
|
| 928 |
+
validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
|
| 929 |
+
return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
|
| 930 |
+
}));
|
| 931 |
+
}
|
| 932 |
+
} else {
|
| 933 |
+
// If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
|
| 934 |
+
if (upload_backend) {
|
| 935 |
+
file->seek(weight->offs, SEEK_SET);
|
| 936 |
+
|
| 937 |
+
size_t bytes_read = 0;
|
| 938 |
+
|
| 939 |
+
while (bytes_read < n_size) {
|
| 940 |
+
size_t read_iteration = std::min<size_t>(buffer_size, n_size - bytes_read);
|
| 941 |
+
|
| 942 |
+
ggml_backend_event_synchronize(events[buffer_idx]);
|
| 943 |
+
file->read_raw(host_ptrs[buffer_idx], read_iteration);
|
| 944 |
+
ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
|
| 945 |
+
ggml_backend_event_record(events[buffer_idx], upload_backend);
|
| 946 |
+
|
| 947 |
+
bytes_read += read_iteration;
|
| 948 |
+
++buffer_idx;
|
| 949 |
+
buffer_idx %= n_buffers;
|
| 950 |
+
}
|
| 951 |
+
} else {
|
| 952 |
+
read_buf.resize(n_size);
|
| 953 |
+
file->seek(weight->offs, SEEK_SET);
|
| 954 |
+
file->read_raw(read_buf.data(), n_size);
|
| 955 |
+
ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
|
| 956 |
+
if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
|
| 957 |
+
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
|
| 958 |
+
}
|
| 959 |
+
}
|
| 960 |
+
}
|
| 961 |
+
}
|
| 962 |
+
|
| 963 |
+
size_done += n_size;
|
| 964 |
+
}
|
| 965 |
+
|
| 966 |
+
// free temporary resources used for async uploads
|
| 967 |
+
for (auto * event : events) {
|
| 968 |
+
ggml_backend_event_synchronize(event);
|
| 969 |
+
ggml_backend_event_free(event);
|
| 970 |
+
}
|
| 971 |
+
for (auto * buf : host_buffers) {
|
| 972 |
+
ggml_backend_buffer_free(buf);
|
| 973 |
+
}
|
| 974 |
+
ggml_backend_free(upload_backend);
|
| 975 |
+
|
| 976 |
+
// check validation results
|
| 977 |
+
bool validation_failed = false;
|
| 978 |
+
for (auto & future : validation_result) {
|
| 979 |
+
auto result = future.get();
|
| 980 |
+
if (!result.second) {
|
| 981 |
+
LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first));
|
| 982 |
+
validation_failed = true;
|
| 983 |
+
}
|
| 984 |
+
}
|
| 985 |
+
if (validation_failed) {
|
| 986 |
+
throw std::runtime_error("found tensors with invalid data");
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
// check if this is the last call and do final cleanup
|
| 990 |
+
if (size_done >= size_data) {
|
| 991 |
+
// unmap offloaded tensors and metadata
|
| 992 |
+
if (use_mmap) {
|
| 993 |
+
for (uint32_t idx = 0; idx < mappings.size(); idx++) {
|
| 994 |
+
const auto & mmap_used = mmaps_used.at(idx);
|
| 995 |
+
auto & mapping = mappings.at(idx);
|
| 996 |
+
mapping->unmap_fragment(0, mmap_used.first);
|
| 997 |
+
if (mmap_used.second != 0) {
|
| 998 |
+
mapping->unmap_fragment(mmap_used.second, mapping->size());
|
| 999 |
+
}
|
| 1000 |
+
}
|
| 1001 |
+
}
|
| 1002 |
+
if (progress_callback) {
|
| 1003 |
+
// Even though the model is done loading, we still honor
|
| 1004 |
+
// cancellation since we need to free allocations.
|
| 1005 |
+
return progress_callback(1.0f, progress_callback_user_data);
|
| 1006 |
+
}
|
| 1007 |
+
}
|
| 1008 |
+
|
| 1009 |
+
return true;
|
| 1010 |
+
}
|
examples/talk-llama/llama-model-loader.h
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
|
| 5 |
+
#include "llama-impl.h"
|
| 6 |
+
#include "llama-arch.h"
|
| 7 |
+
#include "llama-mmap.h"
|
| 8 |
+
|
| 9 |
+
#include "ggml-cpp.h"
|
| 10 |
+
|
| 11 |
+
#include <cstddef>
|
| 12 |
+
#include <map>
|
| 13 |
+
#include <stdexcept>
|
| 14 |
+
#include <unordered_map>
|
| 15 |
+
|
| 16 |
+
using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
|
| 17 |
+
|
| 18 |
+
enum llama_fver {
|
| 19 |
+
GGUF_FILE_VERSION_V1 = 1,
|
| 20 |
+
GGUF_FILE_VERSION_V2 = 2,
|
| 21 |
+
GGUF_FILE_VERSION_V3 = 3,
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
const char * llama_file_version_name(llama_fver version);
|
| 25 |
+
|
| 26 |
+
struct llama_model_loader {
|
| 27 |
+
// Holds information on a model weight
|
| 28 |
+
struct llama_tensor_weight {
|
| 29 |
+
uint16_t idx; // source file index
|
| 30 |
+
size_t offs; // tensor data offset in the original file
|
| 31 |
+
|
| 32 |
+
ggml_tensor * tensor;
|
| 33 |
+
|
| 34 |
+
llama_tensor_weight(const llama_file * file, uint16_t idx, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
|
| 35 |
+
const int tensor_idx = gguf_find_tensor(gguf_ctx, ggml_get_name(tensor));
|
| 36 |
+
if (tensor_idx < 0) {
|
| 37 |
+
throw std::runtime_error(format("tensor '%s' not found in the model", ggml_get_name(tensor)));
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
|
| 41 |
+
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) {
|
| 42 |
+
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", ggml_get_name(tensor)));
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
// custom comparator to sort weights more nicely by layer
|
| 48 |
+
struct weight_name_comparer {
|
| 49 |
+
bool operator()(const std::string & a, const std::string & b) const {
|
| 50 |
+
int a_layer = -1;
|
| 51 |
+
int b_layer = -1;
|
| 52 |
+
sscanf(a.c_str(), "blk.%d.", &a_layer);
|
| 53 |
+
sscanf(b.c_str(), "blk.%d.", &b_layer);
|
| 54 |
+
if (a_layer != b_layer) {
|
| 55 |
+
return a_layer < b_layer;
|
| 56 |
+
}
|
| 57 |
+
return a < b;
|
| 58 |
+
}
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
static const int TENSOR_NOT_REQUIRED = 1;
|
| 62 |
+
static const int TENSOR_DUPLICATED = 2;
|
| 63 |
+
|
| 64 |
+
int n_kv = 0;
|
| 65 |
+
int n_tensors = 0;
|
| 66 |
+
int n_created = 0;
|
| 67 |
+
|
| 68 |
+
uint64_t n_elements = 0;
|
| 69 |
+
size_t n_bytes = 0;
|
| 70 |
+
|
| 71 |
+
bool use_mmap = false;
|
| 72 |
+
bool check_tensors;
|
| 73 |
+
|
| 74 |
+
llama_files files;
|
| 75 |
+
llama_ftype ftype;
|
| 76 |
+
llama_fver fver;
|
| 77 |
+
|
| 78 |
+
llama_mmaps mappings;
|
| 79 |
+
|
| 80 |
+
std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
|
| 81 |
+
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
|
| 82 |
+
|
| 83 |
+
gguf_context_ptr meta;
|
| 84 |
+
std::vector<ggml_context_ptr> contexts;
|
| 85 |
+
|
| 86 |
+
std::string arch_name;
|
| 87 |
+
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
|
| 88 |
+
|
| 89 |
+
size_t size_done = 0;
|
| 90 |
+
size_t size_data = 0;
|
| 91 |
+
std::vector<std::pair<size_t, size_t>> mmaps_used;
|
| 92 |
+
|
| 93 |
+
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p);
|
| 94 |
+
|
| 95 |
+
template<typename T>
|
| 96 |
+
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
| 97 |
+
get_arr_n(const std::string & key, T & result, bool required = true);
|
| 98 |
+
|
| 99 |
+
template<typename T>
|
| 100 |
+
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
| 101 |
+
get_arr_n(enum llm_kv kid, T & result, bool required = true);
|
| 102 |
+
|
| 103 |
+
template<typename T>
|
| 104 |
+
bool get_arr(const std::string & key, std::vector<T> & result, bool required = true);
|
| 105 |
+
|
| 106 |
+
template<typename T, size_t N_MAX>
|
| 107 |
+
bool get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required = true);
|
| 108 |
+
|
| 109 |
+
template<typename T>
|
| 110 |
+
bool get_arr(enum llm_kv kid, T & result, bool required = true);
|
| 111 |
+
|
| 112 |
+
template<typename T>
|
| 113 |
+
bool get_key(const std::string & key, T & result, bool required = true);
|
| 114 |
+
|
| 115 |
+
template<typename T>
|
| 116 |
+
bool get_key(enum llm_kv kid, T & result, bool required = true);
|
| 117 |
+
|
| 118 |
+
template<typename T, size_t N_MAX>
|
| 119 |
+
bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, bool required = true);
|
| 120 |
+
|
| 121 |
+
template<typename T>
|
| 122 |
+
bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true);
|
| 123 |
+
|
| 124 |
+
std::string get_arch_name() const;
|
| 125 |
+
|
| 126 |
+
enum llm_arch get_arch() const;
|
| 127 |
+
|
| 128 |
+
const llama_tensor_weight * get_weight(const char * name) const;
|
| 129 |
+
|
| 130 |
+
const llama_tensor_weight & require_weight(const char * name) const;
|
| 131 |
+
|
| 132 |
+
struct ggml_tensor * get_tensor_meta(const char * name) const;
|
| 133 |
+
|
| 134 |
+
struct ggml_tensor * require_tensor_meta(const std::string & name) const;
|
| 135 |
+
|
| 136 |
+
const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const;
|
| 137 |
+
|
| 138 |
+
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags = 0);
|
| 139 |
+
|
| 140 |
+
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true);
|
| 141 |
+
|
| 142 |
+
void done_getting_tensors() const;
|
| 143 |
+
|
| 144 |
+
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr);
|
| 145 |
+
|
| 146 |
+
void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const;
|
| 147 |
+
|
| 148 |
+
// for backwards compatibility, does not support ggml-backend
|
| 149 |
+
void load_data_for(struct ggml_tensor * cur) const;
|
| 150 |
+
|
| 151 |
+
// Returns false if cancelled by progress_callback
|
| 152 |
+
bool load_all_data(
|
| 153 |
+
struct ggml_context * ctx,
|
| 154 |
+
llama_buf_map & bufs,
|
| 155 |
+
llama_mlocks * lmlocks,
|
| 156 |
+
llama_progress_callback progress_callback,
|
| 157 |
+
void * progress_callback_user_data);
|
| 158 |
+
};
|
examples/talk-llama/llama-model.cpp
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/llama-model.h
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
#include "llama-arch.h"
|
| 5 |
+
#include "llama-hparams.h"
|
| 6 |
+
#include "llama-vocab.h"
|
| 7 |
+
#include "llama-mmap.h"
|
| 8 |
+
|
| 9 |
+
#include "ggml-cpp.h"
|
| 10 |
+
|
| 11 |
+
#include <vector>
|
| 12 |
+
|
| 13 |
+
// available models
|
| 14 |
+
// TODO: this enum does not follow the enum naming convention
|
| 15 |
+
enum llm_type {
|
| 16 |
+
MODEL_UNKNOWN,
|
| 17 |
+
MODEL_14M,
|
| 18 |
+
MODEL_17M,
|
| 19 |
+
MODEL_22M,
|
| 20 |
+
MODEL_33M,
|
| 21 |
+
MODEL_60M,
|
| 22 |
+
MODEL_70M,
|
| 23 |
+
MODEL_80M,
|
| 24 |
+
MODEL_109M,
|
| 25 |
+
MODEL_137M,
|
| 26 |
+
MODEL_160M,
|
| 27 |
+
MODEL_220M,
|
| 28 |
+
MODEL_250M,
|
| 29 |
+
MODEL_270M,
|
| 30 |
+
MODEL_335M,
|
| 31 |
+
MODEL_410M,
|
| 32 |
+
MODEL_450M,
|
| 33 |
+
MODEL_770M,
|
| 34 |
+
MODEL_780M,
|
| 35 |
+
MODEL_0_5B,
|
| 36 |
+
MODEL_1B,
|
| 37 |
+
MODEL_1_3B,
|
| 38 |
+
MODEL_1_4B,
|
| 39 |
+
MODEL_1_5B,
|
| 40 |
+
MODEL_1_6B,
|
| 41 |
+
MODEL_2B,
|
| 42 |
+
MODEL_2_8B,
|
| 43 |
+
MODEL_3B,
|
| 44 |
+
MODEL_4B,
|
| 45 |
+
MODEL_6B,
|
| 46 |
+
MODEL_6_9B,
|
| 47 |
+
MODEL_7B,
|
| 48 |
+
MODEL_8B,
|
| 49 |
+
MODEL_9B,
|
| 50 |
+
MODEL_11B,
|
| 51 |
+
MODEL_12B,
|
| 52 |
+
MODEL_13B,
|
| 53 |
+
MODEL_14B,
|
| 54 |
+
MODEL_15B,
|
| 55 |
+
MODEL_16B,
|
| 56 |
+
MODEL_20B,
|
| 57 |
+
MODEL_30B,
|
| 58 |
+
MODEL_32B,
|
| 59 |
+
MODEL_34B,
|
| 60 |
+
MODEL_35B,
|
| 61 |
+
MODEL_40B,
|
| 62 |
+
MODEL_65B,
|
| 63 |
+
MODEL_70B,
|
| 64 |
+
MODEL_236B,
|
| 65 |
+
MODEL_314B,
|
| 66 |
+
MODEL_671B,
|
| 67 |
+
MODEL_SMALL,
|
| 68 |
+
MODEL_MEDIUM,
|
| 69 |
+
MODEL_LARGE,
|
| 70 |
+
MODEL_XL,
|
| 71 |
+
MODEL_A1_7B,
|
| 72 |
+
MODEL_A2_7B,
|
| 73 |
+
MODEL_8x7B,
|
| 74 |
+
MODEL_8x22B,
|
| 75 |
+
MODEL_16x12B,
|
| 76 |
+
MODEL_10B_128x3_66B,
|
| 77 |
+
MODEL_57B_A14B,
|
| 78 |
+
MODEL_27B,
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
struct llama_layer_posnet {
|
| 82 |
+
// resnet
|
| 83 |
+
struct ggml_tensor * norm1 = nullptr;
|
| 84 |
+
struct ggml_tensor * norm1_b = nullptr;
|
| 85 |
+
|
| 86 |
+
struct ggml_tensor * conv1 = nullptr;
|
| 87 |
+
struct ggml_tensor * conv1_b = nullptr;
|
| 88 |
+
|
| 89 |
+
struct ggml_tensor * norm2 = nullptr;
|
| 90 |
+
struct ggml_tensor * norm2_b = nullptr;
|
| 91 |
+
|
| 92 |
+
struct ggml_tensor * conv2 = nullptr;
|
| 93 |
+
struct ggml_tensor * conv2_b = nullptr;
|
| 94 |
+
|
| 95 |
+
// attention
|
| 96 |
+
struct ggml_tensor * attn_norm = nullptr;
|
| 97 |
+
struct ggml_tensor * attn_norm_b = nullptr;
|
| 98 |
+
|
| 99 |
+
struct ggml_tensor * attn_q = nullptr;
|
| 100 |
+
struct ggml_tensor * attn_q_b = nullptr;
|
| 101 |
+
|
| 102 |
+
struct ggml_tensor * attn_k = nullptr;
|
| 103 |
+
struct ggml_tensor * attn_k_b = nullptr;
|
| 104 |
+
|
| 105 |
+
struct ggml_tensor * attn_v = nullptr;
|
| 106 |
+
struct ggml_tensor * attn_v_b = nullptr;
|
| 107 |
+
|
| 108 |
+
struct ggml_tensor * attn_o = nullptr;
|
| 109 |
+
struct ggml_tensor * attn_o_b = nullptr;
|
| 110 |
+
|
| 111 |
+
// normalize
|
| 112 |
+
struct ggml_tensor * norm = nullptr;
|
| 113 |
+
struct ggml_tensor * norm_b = nullptr;
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
struct llama_layer_convnext {
|
| 117 |
+
struct ggml_tensor * dw = nullptr;
|
| 118 |
+
struct ggml_tensor * dw_b = nullptr;
|
| 119 |
+
|
| 120 |
+
struct ggml_tensor * norm = nullptr;
|
| 121 |
+
struct ggml_tensor * norm_b = nullptr;
|
| 122 |
+
|
| 123 |
+
struct ggml_tensor * pw1 = nullptr;
|
| 124 |
+
struct ggml_tensor * pw1_b = nullptr;
|
| 125 |
+
|
| 126 |
+
struct ggml_tensor * pw2 = nullptr;
|
| 127 |
+
struct ggml_tensor * pw2_b = nullptr;
|
| 128 |
+
|
| 129 |
+
struct ggml_tensor * gamma = nullptr;
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
struct llama_layer {
|
| 133 |
+
// normalization
|
| 134 |
+
struct ggml_tensor * attn_norm = nullptr;
|
| 135 |
+
struct ggml_tensor * attn_norm_b = nullptr;
|
| 136 |
+
struct ggml_tensor * attn_norm_2 = nullptr;
|
| 137 |
+
struct ggml_tensor * attn_norm_2_b = nullptr;
|
| 138 |
+
struct ggml_tensor * attn_q_norm = nullptr;
|
| 139 |
+
struct ggml_tensor * attn_q_norm_b = nullptr;
|
| 140 |
+
struct ggml_tensor * attn_k_norm = nullptr;
|
| 141 |
+
struct ggml_tensor * attn_k_norm_b = nullptr;
|
| 142 |
+
struct ggml_tensor * attn_out_norm = nullptr;
|
| 143 |
+
struct ggml_tensor * attn_out_norm_b = nullptr;
|
| 144 |
+
struct ggml_tensor * attn_q_a_norm = nullptr;
|
| 145 |
+
struct ggml_tensor * attn_kv_a_norm = nullptr;
|
| 146 |
+
struct ggml_tensor * attn_sub_norm = nullptr;
|
| 147 |
+
struct ggml_tensor * attn_post_norm = nullptr;
|
| 148 |
+
struct ggml_tensor * ffn_sub_norm = nullptr;
|
| 149 |
+
struct ggml_tensor * attn_norm_cross = nullptr;
|
| 150 |
+
struct ggml_tensor * attn_norm_enc = nullptr;
|
| 151 |
+
|
| 152 |
+
// attention
|
| 153 |
+
struct ggml_tensor * wq = nullptr;
|
| 154 |
+
struct ggml_tensor * wk = nullptr;
|
| 155 |
+
struct ggml_tensor * wv = nullptr;
|
| 156 |
+
struct ggml_tensor * wo = nullptr;
|
| 157 |
+
struct ggml_tensor * wqkv = nullptr;
|
| 158 |
+
struct ggml_tensor * wq_a = nullptr;
|
| 159 |
+
struct ggml_tensor * wq_b = nullptr;
|
| 160 |
+
struct ggml_tensor * wkv_a_mqa = nullptr;
|
| 161 |
+
struct ggml_tensor * wkv_b = nullptr;
|
| 162 |
+
struct ggml_tensor * wq_cross = nullptr;
|
| 163 |
+
struct ggml_tensor * wk_cross = nullptr;
|
| 164 |
+
struct ggml_tensor * wv_cross = nullptr;
|
| 165 |
+
struct ggml_tensor * wo_cross = nullptr;
|
| 166 |
+
struct ggml_tensor * wq_enc = nullptr;
|
| 167 |
+
struct ggml_tensor * wk_enc = nullptr;
|
| 168 |
+
struct ggml_tensor * wv_enc = nullptr;
|
| 169 |
+
struct ggml_tensor * wo_enc = nullptr;
|
| 170 |
+
|
| 171 |
+
// attention bias
|
| 172 |
+
struct ggml_tensor * bq = nullptr;
|
| 173 |
+
struct ggml_tensor * bk = nullptr;
|
| 174 |
+
struct ggml_tensor * bv = nullptr;
|
| 175 |
+
struct ggml_tensor * bo = nullptr;
|
| 176 |
+
struct ggml_tensor * bqkv = nullptr;
|
| 177 |
+
|
| 178 |
+
// relative position bias
|
| 179 |
+
struct ggml_tensor * attn_rel_b = nullptr;
|
| 180 |
+
struct ggml_tensor * attn_rel_b_enc = nullptr;
|
| 181 |
+
struct ggml_tensor * attn_rel_b_cross = nullptr;
|
| 182 |
+
|
| 183 |
+
// normalization
|
| 184 |
+
struct ggml_tensor * ffn_norm = nullptr;
|
| 185 |
+
struct ggml_tensor * ffn_norm_b = nullptr;
|
| 186 |
+
struct ggml_tensor * ffn_post_norm = nullptr;
|
| 187 |
+
struct ggml_tensor * layer_out_norm = nullptr;
|
| 188 |
+
struct ggml_tensor * layer_out_norm_b = nullptr;
|
| 189 |
+
struct ggml_tensor * ffn_norm_exps = nullptr;
|
| 190 |
+
struct ggml_tensor * ffn_norm_enc = nullptr;
|
| 191 |
+
|
| 192 |
+
// ff
|
| 193 |
+
struct ggml_tensor * ffn_gate = nullptr; // w1
|
| 194 |
+
struct ggml_tensor * ffn_down = nullptr; // w2
|
| 195 |
+
struct ggml_tensor * ffn_up = nullptr; // w3
|
| 196 |
+
struct ggml_tensor * ffn_gate_enc = nullptr;
|
| 197 |
+
struct ggml_tensor * ffn_down_enc = nullptr;
|
| 198 |
+
struct ggml_tensor * ffn_up_enc = nullptr;
|
| 199 |
+
|
| 200 |
+
// ff MoE
|
| 201 |
+
struct ggml_tensor * ffn_gate_inp = nullptr;
|
| 202 |
+
struct ggml_tensor * ffn_gate_exps = nullptr;
|
| 203 |
+
struct ggml_tensor * ffn_down_exps = nullptr;
|
| 204 |
+
struct ggml_tensor * ffn_up_exps = nullptr;
|
| 205 |
+
|
| 206 |
+
// ff shared expert (shexp)
|
| 207 |
+
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
|
| 208 |
+
struct ggml_tensor * ffn_gate_shexp = nullptr;
|
| 209 |
+
struct ggml_tensor * ffn_down_shexp = nullptr;
|
| 210 |
+
struct ggml_tensor * ffn_up_shexp = nullptr;
|
| 211 |
+
|
| 212 |
+
// ff bias
|
| 213 |
+
struct ggml_tensor * ffn_gate_b = nullptr;
|
| 214 |
+
struct ggml_tensor * ffn_down_b = nullptr; // b2
|
| 215 |
+
struct ggml_tensor * ffn_up_b = nullptr; // b3
|
| 216 |
+
struct ggml_tensor * ffn_act = nullptr;
|
| 217 |
+
struct ggml_tensor * ffn_exp_probs_b = nullptr;
|
| 218 |
+
|
| 219 |
+
// mamba proj
|
| 220 |
+
struct ggml_tensor * ssm_in = nullptr;
|
| 221 |
+
struct ggml_tensor * ssm_x = nullptr;
|
| 222 |
+
struct ggml_tensor * ssm_dt = nullptr;
|
| 223 |
+
struct ggml_tensor * ssm_out = nullptr;
|
| 224 |
+
|
| 225 |
+
// mamba
|
| 226 |
+
struct ggml_tensor * ssm_conv1d = nullptr;
|
| 227 |
+
struct ggml_tensor * ssm_a = nullptr;
|
| 228 |
+
struct ggml_tensor * ssm_d = nullptr;
|
| 229 |
+
|
| 230 |
+
// mamba bias
|
| 231 |
+
struct ggml_tensor * ssm_conv1d_b = nullptr;
|
| 232 |
+
struct ggml_tensor * ssm_dt_b = nullptr;
|
| 233 |
+
|
| 234 |
+
// rwkv
|
| 235 |
+
struct ggml_tensor * time_mix_w1 = nullptr;
|
| 236 |
+
struct ggml_tensor * time_mix_w2 = nullptr;
|
| 237 |
+
struct ggml_tensor * time_mix_lerp_x = nullptr;
|
| 238 |
+
struct ggml_tensor * time_mix_lerp_w = nullptr;
|
| 239 |
+
struct ggml_tensor * time_mix_lerp_k = nullptr;
|
| 240 |
+
struct ggml_tensor * time_mix_lerp_v = nullptr;
|
| 241 |
+
struct ggml_tensor * time_mix_lerp_r = nullptr;
|
| 242 |
+
struct ggml_tensor * time_mix_lerp_g = nullptr;
|
| 243 |
+
|
| 244 |
+
struct ggml_tensor * time_mix_first = nullptr;
|
| 245 |
+
struct ggml_tensor * time_mix_decay = nullptr;
|
| 246 |
+
struct ggml_tensor * time_mix_decay_w1 = nullptr;
|
| 247 |
+
struct ggml_tensor * time_mix_decay_w2 = nullptr;
|
| 248 |
+
struct ggml_tensor * time_mix_key = nullptr;
|
| 249 |
+
struct ggml_tensor * time_mix_value = nullptr;
|
| 250 |
+
struct ggml_tensor * time_mix_receptance = nullptr;
|
| 251 |
+
struct ggml_tensor * time_mix_gate = nullptr;
|
| 252 |
+
|
| 253 |
+
struct ggml_tensor * time_mix_ln = nullptr;
|
| 254 |
+
struct ggml_tensor * time_mix_ln_b = nullptr;
|
| 255 |
+
struct ggml_tensor * time_mix_output = nullptr;
|
| 256 |
+
|
| 257 |
+
struct ggml_tensor * channel_mix_lerp_k = nullptr;
|
| 258 |
+
struct ggml_tensor * channel_mix_lerp_r = nullptr;
|
| 259 |
+
|
| 260 |
+
struct ggml_tensor * channel_mix_key = nullptr;
|
| 261 |
+
struct ggml_tensor * channel_mix_receptance = nullptr;
|
| 262 |
+
struct ggml_tensor * channel_mix_value = nullptr;
|
| 263 |
+
|
| 264 |
+
// long rope factors
|
| 265 |
+
struct ggml_tensor * rope_long = nullptr;
|
| 266 |
+
struct ggml_tensor * rope_short = nullptr;
|
| 267 |
+
struct ggml_tensor * rope_freqs = nullptr;
|
| 268 |
+
|
| 269 |
+
// bitnet scale
|
| 270 |
+
struct ggml_tensor * wq_scale = nullptr;
|
| 271 |
+
struct ggml_tensor * wk_scale = nullptr;
|
| 272 |
+
struct ggml_tensor * wv_scale = nullptr;
|
| 273 |
+
struct ggml_tensor * wo_scale = nullptr;
|
| 274 |
+
struct ggml_tensor * ffn_gate_scale = nullptr;
|
| 275 |
+
struct ggml_tensor * ffn_up_scale = nullptr;
|
| 276 |
+
struct ggml_tensor * ffn_down_scale = nullptr;
|
| 277 |
+
|
| 278 |
+
struct llama_layer_posnet posnet;
|
| 279 |
+
|
| 280 |
+
struct llama_layer_convnext convnext;
|
| 281 |
+
};
|
| 282 |
+
|
| 283 |
+
struct llama_model {
|
| 284 |
+
llm_type type = MODEL_UNKNOWN;
|
| 285 |
+
llm_arch arch = LLM_ARCH_UNKNOWN;
|
| 286 |
+
|
| 287 |
+
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
|
| 288 |
+
|
| 289 |
+
std::string name = "n/a";
|
| 290 |
+
|
| 291 |
+
llama_hparams hparams = {};
|
| 292 |
+
llama_vocab vocab;
|
| 293 |
+
|
| 294 |
+
struct ggml_tensor * tok_embd = nullptr;
|
| 295 |
+
struct ggml_tensor * type_embd = nullptr;
|
| 296 |
+
struct ggml_tensor * pos_embd = nullptr;
|
| 297 |
+
struct ggml_tensor * tok_norm = nullptr;
|
| 298 |
+
struct ggml_tensor * tok_norm_b = nullptr;
|
| 299 |
+
|
| 300 |
+
struct ggml_tensor * output_norm = nullptr;
|
| 301 |
+
struct ggml_tensor * output_norm_b = nullptr;
|
| 302 |
+
struct ggml_tensor * output = nullptr;
|
| 303 |
+
struct ggml_tensor * output_b = nullptr;
|
| 304 |
+
struct ggml_tensor * output_norm_enc = nullptr;
|
| 305 |
+
|
| 306 |
+
// classifier
|
| 307 |
+
struct ggml_tensor * cls = nullptr;
|
| 308 |
+
struct ggml_tensor * cls_b = nullptr;
|
| 309 |
+
struct ggml_tensor * cls_out = nullptr;
|
| 310 |
+
struct ggml_tensor * cls_out_b = nullptr;
|
| 311 |
+
|
| 312 |
+
struct ggml_tensor * conv1d = nullptr;
|
| 313 |
+
struct ggml_tensor * conv1d_b = nullptr;
|
| 314 |
+
|
| 315 |
+
std::vector<llama_layer> layers;
|
| 316 |
+
|
| 317 |
+
// gguf metadata
|
| 318 |
+
std::unordered_map<std::string, std::string> gguf_kv;
|
| 319 |
+
|
| 320 |
+
llama_split_mode split_mode;
|
| 321 |
+
int main_gpu;
|
| 322 |
+
int n_gpu_layers;
|
| 323 |
+
|
| 324 |
+
std::vector<std::string> rpc_servers;
|
| 325 |
+
|
| 326 |
+
// list of devices used in this model
|
| 327 |
+
std::vector<ggml_backend_dev_t> devices;
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
// lists of buffer types used for each layer
|
| 331 |
+
using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
|
| 332 |
+
buft_list_t cpu_buft_list;
|
| 333 |
+
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
|
| 334 |
+
|
| 335 |
+
struct layer_dev {
|
| 336 |
+
ggml_backend_dev_t dev;
|
| 337 |
+
buft_list_t * buft_list;
|
| 338 |
+
};
|
| 339 |
+
|
| 340 |
+
layer_dev dev_input = {};
|
| 341 |
+
layer_dev dev_output = {};
|
| 342 |
+
std::vector<layer_dev> dev_layer;
|
| 343 |
+
|
| 344 |
+
// contexts where the model tensors metadata is stored
|
| 345 |
+
std::vector<ggml_context_ptr> ctxs;
|
| 346 |
+
|
| 347 |
+
// the model memory buffers for the tensor data
|
| 348 |
+
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 349 |
+
|
| 350 |
+
// model memory mapped files
|
| 351 |
+
llama_mmaps mappings;
|
| 352 |
+
|
| 353 |
+
// objects representing data potentially being locked in memory
|
| 354 |
+
llama_mlocks mlock_bufs;
|
| 355 |
+
llama_mlocks mlock_mmaps;
|
| 356 |
+
|
| 357 |
+
// for quantize-stats only
|
| 358 |
+
std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
|
| 359 |
+
|
| 360 |
+
int64_t t_load_us = 0;
|
| 361 |
+
int64_t t_start_us = 0;
|
| 362 |
+
|
| 363 |
+
// total number of parameters in the model
|
| 364 |
+
uint64_t n_elements = 0;
|
| 365 |
+
|
| 366 |
+
// total size of all the tensors in the model in bytes
|
| 367 |
+
size_t n_bytes = 0;
|
| 368 |
+
};
|
| 369 |
+
|
| 370 |
+
const char * llm_type_name(llm_type type);
|
| 371 |
+
|
| 372 |
+
std::string llama_model_arch_name (const llama_model & model);
|
| 373 |
+
std::string llama_model_type_name (const llama_model & model);
|
| 374 |
+
std::string llama_model_ftype_name(const llama_model & model);
|
| 375 |
+
|
| 376 |
+
// used by llama_adapter_cvec
|
| 377 |
+
ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
|
| 378 |
+
|
| 379 |
+
// used by llama_adapter_lora
|
| 380 |
+
struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name);
|
| 381 |
+
|
| 382 |
+
size_t llama_model_max_nodes(const llama_model & model);
|
| 383 |
+
|
| 384 |
+
struct llama_model_loader;
|
| 385 |
+
|
| 386 |
+
// TODO: become llama_model methods
|
| 387 |
+
void llm_load_stats (llama_model_loader & ml, llama_model & model);
|
| 388 |
+
void llm_load_arch (llama_model_loader & ml, llama_model & model);
|
| 389 |
+
void llm_load_hparams (llama_model_loader & ml, llama_model & model);
|
| 390 |
+
void llm_load_vocab (llama_model_loader & ml, llama_model & model);
|
| 391 |
+
void llm_load_print_meta(llama_model_loader & ml, llama_model & model);
|
examples/talk-llama/llama-quant.cpp
ADDED
|
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-quant.h"
|
| 2 |
+
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
#include "llama-model.h"
|
| 5 |
+
#include "llama-model-loader.h"
|
| 6 |
+
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <cmath>
|
| 9 |
+
#include <cstring>
|
| 10 |
+
#include <fstream>
|
| 11 |
+
#include <mutex>
|
| 12 |
+
#include <thread>
|
| 13 |
+
#include <unordered_map>
|
| 14 |
+
|
| 15 |
+
// TODO: replace with ggml API call
|
| 16 |
+
#define QK_K 256
|
| 17 |
+
|
| 18 |
+
static void zeros(std::ofstream & file, size_t n) {
|
| 19 |
+
char zero = 0;
|
| 20 |
+
for (size_t i = 0; i < n; ++i) {
|
| 21 |
+
file.write(&zero, 1);
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
struct quantize_state_impl {
|
| 26 |
+
const llama_model & model;
|
| 27 |
+
const llama_model_quantize_params * params;
|
| 28 |
+
|
| 29 |
+
int n_attention_wv = 0;
|
| 30 |
+
int n_ffn_down = 0;
|
| 31 |
+
int n_ffn_gate = 0;
|
| 32 |
+
int n_ffn_up = 0;
|
| 33 |
+
int i_attention_wv = 0;
|
| 34 |
+
int i_ffn_down = 0;
|
| 35 |
+
int i_ffn_gate = 0;
|
| 36 |
+
int i_ffn_up = 0;
|
| 37 |
+
|
| 38 |
+
int n_k_quantized = 0;
|
| 39 |
+
int n_fallback = 0;
|
| 40 |
+
|
| 41 |
+
bool has_imatrix = false;
|
| 42 |
+
|
| 43 |
+
// used to figure out if a model shares tok_embd with the output weight
|
| 44 |
+
bool has_output = false;
|
| 45 |
+
|
| 46 |
+
quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
|
| 47 |
+
: model(model)
|
| 48 |
+
, params(params)
|
| 49 |
+
{}
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
static void llama_tensor_dequantize_impl(
|
| 53 |
+
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
| 54 |
+
const size_t nelements, const int nthread
|
| 55 |
+
) {
|
| 56 |
+
if (output.size() < nelements) {
|
| 57 |
+
output.resize(nelements);
|
| 58 |
+
}
|
| 59 |
+
float * f32_output = (float *) output.data();
|
| 60 |
+
|
| 61 |
+
const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type);
|
| 62 |
+
if (ggml_is_quantized(tensor->type)) {
|
| 63 |
+
if (qtype->to_float == NULL) {
|
| 64 |
+
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
|
| 65 |
+
}
|
| 66 |
+
} else if (tensor->type != GGML_TYPE_F16 &&
|
| 67 |
+
tensor->type != GGML_TYPE_BF16) {
|
| 68 |
+
throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
if (nthread < 2) {
|
| 72 |
+
if (tensor->type == GGML_TYPE_F16) {
|
| 73 |
+
ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
|
| 74 |
+
} else if (tensor->type == GGML_TYPE_BF16) {
|
| 75 |
+
ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
|
| 76 |
+
} else if (ggml_is_quantized(tensor->type)) {
|
| 77 |
+
qtype->to_float(tensor->data, f32_output, nelements);
|
| 78 |
+
} else {
|
| 79 |
+
GGML_ABORT("fatal error"); // unreachable
|
| 80 |
+
}
|
| 81 |
+
return;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
size_t block_size;
|
| 85 |
+
if (tensor->type == GGML_TYPE_F16 ||
|
| 86 |
+
tensor->type == GGML_TYPE_BF16) {
|
| 87 |
+
block_size = 1;
|
| 88 |
+
} else {
|
| 89 |
+
block_size = (size_t)ggml_blck_size(tensor->type);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
size_t block_size_bytes = ggml_type_size(tensor->type);
|
| 93 |
+
|
| 94 |
+
GGML_ASSERT(nelements % block_size == 0);
|
| 95 |
+
size_t nblocks = nelements / block_size;
|
| 96 |
+
size_t blocks_per_thread = nblocks / nthread;
|
| 97 |
+
size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
|
| 98 |
+
|
| 99 |
+
size_t in_buff_offs = 0;
|
| 100 |
+
size_t out_buff_offs = 0;
|
| 101 |
+
|
| 102 |
+
for (int tnum = 0; tnum < nthread; tnum++) {
|
| 103 |
+
size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
|
| 104 |
+
size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
|
| 105 |
+
size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
|
| 106 |
+
|
| 107 |
+
auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
|
| 108 |
+
if (typ == GGML_TYPE_F16) {
|
| 109 |
+
ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
|
| 110 |
+
} else if (typ == GGML_TYPE_BF16) {
|
| 111 |
+
ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
|
| 112 |
+
} else {
|
| 113 |
+
qtype->to_float(inbuf, outbuf, nels);
|
| 114 |
+
}
|
| 115 |
+
};
|
| 116 |
+
workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
|
| 117 |
+
in_buff_offs += thr_block_bytes;
|
| 118 |
+
out_buff_offs += thr_elems;
|
| 119 |
+
}
|
| 120 |
+
for (auto & w : workers) { w.join(); }
|
| 121 |
+
workers.clear();
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
|
| 125 |
+
const std::string name = ggml_get_name(tensor);
|
| 126 |
+
|
| 127 |
+
// TODO: avoid hardcoded tensor names - use the TN_* constants
|
| 128 |
+
const llm_arch arch = qs.model.arch;
|
| 129 |
+
const auto tn = LLM_TN(arch);
|
| 130 |
+
|
| 131 |
+
auto use_more_bits = [](int i_layer, int n_layers) -> bool {
|
| 132 |
+
return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
|
| 133 |
+
};
|
| 134 |
+
const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
|
| 135 |
+
auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
|
| 136 |
+
if (n_expert > 1) {
|
| 137 |
+
// Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
|
| 138 |
+
// sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
|
| 139 |
+
// for getting the current layer as I initially thought, and we need to resort to parsing the
|
| 140 |
+
// tensor name.
|
| 141 |
+
if (sscanf(name, "blk.%d.", &i_layer) != 1) {
|
| 142 |
+
throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
|
| 143 |
+
}
|
| 144 |
+
if (i_layer < 0 || i_layer >= n_layer) {
|
| 145 |
+
throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
return std::make_pair(i_layer, n_layer);
|
| 149 |
+
};
|
| 150 |
+
|
| 151 |
+
// for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
|
| 152 |
+
// with the quantization of the output tensor
|
| 153 |
+
if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
|
| 154 |
+
if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
|
| 155 |
+
new_type = qs.params->output_tensor_type;
|
| 156 |
+
} else {
|
| 157 |
+
int nx = tensor->ne[0];
|
| 158 |
+
if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
|
| 159 |
+
new_type = GGML_TYPE_Q8_0;
|
| 160 |
+
}
|
| 161 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
| 162 |
+
ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ||
|
| 163 |
+
ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
|
| 164 |
+
new_type = GGML_TYPE_Q5_K;
|
| 165 |
+
}
|
| 166 |
+
else if (new_type != GGML_TYPE_Q8_0) {
|
| 167 |
+
new_type = GGML_TYPE_Q6_K;
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
} else if (name == "token_embd.weight") {
|
| 171 |
+
if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
|
| 172 |
+
new_type = qs.params->token_embedding_type;
|
| 173 |
+
} else {
|
| 174 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
|
| 175 |
+
ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
|
| 176 |
+
new_type = GGML_TYPE_Q2_K;
|
| 177 |
+
}
|
| 178 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
|
| 179 |
+
new_type = GGML_TYPE_IQ3_S;
|
| 180 |
+
}
|
| 181 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
| 182 |
+
new_type = GGML_TYPE_IQ3_S;
|
| 183 |
+
}
|
| 184 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
|
| 185 |
+
new_type = GGML_TYPE_Q4_K;
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
} else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
|
| 189 |
+
ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
|
| 190 |
+
if (name.find("attn_v.weight") != std::string::npos) {
|
| 191 |
+
if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
|
| 192 |
+
else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
|
| 193 |
+
++qs.i_attention_wv;
|
| 194 |
+
}
|
| 195 |
+
else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
|
| 196 |
+
new_type = GGML_TYPE_Q4_K;
|
| 197 |
+
}
|
| 198 |
+
else if (name.find("ffn_down") != std::string::npos) {
|
| 199 |
+
if (qs.i_ffn_down < qs.n_ffn_down/8) {
|
| 200 |
+
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
|
| 201 |
+
}
|
| 202 |
+
++qs.i_ffn_down;
|
| 203 |
+
}
|
| 204 |
+
else if (name.find("attn_output.weight") != std::string::npos) {
|
| 205 |
+
if (qs.model.hparams.n_expert == 8) {
|
| 206 |
+
new_type = GGML_TYPE_Q5_K;
|
| 207 |
+
} else {
|
| 208 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
|
| 209 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
} else if (name.find("attn_v.weight") != std::string::npos) {
|
| 213 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
|
| 214 |
+
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
|
| 215 |
+
}
|
| 216 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
|
| 217 |
+
new_type = GGML_TYPE_Q4_K;
|
| 218 |
+
}
|
| 219 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
| 220 |
+
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
|
| 221 |
+
}
|
| 222 |
+
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
|
| 223 |
+
new_type = GGML_TYPE_Q4_K;
|
| 224 |
+
}
|
| 225 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
|
| 226 |
+
new_type = GGML_TYPE_Q4_K;
|
| 227 |
+
}
|
| 228 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
| 229 |
+
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
|
| 230 |
+
}
|
| 231 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
| 232 |
+
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
|
| 233 |
+
new_type = GGML_TYPE_Q5_K;
|
| 234 |
+
}
|
| 235 |
+
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
| 236 |
+
use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
|
| 237 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
|
| 238 |
+
if (qs.model.type == MODEL_70B) {
|
| 239 |
+
// In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
|
| 240 |
+
// 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
|
| 241 |
+
// nearly negligible increase in model size by quantizing this tensor with more bits:
|
| 242 |
+
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
|
| 243 |
+
}
|
| 244 |
+
if (qs.model.hparams.n_expert == 8) {
|
| 245 |
+
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
| 246 |
+
// TODO: explore better strategies
|
| 247 |
+
new_type = GGML_TYPE_Q8_0;
|
| 248 |
+
}
|
| 249 |
+
++qs.i_attention_wv;
|
| 250 |
+
} else if (name.find("attn_k.weight") != std::string::npos) {
|
| 251 |
+
if (qs.model.hparams.n_expert == 8) {
|
| 252 |
+
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
| 253 |
+
// TODO: explore better strategies
|
| 254 |
+
new_type = GGML_TYPE_Q8_0;
|
| 255 |
+
}
|
| 256 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
|
| 257 |
+
new_type = GGML_TYPE_IQ3_XXS;
|
| 258 |
+
}
|
| 259 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
| 260 |
+
new_type = GGML_TYPE_IQ2_S;
|
| 261 |
+
}
|
| 262 |
+
} else if (name.find("attn_q.weight") != std::string::npos) {
|
| 263 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
|
| 264 |
+
new_type = GGML_TYPE_IQ3_XXS;
|
| 265 |
+
}
|
| 266 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
| 267 |
+
new_type = GGML_TYPE_IQ2_S;
|
| 268 |
+
}
|
| 269 |
+
} else if (name.find("ffn_down") != std::string::npos) {
|
| 270 |
+
auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
|
| 271 |
+
int i_layer = info.first, n_layer = info.second;
|
| 272 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
| 273 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
|
| 274 |
+
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
|
| 275 |
+
}
|
| 276 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
|
| 277 |
+
new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
|
| 278 |
+
}
|
| 279 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
| 280 |
+
new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
|
| 281 |
+
: arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
|
| 282 |
+
: GGML_TYPE_Q3_K;
|
| 283 |
+
}
|
| 284 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
|
| 285 |
+
(qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
|
| 286 |
+
new_type = GGML_TYPE_Q4_K;
|
| 287 |
+
}
|
| 288 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
|
| 289 |
+
new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
|
| 290 |
+
}
|
| 291 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
|
| 292 |
+
if (arch == LLM_ARCH_FALCON) {
|
| 293 |
+
new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
|
| 294 |
+
use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
|
| 295 |
+
} else {
|
| 296 |
+
if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
|
| 300 |
+
new_type = GGML_TYPE_Q5_K;
|
| 301 |
+
}
|
| 302 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
|
| 303 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
|
| 304 |
+
new_type = GGML_TYPE_Q5_K;
|
| 305 |
+
}
|
| 306 |
+
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
|
| 307 |
+
&& qs.has_imatrix && i_layer < n_layer/8) {
|
| 308 |
+
// Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
|
| 309 |
+
// We only do it when an imatrix is provided because a) we want to make sure that one can always get the
|
| 310 |
+
// same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
|
| 311 |
+
new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
|
| 312 |
+
}
|
| 313 |
+
++qs.i_ffn_down;
|
| 314 |
+
} else if (name.find("attn_output.weight") != std::string::npos) {
|
| 315 |
+
if (arch != LLM_ARCH_FALCON) {
|
| 316 |
+
if (qs.model.hparams.n_expert == 8) {
|
| 317 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
| 318 |
+
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
|
| 319 |
+
ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
| 320 |
+
ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
|
| 321 |
+
new_type = GGML_TYPE_Q5_K;
|
| 322 |
+
}
|
| 323 |
+
} else {
|
| 324 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
|
| 325 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
|
| 326 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
|
| 327 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
|
| 328 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K;
|
| 329 |
+
}
|
| 330 |
+
} else {
|
| 331 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
else if (name.find("attn_qkv.weight") != std::string::npos) {
|
| 335 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
|
| 336 |
+
new_type = GGML_TYPE_Q4_K;
|
| 337 |
+
}
|
| 338 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
|
| 339 |
+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
|
| 340 |
+
}
|
| 341 |
+
else if (name.find("ffn_gate") != std::string::npos) {
|
| 342 |
+
auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
|
| 343 |
+
int i_layer = info.first, n_layer = info.second;
|
| 344 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
|
| 345 |
+
new_type = GGML_TYPE_IQ3_XXS;
|
| 346 |
+
}
|
| 347 |
+
++qs.i_ffn_gate;
|
| 348 |
+
}
|
| 349 |
+
else if (name.find("ffn_up") != std::string::npos) {
|
| 350 |
+
auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
|
| 351 |
+
int i_layer = info.first, n_layer = info.second;
|
| 352 |
+
if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
|
| 353 |
+
new_type = GGML_TYPE_IQ3_XXS;
|
| 354 |
+
}
|
| 355 |
+
++qs.i_ffn_up;
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
| 359 |
+
//}
|
| 360 |
+
// IK: let's remove this, else Q2_K is almost the same as Q3_K_S
|
| 361 |
+
//else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
|
| 362 |
+
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
| 363 |
+
//}
|
| 364 |
+
// This can be used to reduce the size of the Q5_K_S model.
|
| 365 |
+
// The associated PPL increase is fully in line with the size reduction
|
| 366 |
+
//else {
|
| 367 |
+
// if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
|
| 368 |
+
//}
|
| 369 |
+
bool convert_incompatible_tensor = false;
|
| 370 |
+
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
|
| 371 |
+
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS ||
|
| 372 |
+
new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S ||
|
| 373 |
+
new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S ||
|
| 374 |
+
new_type == GGML_TYPE_IQ1_M) {
|
| 375 |
+
int nx = tensor->ne[0];
|
| 376 |
+
int ny = tensor->ne[1];
|
| 377 |
+
if (nx % QK_K != 0) {
|
| 378 |
+
LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));
|
| 379 |
+
convert_incompatible_tensor = true;
|
| 380 |
+
} else {
|
| 381 |
+
++qs.n_k_quantized;
|
| 382 |
+
}
|
| 383 |
+
}
|
| 384 |
+
if (convert_incompatible_tensor) {
|
| 385 |
+
switch (new_type) {
|
| 386 |
+
case GGML_TYPE_TQ1_0:
|
| 387 |
+
case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
|
| 388 |
+
case GGML_TYPE_IQ2_XXS:
|
| 389 |
+
case GGML_TYPE_IQ2_XS:
|
| 390 |
+
case GGML_TYPE_IQ2_S:
|
| 391 |
+
case GGML_TYPE_IQ3_XXS:
|
| 392 |
+
case GGML_TYPE_IQ3_S:
|
| 393 |
+
case GGML_TYPE_IQ1_S:
|
| 394 |
+
case GGML_TYPE_IQ1_M:
|
| 395 |
+
case GGML_TYPE_Q2_K:
|
| 396 |
+
case GGML_TYPE_Q3_K:
|
| 397 |
+
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
|
| 398 |
+
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
|
| 399 |
+
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
|
| 400 |
+
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
|
| 401 |
+
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
|
| 402 |
+
}
|
| 403 |
+
if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
|
| 404 |
+
new_type = GGML_TYPE_F16;
|
| 405 |
+
}
|
| 406 |
+
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
|
| 407 |
+
++qs.n_fallback;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
return new_type;
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
|
| 414 |
+
if (nthread < 2) {
|
| 415 |
+
// single-thread
|
| 416 |
+
size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
|
| 417 |
+
if (!ggml_validate_row_data(new_type, new_data, new_size)) {
|
| 418 |
+
throw std::runtime_error("quantized data validation failed");
|
| 419 |
+
}
|
| 420 |
+
return new_size;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
std::mutex mutex;
|
| 424 |
+
int64_t counter = 0;
|
| 425 |
+
size_t new_size = 0;
|
| 426 |
+
bool valid = true;
|
| 427 |
+
auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
|
| 428 |
+
nrows, n_per_row, imatrix]() {
|
| 429 |
+
const int64_t nrows_per_chunk = chunk_size / n_per_row;
|
| 430 |
+
size_t local_size = 0;
|
| 431 |
+
while (true) {
|
| 432 |
+
std::unique_lock<std::mutex> lock(mutex);
|
| 433 |
+
int64_t first_row = counter; counter += nrows_per_chunk;
|
| 434 |
+
if (first_row >= nrows) {
|
| 435 |
+
if (local_size > 0) {
|
| 436 |
+
new_size += local_size;
|
| 437 |
+
}
|
| 438 |
+
break;
|
| 439 |
+
}
|
| 440 |
+
lock.unlock();
|
| 441 |
+
const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
|
| 442 |
+
size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
|
| 443 |
+
local_size += this_size;
|
| 444 |
+
|
| 445 |
+
// validate the quantized data
|
| 446 |
+
const size_t row_size = ggml_row_size(new_type, n_per_row);
|
| 447 |
+
void * this_data = (char *) new_data + first_row * row_size;
|
| 448 |
+
if (!ggml_validate_row_data(new_type, this_data, this_size)) {
|
| 449 |
+
std::unique_lock<std::mutex> lock(mutex);
|
| 450 |
+
valid = false;
|
| 451 |
+
break;
|
| 452 |
+
}
|
| 453 |
+
}
|
| 454 |
+
};
|
| 455 |
+
for (int it = 0; it < nthread - 1; ++it) {
|
| 456 |
+
workers.emplace_back(compute);
|
| 457 |
+
}
|
| 458 |
+
compute();
|
| 459 |
+
for (auto & w : workers) { w.join(); }
|
| 460 |
+
workers.clear();
|
| 461 |
+
if (!valid) {
|
| 462 |
+
throw std::runtime_error("quantized data validation failed");
|
| 463 |
+
}
|
| 464 |
+
return new_size;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
|
| 468 |
+
ggml_type default_type;
|
| 469 |
+
llama_ftype ftype = params->ftype;
|
| 470 |
+
|
| 471 |
+
switch (params->ftype) {
|
| 472 |
+
case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
|
| 473 |
+
case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
|
| 474 |
+
case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
|
| 475 |
+
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
|
| 476 |
+
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
|
| 477 |
+
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
|
| 478 |
+
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
|
| 479 |
+
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
|
| 480 |
+
|
| 481 |
+
// K-quants
|
| 482 |
+
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
|
| 483 |
+
case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
|
| 484 |
+
case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break;
|
| 485 |
+
case LLAMA_FTYPE_MOSTLY_Q3_K_S:
|
| 486 |
+
case LLAMA_FTYPE_MOSTLY_Q3_K_M:
|
| 487 |
+
case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break;
|
| 488 |
+
case LLAMA_FTYPE_MOSTLY_Q4_K_S:
|
| 489 |
+
case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break;
|
| 490 |
+
case LLAMA_FTYPE_MOSTLY_Q5_K_S:
|
| 491 |
+
case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break;
|
| 492 |
+
case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
|
| 493 |
+
case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break;
|
| 494 |
+
case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break;
|
| 495 |
+
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
|
| 496 |
+
case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
|
| 497 |
+
case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break;
|
| 498 |
+
case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break;
|
| 499 |
+
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
|
| 500 |
+
case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break;
|
| 501 |
+
case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break;
|
| 502 |
+
case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
|
| 503 |
+
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
|
| 504 |
+
case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
|
| 505 |
+
case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
|
| 506 |
+
|
| 507 |
+
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
int nthread = params->nthread;
|
| 511 |
+
|
| 512 |
+
if (nthread <= 0) {
|
| 513 |
+
nthread = std::thread::hardware_concurrency();
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
// mmap consistently increases speed Linux, and also increases speed on Windows with
|
| 517 |
+
// hot cache. It may cause a slowdown on macOS, possibly related to free memory.
|
| 518 |
+
#if defined(__linux__) || defined(_WIN32)
|
| 519 |
+
constexpr bool use_mmap = true;
|
| 520 |
+
#else
|
| 521 |
+
constexpr bool use_mmap = false;
|
| 522 |
+
#endif
|
| 523 |
+
|
| 524 |
+
llama_model_kv_override * kv_overrides = nullptr;
|
| 525 |
+
if (params->kv_overrides) {
|
| 526 |
+
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
|
| 527 |
+
kv_overrides = v->data();
|
| 528 |
+
}
|
| 529 |
+
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
|
| 530 |
+
ml.init_mappings(false); // no prefetching
|
| 531 |
+
|
| 532 |
+
llama_model model;
|
| 533 |
+
llm_load_arch (ml, model);
|
| 534 |
+
llm_load_hparams(ml, model);
|
| 535 |
+
llm_load_stats (ml, model);
|
| 536 |
+
|
| 537 |
+
struct quantize_state_impl qs(model, params);
|
| 538 |
+
|
| 539 |
+
if (params->only_copy) {
|
| 540 |
+
ftype = model.ftype;
|
| 541 |
+
}
|
| 542 |
+
const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
|
| 543 |
+
if (params->imatrix) {
|
| 544 |
+
imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
|
| 545 |
+
if (imatrix_data) {
|
| 546 |
+
LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
|
| 547 |
+
qs.has_imatrix = true;
|
| 548 |
+
// check imatrix for nans or infs
|
| 549 |
+
for (const auto & kv : *imatrix_data) {
|
| 550 |
+
for (float f : kv.second) {
|
| 551 |
+
if (!std::isfinite(f)) {
|
| 552 |
+
throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
|
| 553 |
+
}
|
| 554 |
+
}
|
| 555 |
+
}
|
| 556 |
+
}
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
const size_t align = GGUF_DEFAULT_ALIGNMENT;
|
| 560 |
+
gguf_context_ptr ctx_out { gguf_init_empty() };
|
| 561 |
+
|
| 562 |
+
// copy the KV pairs from the input file
|
| 563 |
+
gguf_set_kv (ctx_out.get(), ml.meta.get());
|
| 564 |
+
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
|
| 565 |
+
gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
|
| 566 |
+
|
| 567 |
+
// Remove split metadata
|
| 568 |
+
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
|
| 569 |
+
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
|
| 570 |
+
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
|
| 571 |
+
|
| 572 |
+
if (params->kv_overrides) {
|
| 573 |
+
const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
|
| 574 |
+
for (const auto & o : overrides) {
|
| 575 |
+
if (o.key[0] == 0) break;
|
| 576 |
+
if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
|
| 577 |
+
gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
|
| 578 |
+
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
|
| 579 |
+
gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64);
|
| 580 |
+
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
|
| 581 |
+
gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
|
| 582 |
+
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
|
| 583 |
+
gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
|
| 584 |
+
} else {
|
| 585 |
+
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
|
| 586 |
+
}
|
| 587 |
+
}
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
// make a list of weights
|
| 591 |
+
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
| 592 |
+
tensors.reserve(ml.weights_map.size());
|
| 593 |
+
for (const auto & it : ml.weights_map) {
|
| 594 |
+
tensors.push_back(&it.second);
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
// keep_split requires that the weights are sorted by split index
|
| 598 |
+
if (params->keep_split) {
|
| 599 |
+
std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) {
|
| 600 |
+
if (a->idx == b->idx) {
|
| 601 |
+
return a->offs < b->offs;
|
| 602 |
+
}
|
| 603 |
+
return a->idx < b->idx;
|
| 604 |
+
});
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
for (const auto * it : tensors) {
|
| 608 |
+
const struct ggml_tensor * tensor = it->tensor;
|
| 609 |
+
|
| 610 |
+
const std::string name = ggml_get_name(tensor);
|
| 611 |
+
|
| 612 |
+
// TODO: avoid hardcoded tensor names - use the TN_* constants
|
| 613 |
+
if (name.find("attn_v.weight") != std::string::npos ||
|
| 614 |
+
name.find("attn_qkv.weight") != std::string::npos ||
|
| 615 |
+
name.find("attn_kv_b.weight")!= std::string::npos) {
|
| 616 |
+
++qs.n_attention_wv;
|
| 617 |
+
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
| 618 |
+
qs.has_output = true;
|
| 619 |
+
}
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
| 623 |
+
|
| 624 |
+
// sanity checks
|
| 625 |
+
{
|
| 626 |
+
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
| 627 |
+
// attention layers have a non-zero number of kv heads
|
| 628 |
+
int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
| 629 |
+
if (llama_model_has_encoder(&model)) {
|
| 630 |
+
n_attn_layer *= 3;
|
| 631 |
+
}
|
| 632 |
+
GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
size_t total_size_org = 0;
|
| 636 |
+
size_t total_size_new = 0;
|
| 637 |
+
|
| 638 |
+
std::vector<std::thread> workers;
|
| 639 |
+
workers.reserve(nthread);
|
| 640 |
+
|
| 641 |
+
int idx = 0;
|
| 642 |
+
|
| 643 |
+
std::vector<no_init<uint8_t>> read_data;
|
| 644 |
+
std::vector<no_init<uint8_t>> work;
|
| 645 |
+
std::vector<no_init<float>> f32_conv_buf;
|
| 646 |
+
|
| 647 |
+
uint16_t n_split = 1;
|
| 648 |
+
|
| 649 |
+
// Assume split index is continuous
|
| 650 |
+
if (params->keep_split) {
|
| 651 |
+
for (const auto * it : tensors) {
|
| 652 |
+
n_split = std::max(uint16_t(it->idx + 1), n_split);
|
| 653 |
+
}
|
| 654 |
+
}
|
| 655 |
+
std::vector<gguf_context_ptr> ctx_outs(n_split);
|
| 656 |
+
ctx_outs[0] = std::move(ctx_out);
|
| 657 |
+
|
| 658 |
+
// populate the original tensors so we get an initial meta data
|
| 659 |
+
for (const auto * it : tensors) {
|
| 660 |
+
uint16_t i_split = params->keep_split ? it->idx : 0;
|
| 661 |
+
struct ggml_tensor * tensor = it->tensor;
|
| 662 |
+
if (!ctx_outs[i_split]) {
|
| 663 |
+
ctx_outs[i_split].reset(gguf_init_empty());
|
| 664 |
+
}
|
| 665 |
+
gguf_add_tensor(ctx_outs[i_split].get(), tensor);
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
// Set split info if needed
|
| 669 |
+
if (n_split > 1) {
|
| 670 |
+
for (size_t i = 0; i < ctx_outs.size(); ++i) {
|
| 671 |
+
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
|
| 672 |
+
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
|
| 673 |
+
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
|
| 674 |
+
}
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
int cur_split = -1;
|
| 678 |
+
std::ofstream fout;
|
| 679 |
+
auto close_ofstream = [&]() {
|
| 680 |
+
// Write metadata and close file handler
|
| 681 |
+
if (fout.is_open()) {
|
| 682 |
+
fout.seekp(0);
|
| 683 |
+
std::vector<uint8_t> data(gguf_get_meta_size(ctx_outs[cur_split].get()));
|
| 684 |
+
gguf_get_meta_data(ctx_outs[cur_split].get(), data.data());
|
| 685 |
+
fout.write((const char *) data.data(), data.size());
|
| 686 |
+
fout.close();
|
| 687 |
+
}
|
| 688 |
+
};
|
| 689 |
+
auto new_ofstream = [&](int index) {
|
| 690 |
+
cur_split = index;
|
| 691 |
+
GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context");
|
| 692 |
+
std::string fname = fname_out;
|
| 693 |
+
if (params->keep_split) {
|
| 694 |
+
std::vector<char> split_path(llama_path_max(), 0);
|
| 695 |
+
llama_split_path(split_path.data(), split_path.size(), fname_out.c_str(), cur_split, n_split);
|
| 696 |
+
fname = std::string(split_path.data());
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
fout = std::ofstream(fname, std::ios::binary);
|
| 700 |
+
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
|
| 701 |
+
const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get());
|
| 702 |
+
// placeholder for the meta data
|
| 703 |
+
::zeros(fout, meta_size);
|
| 704 |
+
};
|
| 705 |
+
|
| 706 |
+
const auto tn = LLM_TN(model.arch);
|
| 707 |
+
new_ofstream(0);
|
| 708 |
+
for (const auto * it : tensors) {
|
| 709 |
+
const auto & weight = *it;
|
| 710 |
+
struct ggml_tensor * tensor = weight.tensor;
|
| 711 |
+
if (weight.idx != cur_split && params->keep_split) {
|
| 712 |
+
close_ofstream();
|
| 713 |
+
new_ofstream(weight.idx);
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
const std::string name = ggml_get_name(tensor);
|
| 717 |
+
|
| 718 |
+
if (!ml.use_mmap) {
|
| 719 |
+
if (read_data.size() < ggml_nbytes(tensor)) {
|
| 720 |
+
read_data.resize(ggml_nbytes(tensor));
|
| 721 |
+
}
|
| 722 |
+
tensor->data = read_data.data();
|
| 723 |
+
}
|
| 724 |
+
ml.load_data_for(tensor);
|
| 725 |
+
|
| 726 |
+
LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
|
| 727 |
+
++idx, ml.n_tensors,
|
| 728 |
+
ggml_get_name(tensor),
|
| 729 |
+
llama_format_tensor_shape(tensor).c_str(),
|
| 730 |
+
ggml_type_name(tensor->type));
|
| 731 |
+
|
| 732 |
+
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
| 733 |
+
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
| 734 |
+
|
| 735 |
+
// quantize only 2D and 3D tensors (experts)
|
| 736 |
+
quantize &= (ggml_n_dims(tensor) >= 2);
|
| 737 |
+
|
| 738 |
+
// do not quantize norm tensors
|
| 739 |
+
quantize &= name.find("_norm.weight") == std::string::npos;
|
| 740 |
+
|
| 741 |
+
quantize &= params->quantize_output_tensor || name != "output.weight";
|
| 742 |
+
quantize &= !params->only_copy;
|
| 743 |
+
|
| 744 |
+
// do not quantize expert gating tensors
|
| 745 |
+
// NOTE: can't use LLM_TN here because the layer number is not known
|
| 746 |
+
quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
|
| 747 |
+
|
| 748 |
+
// do not quantize positional embeddings and token types (BERT)
|
| 749 |
+
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
|
| 750 |
+
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
|
| 751 |
+
|
| 752 |
+
// do not quantize Mamba's small yet 2D weights
|
| 753 |
+
// NOTE: can't use LLM_TN here because the layer number is not known
|
| 754 |
+
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
| 755 |
+
|
| 756 |
+
// do not quantize RWKV's time_mix_first tensors
|
| 757 |
+
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
| 758 |
+
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
|
| 759 |
+
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
| 760 |
+
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
| 761 |
+
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
| 762 |
+
|
| 763 |
+
// do not quantize relative position bias (T5)
|
| 764 |
+
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
| 765 |
+
|
| 766 |
+
enum ggml_type new_type;
|
| 767 |
+
void * new_data;
|
| 768 |
+
size_t new_size;
|
| 769 |
+
|
| 770 |
+
if (quantize) {
|
| 771 |
+
new_type = default_type;
|
| 772 |
+
|
| 773 |
+
// get more optimal quantization type based on the tensor shape, layer, etc.
|
| 774 |
+
if (!params->pure && ggml_is_quantized(default_type)) {
|
| 775 |
+
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
| 776 |
+
}
|
| 777 |
+
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
| 778 |
+
new_type = params->token_embedding_type;
|
| 779 |
+
}
|
| 780 |
+
if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
|
| 781 |
+
new_type = params->output_tensor_type;
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
// If we've decided to quantize to the same type the tensor is already
|
| 785 |
+
// in then there's nothing to do.
|
| 786 |
+
quantize = tensor->type != new_type;
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
if (!quantize) {
|
| 790 |
+
new_type = tensor->type;
|
| 791 |
+
new_data = tensor->data;
|
| 792 |
+
new_size = ggml_nbytes(tensor);
|
| 793 |
+
LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
|
| 794 |
+
} else {
|
| 795 |
+
const int64_t nelements = ggml_nelements(tensor);
|
| 796 |
+
|
| 797 |
+
const float * imatrix = nullptr;
|
| 798 |
+
if (imatrix_data) {
|
| 799 |
+
auto it = imatrix_data->find(tensor->name);
|
| 800 |
+
if (it == imatrix_data->end()) {
|
| 801 |
+
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
|
| 802 |
+
} else {
|
| 803 |
+
if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
|
| 804 |
+
imatrix = it->second.data();
|
| 805 |
+
} else {
|
| 806 |
+
LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
|
| 807 |
+
int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
|
| 808 |
+
|
| 809 |
+
// this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
|
| 810 |
+
// this is a significant error and it may be good idea to abort the process if this happens,
|
| 811 |
+
// since many people will miss the error and not realize that most of the model is being quantized without an imatrix
|
| 812 |
+
// tok_embd should be ignored in this case, since it always causes this warning
|
| 813 |
+
if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
|
| 814 |
+
throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
|
| 815 |
+
int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
|
| 816 |
+
}
|
| 817 |
+
}
|
| 818 |
+
}
|
| 819 |
+
}
|
| 820 |
+
if ((new_type == GGML_TYPE_IQ2_XXS ||
|
| 821 |
+
new_type == GGML_TYPE_IQ2_XS ||
|
| 822 |
+
new_type == GGML_TYPE_IQ2_S ||
|
| 823 |
+
new_type == GGML_TYPE_IQ1_S ||
|
| 824 |
+
(new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) ||
|
| 825 |
+
(new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
|
| 826 |
+
LLAMA_LOG_ERROR("\n\n============================================================\n");
|
| 827 |
+
LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
|
| 828 |
+
LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
|
| 829 |
+
LLAMA_LOG_ERROR("============================================================\n\n");
|
| 830 |
+
throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
|
| 831 |
+
}
|
| 832 |
+
|
| 833 |
+
float * f32_data;
|
| 834 |
+
|
| 835 |
+
if (tensor->type == GGML_TYPE_F32) {
|
| 836 |
+
f32_data = (float *) tensor->data;
|
| 837 |
+
} else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
|
| 838 |
+
throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
|
| 839 |
+
} else {
|
| 840 |
+
llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
|
| 841 |
+
f32_data = (float *) f32_conv_buf.data();
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
|
| 845 |
+
fflush(stdout);
|
| 846 |
+
|
| 847 |
+
if (work.size() < (size_t)nelements * 4) {
|
| 848 |
+
work.resize(nelements * 4); // upper bound on size
|
| 849 |
+
}
|
| 850 |
+
new_data = work.data();
|
| 851 |
+
|
| 852 |
+
const int64_t n_per_row = tensor->ne[0];
|
| 853 |
+
const int64_t nrows = tensor->ne[1];
|
| 854 |
+
|
| 855 |
+
static const int64_t min_chunk_size = 32 * 512;
|
| 856 |
+
const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
|
| 857 |
+
|
| 858 |
+
const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
|
| 859 |
+
const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
|
| 860 |
+
const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
|
| 861 |
+
|
| 862 |
+
// quantize each expert separately since they have different importance matrices
|
| 863 |
+
new_size = 0;
|
| 864 |
+
for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
|
| 865 |
+
const float * f32_data_03 = f32_data + i03 * nelements_matrix;
|
| 866 |
+
void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
|
| 867 |
+
const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
|
| 868 |
+
|
| 869 |
+
new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
|
| 870 |
+
}
|
| 871 |
+
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
|
| 872 |
+
}
|
| 873 |
+
total_size_org += ggml_nbytes(tensor);
|
| 874 |
+
total_size_new += new_size;
|
| 875 |
+
|
| 876 |
+
// update the gguf meta data as we go
|
| 877 |
+
gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
|
| 878 |
+
gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data, new_size);
|
| 879 |
+
|
| 880 |
+
// write tensor data + padding
|
| 881 |
+
fout.write((const char *) new_data, new_size);
|
| 882 |
+
zeros(fout, GGML_PAD(new_size, align) - new_size);
|
| 883 |
+
}
|
| 884 |
+
close_ofstream();
|
| 885 |
+
|
| 886 |
+
LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
|
| 887 |
+
LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
|
| 888 |
+
|
| 889 |
+
if (qs.n_fallback > 0) {
|
| 890 |
+
LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
|
| 891 |
+
__func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
|
| 892 |
+
}
|
| 893 |
+
}
|
| 894 |
+
|
| 895 |
+
//
|
| 896 |
+
// interface implementation
|
| 897 |
+
//
|
| 898 |
+
|
| 899 |
+
struct llama_model_quantize_params llama_model_quantize_default_params() {
|
| 900 |
+
struct llama_model_quantize_params result = {
|
| 901 |
+
/*.nthread =*/ 0,
|
| 902 |
+
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
|
| 903 |
+
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
|
| 904 |
+
/*.token_embedding_type =*/ GGML_TYPE_COUNT,
|
| 905 |
+
/*.allow_requantize =*/ false,
|
| 906 |
+
/*.quantize_output_tensor =*/ true,
|
| 907 |
+
/*.only_copy =*/ false,
|
| 908 |
+
/*.pure =*/ false,
|
| 909 |
+
/*.keep_split =*/ false,
|
| 910 |
+
/*.imatrix =*/ nullptr,
|
| 911 |
+
/*.kv_overrides =*/ nullptr,
|
| 912 |
+
};
|
| 913 |
+
|
| 914 |
+
return result;
|
| 915 |
+
}
|
| 916 |
+
|
| 917 |
+
uint32_t llama_model_quantize(
|
| 918 |
+
const char * fname_inp,
|
| 919 |
+
const char * fname_out,
|
| 920 |
+
const llama_model_quantize_params * params) {
|
| 921 |
+
try {
|
| 922 |
+
llama_model_quantize_impl(fname_inp, fname_out, params);
|
| 923 |
+
} catch (const std::exception & err) {
|
| 924 |
+
LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
|
| 925 |
+
return 1;
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
return 0;
|
| 929 |
+
}
|
examples/talk-llama/llama-quant.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
examples/talk-llama/llama-sampling.cpp
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
#include "llama-sampling.h"
|
| 2 |
|
|
|
|
| 3 |
#include "llama-vocab.h"
|
| 4 |
#include "llama-grammar.h"
|
| 5 |
|
|
@@ -14,6 +15,118 @@
|
|
| 14 |
#include <numeric>
|
| 15 |
#include <random>
|
| 16 |
#include <unordered_map>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
| 19 |
// iterator for the probabilities
|
|
@@ -144,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|
| 144 |
for (int i = 0; i < (int)cur_p->size; ++i) {
|
| 145 |
const float val = cur_p->data[i].logit;
|
| 146 |
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
| 147 |
-
ib = std::max(0, std::min(nbuckets-1, ib));
|
| 148 |
bucket_idx[i] = ib;
|
| 149 |
++histo[ib];
|
| 150 |
}
|
|
@@ -167,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|
| 167 |
for (int i = 0; i < (int)cur_p->size; ++i) {
|
| 168 |
int j = bucket_idx[i];
|
| 169 |
if (j >= ib) {
|
| 170 |
-
*bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
|
| 171 |
}
|
| 172 |
}
|
| 173 |
|
| 174 |
ptr = tmp_tokens.data();
|
| 175 |
int ndone = 0;
|
| 176 |
-
for (int j = nbuckets-1; j > ib; --j) {
|
| 177 |
std::sort(ptr, ptr + histo[j], comp);
|
| 178 |
ptr += histo[j];
|
| 179 |
ndone += histo[j];
|
|
@@ -1719,7 +1832,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
|
|
| 1719 |
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
|
| 1720 |
if (n > 0) {
|
| 1721 |
lt = k;
|
| 1722 |
-
rt = k+n-1;
|
| 1723 |
}
|
| 1724 |
} else {
|
| 1725 |
// If k is inside the current Z-box, consider two cases.
|
|
|
|
| 1 |
#include "llama-sampling.h"
|
| 2 |
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
#include "llama-vocab.h"
|
| 5 |
#include "llama-grammar.h"
|
| 6 |
|
|
|
|
| 15 |
#include <numeric>
|
| 16 |
#include <random>
|
| 17 |
#include <unordered_map>
|
| 18 |
+
#include <stdexcept>
|
| 19 |
+
|
| 20 |
+
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
| 21 |
+
template<typename T>
|
| 22 |
+
struct ring_buffer {
|
| 23 |
+
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
| 24 |
+
|
| 25 |
+
T & front() {
|
| 26 |
+
if (sz == 0) {
|
| 27 |
+
throw std::runtime_error("ring buffer is empty");
|
| 28 |
+
}
|
| 29 |
+
return data[first];
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
const T & front() const {
|
| 33 |
+
if (sz == 0) {
|
| 34 |
+
throw std::runtime_error("ring buffer is empty");
|
| 35 |
+
}
|
| 36 |
+
return data[first];
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
T & back() {
|
| 40 |
+
if (sz == 0) {
|
| 41 |
+
throw std::runtime_error("ring buffer is empty");
|
| 42 |
+
}
|
| 43 |
+
return data[pos];
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
const T & back() const {
|
| 47 |
+
if (sz == 0) {
|
| 48 |
+
throw std::runtime_error("ring buffer is empty");
|
| 49 |
+
}
|
| 50 |
+
return data[pos];
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
void push_back(const T & value) {
|
| 54 |
+
if (capacity == 0) {
|
| 55 |
+
throw std::runtime_error("ring buffer: capacity is zero");
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
if (sz == capacity) {
|
| 59 |
+
// advance the start when buffer is full
|
| 60 |
+
first = (first + 1) % capacity;
|
| 61 |
+
} else {
|
| 62 |
+
sz++;
|
| 63 |
+
}
|
| 64 |
+
data[pos] = value;
|
| 65 |
+
pos = (pos + 1) % capacity;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
T pop_front() {
|
| 69 |
+
if (sz == 0) {
|
| 70 |
+
throw std::runtime_error("ring buffer is empty");
|
| 71 |
+
}
|
| 72 |
+
T value = data[first];
|
| 73 |
+
first = (first + 1) % capacity;
|
| 74 |
+
sz--;
|
| 75 |
+
return value;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
//T & operator[](size_t i) {
|
| 79 |
+
// if (i >= sz) {
|
| 80 |
+
// throw std::runtime_error("ring buffer: index out of bounds");
|
| 81 |
+
// }
|
| 82 |
+
// return data[(first + i) % capacity];
|
| 83 |
+
//}
|
| 84 |
+
|
| 85 |
+
//const T & at(size_t i) const {
|
| 86 |
+
// if (i >= sz) {
|
| 87 |
+
// throw std::runtime_error("ring buffer: index out of bounds");
|
| 88 |
+
// }
|
| 89 |
+
// return data[(first + i) % capacity];
|
| 90 |
+
//}
|
| 91 |
+
|
| 92 |
+
const T & rat(size_t i) const {
|
| 93 |
+
if (i >= sz) {
|
| 94 |
+
throw std::runtime_error("ring buffer: index out of bounds");
|
| 95 |
+
}
|
| 96 |
+
return data[(first + sz - i - 1) % capacity];
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
std::vector<T> to_vector() const {
|
| 100 |
+
std::vector<T> result;
|
| 101 |
+
result.reserve(sz);
|
| 102 |
+
for (size_t i = 0; i < sz; i++) {
|
| 103 |
+
result.push_back(data[(first + i) % capacity]);
|
| 104 |
+
}
|
| 105 |
+
return result;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
void clear() {
|
| 109 |
+
// here only reset the status of the buffer
|
| 110 |
+
sz = 0;
|
| 111 |
+
first = 0;
|
| 112 |
+
pos = 0;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
bool empty() const {
|
| 116 |
+
return sz == 0;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
size_t size() const {
|
| 120 |
+
return sz;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
size_t capacity = 0;
|
| 124 |
+
size_t sz = 0;
|
| 125 |
+
size_t first = 0;
|
| 126 |
+
size_t pos = 0;
|
| 127 |
+
|
| 128 |
+
std::vector<T> data;
|
| 129 |
+
};
|
| 130 |
|
| 131 |
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
| 132 |
// iterator for the probabilities
|
|
|
|
| 257 |
for (int i = 0; i < (int)cur_p->size; ++i) {
|
| 258 |
const float val = cur_p->data[i].logit;
|
| 259 |
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
| 260 |
+
ib = std::max(0, std::min(nbuckets - 1, ib));
|
| 261 |
bucket_idx[i] = ib;
|
| 262 |
++histo[ib];
|
| 263 |
}
|
|
|
|
| 280 |
for (int i = 0; i < (int)cur_p->size; ++i) {
|
| 281 |
int j = bucket_idx[i];
|
| 282 |
if (j >= ib) {
|
| 283 |
+
*bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
|
| 284 |
}
|
| 285 |
}
|
| 286 |
|
| 287 |
ptr = tmp_tokens.data();
|
| 288 |
int ndone = 0;
|
| 289 |
+
for (int j = nbuckets - 1; j > ib; --j) {
|
| 290 |
std::sort(ptr, ptr + histo[j], comp);
|
| 291 |
ptr += histo[j];
|
| 292 |
ndone += histo[j];
|
|
|
|
| 1832 |
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
|
| 1833 |
if (n > 0) {
|
| 1834 |
lt = k;
|
| 1835 |
+
rt = k + n - 1;
|
| 1836 |
}
|
| 1837 |
} else {
|
| 1838 |
// If k is inside the current Z-box, consider two cases.
|
examples/talk-llama/llama-vocab.cpp
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
#include "llama-vocab.h"
|
| 2 |
|
|
|
|
|
|
|
| 3 |
#include "unicode.h"
|
| 4 |
|
| 5 |
#include <algorithm>
|
|
@@ -16,22 +18,6 @@
|
|
| 16 |
// helpers
|
| 17 |
//
|
| 18 |
|
| 19 |
-
LLAMA_ATTRIBUTE_FORMAT(1, 2)
|
| 20 |
-
static std::string format(const char * fmt, ...) {
|
| 21 |
-
va_list ap;
|
| 22 |
-
va_list ap2;
|
| 23 |
-
va_start(ap, fmt);
|
| 24 |
-
va_copy(ap2, ap);
|
| 25 |
-
int size = vsnprintf(NULL, 0, fmt, ap);
|
| 26 |
-
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
| 27 |
-
std::vector<char> buf(size + 1);
|
| 28 |
-
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
| 29 |
-
GGML_ASSERT(size2 == size);
|
| 30 |
-
va_end(ap2);
|
| 31 |
-
va_end(ap);
|
| 32 |
-
return std::string(buf.data(), size);
|
| 33 |
-
}
|
| 34 |
-
|
| 35 |
struct naive_trie {
|
| 36 |
naive_trie() : has_value(false), value(0) {
|
| 37 |
}
|
|
@@ -396,6 +382,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|
| 396 |
"\\p{N}+",
|
| 397 |
};
|
| 398 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
|
| 400 |
regex_exprs = {
|
| 401 |
"[\r\n]",
|
|
@@ -504,7 +497,7 @@ struct llm_tokenizer_bpe_session {
|
|
| 504 |
|
| 505 |
bool append_bos(std::vector<llama_vocab::id> & output) const {
|
| 506 |
if (vocab.tokenizer_add_bos) {
|
| 507 |
-
GGML_ASSERT(vocab.special_bos_id !=
|
| 508 |
output.push_back(vocab.special_bos_id);
|
| 509 |
return true;
|
| 510 |
}
|
|
@@ -513,7 +506,7 @@ struct llm_tokenizer_bpe_session {
|
|
| 513 |
|
| 514 |
bool append_eos(std::vector<llama_vocab::id> & output) const {
|
| 515 |
if (vocab.tokenizer_add_eos) {
|
| 516 |
-
GGML_ASSERT(vocab.special_eos_id !=
|
| 517 |
output.push_back(vocab.special_eos_id);
|
| 518 |
return true;
|
| 519 |
}
|
|
@@ -1410,7 +1403,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
|
|
| 1410 |
if (source == 0) {
|
| 1411 |
buffer.erase_after(buffer.before_begin());
|
| 1412 |
} else {
|
| 1413 |
-
buffer.erase_after(std::next(buffer.begin(), (source-1)));
|
| 1414 |
}
|
| 1415 |
|
| 1416 |
// repeat for the right side
|
|
@@ -1424,7 +1417,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
|
|
| 1424 |
if (source == 0) {
|
| 1425 |
buffer.erase_after(buffer.before_begin());
|
| 1426 |
} else {
|
| 1427 |
-
buffer.erase_after(std::next(buffer.begin(), (source-1)));
|
| 1428 |
}
|
| 1429 |
break;
|
| 1430 |
}
|
|
@@ -1461,7 +1454,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|
| 1461 |
bool is_prev_special = true; // prefix with space if first token
|
| 1462 |
|
| 1463 |
if (add_special && vocab.tokenizer_add_bos) {
|
| 1464 |
-
GGML_ASSERT(vocab.special_bos_id !=
|
| 1465 |
output.push_back(vocab.special_bos_id);
|
| 1466 |
is_prev_special = true;
|
| 1467 |
}
|
|
@@ -1496,7 +1489,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|
| 1496 |
}
|
| 1497 |
|
| 1498 |
if (add_special && vocab.tokenizer_add_eos) {
|
| 1499 |
-
GGML_ASSERT(vocab.special_eos_id !=
|
| 1500 |
output.push_back(vocab.special_eos_id);
|
| 1501 |
}
|
| 1502 |
} break;
|
|
@@ -1529,7 +1522,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|
| 1529 |
case LLAMA_VOCAB_TYPE_WPM:
|
| 1530 |
{
|
| 1531 |
if (add_special) {
|
| 1532 |
-
GGML_ASSERT(vocab.special_cls_id !=
|
| 1533 |
output.push_back(vocab.special_cls_id);
|
| 1534 |
}
|
| 1535 |
|
|
@@ -1549,14 +1542,14 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|
| 1549 |
}
|
| 1550 |
|
| 1551 |
if (add_special) {
|
| 1552 |
-
GGML_ASSERT(vocab.special_sep_id !=
|
| 1553 |
output.push_back(vocab.special_sep_id);
|
| 1554 |
}
|
| 1555 |
} break;
|
| 1556 |
case LLAMA_VOCAB_TYPE_UGM:
|
| 1557 |
{
|
| 1558 |
if (add_special && vocab.tokenizer_add_bos) {
|
| 1559 |
-
GGML_ASSERT(vocab.special_bos_id !=
|
| 1560 |
output.push_back(vocab.special_bos_id);
|
| 1561 |
}
|
| 1562 |
llm_tokenizer_ugm_session session(vocab);
|
|
@@ -1581,7 +1574,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|
| 1581 |
}
|
| 1582 |
|
| 1583 |
if (add_special && vocab.tokenizer_add_eos) {
|
| 1584 |
-
GGML_ASSERT(vocab.special_eos_id !=
|
| 1585 |
output.push_back(vocab.special_eos_id);
|
| 1586 |
}
|
| 1587 |
} break;
|
|
@@ -1649,7 +1642,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
|
|
| 1649 |
}
|
| 1650 |
|
| 1651 |
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
|
| 1652 |
-
return token !=
|
| 1653 |
}
|
| 1654 |
|
| 1655 |
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
|
|
@@ -1657,7 +1650,7 @@ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token t
|
|
| 1657 |
}
|
| 1658 |
|
| 1659 |
llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
|
| 1660 |
-
return vocab.special_bos_id;
|
| 1661 |
}
|
| 1662 |
|
| 1663 |
llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
|
|
@@ -1867,6 +1860,10 @@ int32_t llama_detokenize_impl(
|
|
| 1867 |
int32_t text_len_max,
|
| 1868 |
bool remove_special,
|
| 1869 |
bool unparse_special) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1870 |
GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
|
| 1871 |
|
| 1872 |
int32_t avail = text_len_max;
|
|
@@ -1884,7 +1881,7 @@ int32_t llama_detokenize_impl(
|
|
| 1884 |
}
|
| 1885 |
|
| 1886 |
if (remove_special && vocab.tokenizer_add_eos) {
|
| 1887 |
-
if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
|
| 1888 |
n_tokens--;
|
| 1889 |
}
|
| 1890 |
}
|
|
|
|
| 1 |
#include "llama-vocab.h"
|
| 2 |
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
|
| 5 |
#include "unicode.h"
|
| 6 |
|
| 7 |
#include <algorithm>
|
|
|
|
| 18 |
// helpers
|
| 19 |
//
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
struct naive_trie {
|
| 22 |
naive_trie() : has_value(false), value(0) {
|
| 23 |
}
|
|
|
|
| 382 |
"\\p{N}+",
|
| 383 |
};
|
| 384 |
break;
|
| 385 |
+
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM:
|
| 386 |
+
regex_exprs = {
|
| 387 |
+
"\\p{N}{1,3}",
|
| 388 |
+
"[一-龥-ゟ゠-ヿ]+",
|
| 389 |
+
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
| 390 |
+
};
|
| 391 |
+
break;
|
| 392 |
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
|
| 393 |
regex_exprs = {
|
| 394 |
"[\r\n]",
|
|
|
|
| 497 |
|
| 498 |
bool append_bos(std::vector<llama_vocab::id> & output) const {
|
| 499 |
if (vocab.tokenizer_add_bos) {
|
| 500 |
+
GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
|
| 501 |
output.push_back(vocab.special_bos_id);
|
| 502 |
return true;
|
| 503 |
}
|
|
|
|
| 506 |
|
| 507 |
bool append_eos(std::vector<llama_vocab::id> & output) const {
|
| 508 |
if (vocab.tokenizer_add_eos) {
|
| 509 |
+
GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
|
| 510 |
output.push_back(vocab.special_eos_id);
|
| 511 |
return true;
|
| 512 |
}
|
|
|
|
| 1403 |
if (source == 0) {
|
| 1404 |
buffer.erase_after(buffer.before_begin());
|
| 1405 |
} else {
|
| 1406 |
+
buffer.erase_after(std::next(buffer.begin(), (source - 1)));
|
| 1407 |
}
|
| 1408 |
|
| 1409 |
// repeat for the right side
|
|
|
|
| 1417 |
if (source == 0) {
|
| 1418 |
buffer.erase_after(buffer.before_begin());
|
| 1419 |
} else {
|
| 1420 |
+
buffer.erase_after(std::next(buffer.begin(), (source - 1)));
|
| 1421 |
}
|
| 1422 |
break;
|
| 1423 |
}
|
|
|
|
| 1454 |
bool is_prev_special = true; // prefix with space if first token
|
| 1455 |
|
| 1456 |
if (add_special && vocab.tokenizer_add_bos) {
|
| 1457 |
+
GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
|
| 1458 |
output.push_back(vocab.special_bos_id);
|
| 1459 |
is_prev_special = true;
|
| 1460 |
}
|
|
|
|
| 1489 |
}
|
| 1490 |
|
| 1491 |
if (add_special && vocab.tokenizer_add_eos) {
|
| 1492 |
+
GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
|
| 1493 |
output.push_back(vocab.special_eos_id);
|
| 1494 |
}
|
| 1495 |
} break;
|
|
|
|
| 1522 |
case LLAMA_VOCAB_TYPE_WPM:
|
| 1523 |
{
|
| 1524 |
if (add_special) {
|
| 1525 |
+
GGML_ASSERT(vocab.special_cls_id != LLAMA_TOKEN_NULL);
|
| 1526 |
output.push_back(vocab.special_cls_id);
|
| 1527 |
}
|
| 1528 |
|
|
|
|
| 1542 |
}
|
| 1543 |
|
| 1544 |
if (add_special) {
|
| 1545 |
+
GGML_ASSERT(vocab.special_sep_id != LLAMA_TOKEN_NULL);
|
| 1546 |
output.push_back(vocab.special_sep_id);
|
| 1547 |
}
|
| 1548 |
} break;
|
| 1549 |
case LLAMA_VOCAB_TYPE_UGM:
|
| 1550 |
{
|
| 1551 |
if (add_special && vocab.tokenizer_add_bos) {
|
| 1552 |
+
GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
|
| 1553 |
output.push_back(vocab.special_bos_id);
|
| 1554 |
}
|
| 1555 |
llm_tokenizer_ugm_session session(vocab);
|
|
|
|
| 1574 |
}
|
| 1575 |
|
| 1576 |
if (add_special && vocab.tokenizer_add_eos) {
|
| 1577 |
+
GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
|
| 1578 |
output.push_back(vocab.special_eos_id);
|
| 1579 |
}
|
| 1580 |
} break;
|
|
|
|
| 1642 |
}
|
| 1643 |
|
| 1644 |
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
|
| 1645 |
+
return token != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(token) > 0;
|
| 1646 |
}
|
| 1647 |
|
| 1648 |
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
|
|
|
|
| 1650 |
}
|
| 1651 |
|
| 1652 |
llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
|
| 1653 |
+
return vocab.type != LLAMA_VOCAB_TYPE_WPM ? vocab.special_bos_id : vocab.special_cls_id;
|
| 1654 |
}
|
| 1655 |
|
| 1656 |
llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
|
|
|
|
| 1860 |
int32_t text_len_max,
|
| 1861 |
bool remove_special,
|
| 1862 |
bool unparse_special) {
|
| 1863 |
+
if (vocab.type == LLAMA_VOCAB_TYPE_NONE) {
|
| 1864 |
+
return 0;
|
| 1865 |
+
}
|
| 1866 |
+
|
| 1867 |
GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
|
| 1868 |
|
| 1869 |
int32_t avail = text_len_max;
|
|
|
|
| 1881 |
}
|
| 1882 |
|
| 1883 |
if (remove_special && vocab.tokenizer_add_eos) {
|
| 1884 |
+
if (n_tokens > 0 && tokens[n_tokens - 1] == vocab.special_eos_id) {
|
| 1885 |
n_tokens--;
|
| 1886 |
}
|
| 1887 |
}
|
examples/talk-llama/llama-vocab.h
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
-
#include "llama
|
| 4 |
|
| 5 |
#include <string>
|
| 6 |
#include <vector>
|
|
@@ -8,6 +8,18 @@
|
|
| 8 |
#include <map>
|
| 9 |
#include <set>
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
struct llm_tokenizer;
|
| 12 |
|
| 13 |
struct llama_vocab {
|
|
@@ -45,7 +57,7 @@ struct llama_vocab {
|
|
| 45 |
id special_unk_id = 0;
|
| 46 |
id special_sep_id = LLAMA_TOKEN_NULL;
|
| 47 |
id special_pad_id = LLAMA_TOKEN_NULL;
|
| 48 |
-
id special_cls_id = LLAMA_TOKEN_NULL;
|
| 49 |
id special_mask_id = LLAMA_TOKEN_NULL;
|
| 50 |
|
| 51 |
id linefeed_id = 13;
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
+
#include "llama.h"
|
| 4 |
|
| 5 |
#include <string>
|
| 6 |
#include <vector>
|
|
|
|
| 8 |
#include <map>
|
| 9 |
#include <set>
|
| 10 |
|
| 11 |
+
static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
|
| 12 |
+
switch (type) {
|
| 13 |
+
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
|
| 14 |
+
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
|
| 15 |
+
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
|
| 16 |
+
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
|
| 17 |
+
case LLAMA_VOCAB_TYPE_UGM: return "UGM";
|
| 18 |
+
case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
|
| 19 |
+
default: return "unknown";
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
struct llm_tokenizer;
|
| 24 |
|
| 25 |
struct llama_vocab {
|
|
|
|
| 57 |
id special_unk_id = 0;
|
| 58 |
id special_sep_id = LLAMA_TOKEN_NULL;
|
| 59 |
id special_pad_id = LLAMA_TOKEN_NULL;
|
| 60 |
+
id special_cls_id = LLAMA_TOKEN_NULL; // TODO: revisit if this is really needed https://github.com/ggerganov/llama.cpp/pull/10930
|
| 61 |
id special_mask_id = LLAMA_TOKEN_NULL;
|
| 62 |
|
| 63 |
id linefeed_id = 13;
|
examples/talk-llama/llama.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/llama.h
CHANGED
|
@@ -34,7 +34,6 @@
|
|
| 34 |
|
| 35 |
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
| 36 |
|
| 37 |
-
// TODO: use everywhere in the implementation
|
| 38 |
#define LLAMA_TOKEN_NULL -1
|
| 39 |
|
| 40 |
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
|
@@ -105,6 +104,7 @@ extern "C" {
|
|
| 105 |
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
| 106 |
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
| 107 |
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
|
|
|
| 108 |
};
|
| 109 |
|
| 110 |
enum llama_rope_type {
|
|
@@ -385,6 +385,7 @@ extern "C" {
|
|
| 385 |
} llama_chat_message;
|
| 386 |
|
| 387 |
// lora adapter
|
|
|
|
| 388 |
struct llama_lora_adapter;
|
| 389 |
|
| 390 |
// Helpers for getting default parameters
|
|
@@ -412,11 +413,19 @@ extern "C" {
|
|
| 412 |
// Call once at the end of the program - currently only used for MPI
|
| 413 |
LLAMA_API void llama_backend_free(void);
|
| 414 |
|
| 415 |
-
LLAMA_API struct llama_model * llama_load_model_from_file(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
const char * path_model,
|
| 417 |
struct llama_model_params params);
|
| 418 |
|
| 419 |
-
LLAMA_API void llama_free_model(struct llama_model * model)
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
// TODO: rename to llama_init_from_model
|
| 422 |
LLAMA_API struct llama_context * llama_new_context_with_model(
|
|
@@ -482,9 +491,6 @@ extern "C" {
|
|
| 482 |
// Returns the total number of parameters in the model
|
| 483 |
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
| 484 |
|
| 485 |
-
// Get a llama model tensor
|
| 486 |
-
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
|
| 487 |
-
|
| 488 |
// Returns true if the model contains an encoder that requires llama_encode() call
|
| 489 |
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
|
| 490 |
|
|
@@ -504,14 +510,19 @@ extern "C" {
|
|
| 504 |
const char * fname_out,
|
| 505 |
const llama_model_quantize_params * params);
|
| 506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
// Load a LoRA adapter from file
|
| 508 |
-
//
|
| 509 |
LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
|
| 510 |
struct llama_model * model,
|
| 511 |
const char * path_lora);
|
| 512 |
|
| 513 |
// Add a loaded LoRA adapter to given context
|
| 514 |
// This will not modify model's weight
|
|
|
|
| 515 |
LLAMA_API int32_t llama_lora_adapter_set(
|
| 516 |
struct llama_context * ctx,
|
| 517 |
struct llama_lora_adapter * adapter,
|
|
@@ -519,16 +530,18 @@ extern "C" {
|
|
| 519 |
|
| 520 |
// Remove a specific LoRA adapter from given context
|
| 521 |
// Return -1 if the adapter is not present in the context
|
|
|
|
| 522 |
LLAMA_API int32_t llama_lora_adapter_remove(
|
| 523 |
struct llama_context * ctx,
|
| 524 |
struct llama_lora_adapter * adapter);
|
| 525 |
|
| 526 |
// Remove all LoRA adapters from given context
|
| 527 |
-
|
| 528 |
-
|
| 529 |
|
| 530 |
// Manually free a LoRA adapter
|
| 531 |
// Note: loaded adapters will be free when the associated model is deleted
|
|
|
|
| 532 |
LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
|
| 533 |
|
| 534 |
// Apply a loaded control vector to a llama_context, or if data is NULL, clear
|
|
@@ -537,6 +550,7 @@ extern "C" {
|
|
| 537 |
// to an n_embd x n_layers buffer starting from layer 1.
|
| 538 |
// il_start and il_end are the layer range the vector should apply to (both inclusive)
|
| 539 |
// See llama_control_vector_load in common to load a control vector.
|
|
|
|
| 540 |
LLAMA_API int32_t llama_control_vector_apply(
|
| 541 |
struct llama_context * lctx,
|
| 542 |
const float * data,
|
|
@@ -549,6 +563,8 @@ extern "C" {
|
|
| 549 |
// KV cache
|
| 550 |
//
|
| 551 |
|
|
|
|
|
|
|
| 552 |
// Information associated with an individual cell in the KV cache view.
|
| 553 |
struct llama_kv_cache_view_cell {
|
| 554 |
// The position for this cell. Takes KV cache shifts into account.
|
|
@@ -595,8 +611,11 @@ extern "C" {
|
|
| 595 |
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
| 596 |
|
| 597 |
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
|
|
|
| 598 |
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
| 599 |
|
|
|
|
|
|
|
| 600 |
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
| 601 |
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
| 602 |
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
|
@@ -666,6 +685,9 @@ extern "C" {
|
|
| 666 |
struct llama_context * ctx,
|
| 667 |
llama_seq_id seq_id);
|
| 668 |
|
|
|
|
|
|
|
|
|
|
| 669 |
// Defragment the KV cache
|
| 670 |
// This will be applied:
|
| 671 |
// - lazily on next llama_decode()
|
|
|
|
| 34 |
|
| 35 |
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
| 36 |
|
|
|
|
| 37 |
#define LLAMA_TOKEN_NULL -1
|
| 38 |
|
| 39 |
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
|
|
|
| 104 |
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
| 105 |
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
| 106 |
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
| 107 |
+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
| 108 |
};
|
| 109 |
|
| 110 |
enum llama_rope_type {
|
|
|
|
| 385 |
} llama_chat_message;
|
| 386 |
|
| 387 |
// lora adapter
|
| 388 |
+
// TODO: rename to llama_adapter_lora
|
| 389 |
struct llama_lora_adapter;
|
| 390 |
|
| 391 |
// Helpers for getting default parameters
|
|
|
|
| 413 |
// Call once at the end of the program - currently only used for MPI
|
| 414 |
LLAMA_API void llama_backend_free(void);
|
| 415 |
|
| 416 |
+
DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file(
|
| 417 |
+
const char * path_model,
|
| 418 |
+
struct llama_model_params params),
|
| 419 |
+
"use llama_model_load_from_file instead");
|
| 420 |
+
|
| 421 |
+
LLAMA_API struct llama_model * llama_model_load_from_file(
|
| 422 |
const char * path_model,
|
| 423 |
struct llama_model_params params);
|
| 424 |
|
| 425 |
+
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
|
| 426 |
+
"use llama_model_free instead");
|
| 427 |
+
|
| 428 |
+
LLAMA_API void llama_model_free(struct llama_model * model);
|
| 429 |
|
| 430 |
// TODO: rename to llama_init_from_model
|
| 431 |
LLAMA_API struct llama_context * llama_new_context_with_model(
|
|
|
|
| 491 |
// Returns the total number of parameters in the model
|
| 492 |
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
| 493 |
|
|
|
|
|
|
|
|
|
|
| 494 |
// Returns true if the model contains an encoder that requires llama_encode() call
|
| 495 |
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
|
| 496 |
|
|
|
|
| 510 |
const char * fname_out,
|
| 511 |
const llama_model_quantize_params * params);
|
| 512 |
|
| 513 |
+
//
|
| 514 |
+
// Adapters
|
| 515 |
+
//
|
| 516 |
+
|
| 517 |
// Load a LoRA adapter from file
|
| 518 |
+
// TODO: rename to llama_adapter_lora_init
|
| 519 |
LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
|
| 520 |
struct llama_model * model,
|
| 521 |
const char * path_lora);
|
| 522 |
|
| 523 |
// Add a loaded LoRA adapter to given context
|
| 524 |
// This will not modify model's weight
|
| 525 |
+
// TODO: rename to llama_set_adapter_lora
|
| 526 |
LLAMA_API int32_t llama_lora_adapter_set(
|
| 527 |
struct llama_context * ctx,
|
| 528 |
struct llama_lora_adapter * adapter,
|
|
|
|
| 530 |
|
| 531 |
// Remove a specific LoRA adapter from given context
|
| 532 |
// Return -1 if the adapter is not present in the context
|
| 533 |
+
// TODO: rename to llama_rm_adapter_lora
|
| 534 |
LLAMA_API int32_t llama_lora_adapter_remove(
|
| 535 |
struct llama_context * ctx,
|
| 536 |
struct llama_lora_adapter * adapter);
|
| 537 |
|
| 538 |
// Remove all LoRA adapters from given context
|
| 539 |
+
// TODO: rename to llama_clear_adapter_lora
|
| 540 |
+
LLAMA_API void llama_lora_adapter_clear(struct llama_context * ctx);
|
| 541 |
|
| 542 |
// Manually free a LoRA adapter
|
| 543 |
// Note: loaded adapters will be free when the associated model is deleted
|
| 544 |
+
// TODO: rename to llama_adapter_lora_free
|
| 545 |
LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
|
| 546 |
|
| 547 |
// Apply a loaded control vector to a llama_context, or if data is NULL, clear
|
|
|
|
| 550 |
// to an n_embd x n_layers buffer starting from layer 1.
|
| 551 |
// il_start and il_end are the layer range the vector should apply to (both inclusive)
|
| 552 |
// See llama_control_vector_load in common to load a control vector.
|
| 553 |
+
// TODO: rename to llama_adapter_cvec_apply
|
| 554 |
LLAMA_API int32_t llama_control_vector_apply(
|
| 555 |
struct llama_context * lctx,
|
| 556 |
const float * data,
|
|
|
|
| 563 |
// KV cache
|
| 564 |
//
|
| 565 |
|
| 566 |
+
// TODO: remove llama_kv_cache_view_* API
|
| 567 |
+
|
| 568 |
// Information associated with an individual cell in the KV cache view.
|
| 569 |
struct llama_kv_cache_view_cell {
|
| 570 |
// The position for this cell. Takes KV cache shifts into account.
|
|
|
|
| 611 |
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
| 612 |
|
| 613 |
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
| 614 |
+
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
|
| 615 |
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
| 616 |
|
| 617 |
+
///
|
| 618 |
+
|
| 619 |
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
| 620 |
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
| 621 |
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
|
|
|
| 685 |
struct llama_context * ctx,
|
| 686 |
llama_seq_id seq_id);
|
| 687 |
|
| 688 |
+
// TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
|
| 689 |
+
// how to avoid this?
|
| 690 |
+
|
| 691 |
// Defragment the KV cache
|
| 692 |
// This will be applied:
|
| 693 |
// - lazily on next llama_decode()
|
examples/talk-llama/talk-llama.cpp
CHANGED
|
@@ -304,7 +304,7 @@ int main(int argc, char ** argv) {
|
|
| 304 |
lmparams.n_gpu_layers = params.n_gpu_layers;
|
| 305 |
}
|
| 306 |
|
| 307 |
-
struct llama_model * model_llama =
|
| 308 |
if (!model_llama) {
|
| 309 |
fprintf(stderr, "No llama.cpp model specified. Please provide using -ml <modelfile>\n");
|
| 310 |
return 1;
|
|
|
|
| 304 |
lmparams.n_gpu_layers = params.n_gpu_layers;
|
| 305 |
}
|
| 306 |
|
| 307 |
+
struct llama_model * model_llama = llama_model_load_from_file(params.model_llama.c_str(), lmparams);
|
| 308 |
if (!model_llama) {
|
| 309 |
fprintf(stderr, "No llama.cpp model specified. Please provide using -ml <modelfile>\n");
|
| 310 |
return 1;
|
examples/talk-llama/unicode.cpp
CHANGED
|
@@ -667,18 +667,24 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|
| 667 |
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
| 668 |
{ "\\p{L}", unicode_cpt_flags::LETTER },
|
| 669 |
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
|
|
|
|
|
|
|
| 670 |
};
|
| 671 |
|
| 672 |
static const std::map<int, int> k_ucat_cpt = {
|
| 673 |
{ unicode_cpt_flags::NUMBER, 0xD1 },
|
| 674 |
{ unicode_cpt_flags::LETTER, 0xD2 },
|
| 675 |
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
|
|
|
|
|
|
|
| 676 |
};
|
| 677 |
|
| 678 |
static const std::map<int, std::string> k_ucat_map = {
|
| 679 |
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
|
| 680 |
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
| 681 |
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
|
|
|
|
|
|
| 682 |
};
|
| 683 |
|
| 684 |
// compute collapsed codepoints only if needed by at least one regex
|
|
|
|
| 667 |
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
| 668 |
{ "\\p{L}", unicode_cpt_flags::LETTER },
|
| 669 |
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
|
| 670 |
+
{ "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
|
| 671 |
+
{ "\\p{S}", unicode_cpt_flags::SYMBOL },
|
| 672 |
};
|
| 673 |
|
| 674 |
static const std::map<int, int> k_ucat_cpt = {
|
| 675 |
{ unicode_cpt_flags::NUMBER, 0xD1 },
|
| 676 |
{ unicode_cpt_flags::LETTER, 0xD2 },
|
| 677 |
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
|
| 678 |
+
{ unicode_cpt_flags::ACCENT_MARK, 0xD4 },
|
| 679 |
+
{ unicode_cpt_flags::SYMBOL, 0xD5 },
|
| 680 |
};
|
| 681 |
|
| 682 |
static const std::map<int, std::string> k_ucat_map = {
|
| 683 |
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
|
| 684 |
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
| 685 |
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
| 686 |
+
{ unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
|
| 687 |
+
{ unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
|
| 688 |
};
|
| 689 |
|
| 690 |
// compute collapsed codepoints only if needed by at least one regex
|