Spaces:
Running
Running
whisper : whisper_state/backend fixes (#2217)
Browse files* whisper : fixes
* ci : WHISPER_CUBLAS -> WHISPER_CUDA
- .github/workflows/build.yml +1 -1
- whisper-mel-cuda.cu +2 -2
- whisper-mel.hpp +4 -15
- whisper.cpp +53 -72
.github/workflows/build.yml
CHANGED
|
@@ -498,7 +498,7 @@ jobs:
|
|
| 498 |
run: >
|
| 499 |
cmake -S . -B ./build -A ${{ matrix.arch }}
|
| 500 |
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
| 501 |
-
-
|
| 502 |
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
| 503 |
|
| 504 |
- name: Build ${{ matrix.cuda-toolkit }}
|
|
|
|
| 498 |
run: >
|
| 499 |
cmake -S . -B ./build -A ${{ matrix.arch }}
|
| 500 |
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
| 501 |
+
-DWHISPER_CUDA=${{ matrix.cublas }}
|
| 502 |
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
| 503 |
|
| 504 |
- name: Build ${{ matrix.cuda-toolkit }}
|
whisper-mel-cuda.cu
CHANGED
|
@@ -194,7 +194,7 @@ class mel_calc_cuda : public whisper_mel_calc {
|
|
| 194 |
size_t m_log_mel_temp_storage_size = 0;
|
| 195 |
void * m_log_mel_temp_storage = nullptr;
|
| 196 |
public:
|
| 197 |
-
mel_calc_cuda(ggml_backend_t backend, const whisper_filters& filters)
|
| 198 |
: m_n_mel(filters.n_mel)
|
| 199 |
, m_backend(backend)
|
| 200 |
{
|
|
@@ -305,7 +305,7 @@ public:
|
|
| 305 |
whisper_mel ret;
|
| 306 |
// Calculate semi-padded sample length to ensure compatibility
|
| 307 |
int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
| 308 |
-
ret
|
| 309 |
assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
|
| 310 |
|
| 311 |
float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
|
|
|
|
| 194 |
size_t m_log_mel_temp_storage_size = 0;
|
| 195 |
void * m_log_mel_temp_storage = nullptr;
|
| 196 |
public:
|
| 197 |
+
mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters)
|
| 198 |
: m_n_mel(filters.n_mel)
|
| 199 |
, m_backend(backend)
|
| 200 |
{
|
|
|
|
| 305 |
whisper_mel ret;
|
| 306 |
// Calculate semi-padded sample length to ensure compatibility
|
| 307 |
int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
| 308 |
+
whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel);
|
| 309 |
assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
|
| 310 |
|
| 311 |
float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
|
whisper-mel.hpp
CHANGED
|
@@ -5,22 +5,14 @@
|
|
| 5 |
struct whisper_mel {
|
| 6 |
int n_len_org = 0;
|
| 7 |
|
| 8 |
-
ggml_tensor * tensor = nullptr;
|
| 9 |
ggml_context * ctx = nullptr;
|
|
|
|
| 10 |
ggml_backend_buffer_t buffer = nullptr;
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
~whisper_mel();
|
| 14 |
-
|
| 15 |
-
whisper_mel(const whisper_mel &) = delete;
|
| 16 |
-
whisper_mel & operator=(const whisper_mel &) = delete;
|
| 17 |
-
whisper_mel(whisper_mel &&) noexcept;
|
| 18 |
-
whisper_mel & operator=(whisper_mel &&) noexcept;
|
| 19 |
|
| 20 |
-
|
| 21 |
-
void reset();
|
| 22 |
-
void take(whisper_mel & other) noexcept;
|
| 23 |
-
};
|
| 24 |
|
| 25 |
struct whisper_filters {
|
| 26 |
int32_t n_mel;
|
|
@@ -40,6 +32,3 @@ struct whisper_mel_calc {
|
|
| 40 |
virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) const = 0;
|
| 41 |
static whisper_span<const float> hann_window();
|
| 42 |
};
|
| 43 |
-
|
| 44 |
-
// returns a new pointer which needs to be freed with delete
|
| 45 |
-
whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters);
|
|
|
|
| 5 |
struct whisper_mel {
|
| 6 |
int n_len_org = 0;
|
| 7 |
|
|
|
|
| 8 |
ggml_context * ctx = nullptr;
|
| 9 |
+
ggml_tensor * tensor = nullptr;
|
| 10 |
ggml_backend_buffer_t buffer = nullptr;
|
| 11 |
+
};
|
| 12 |
|
| 13 |
+
void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
void whisper_mel_free(whisper_mel & mel);
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
struct whisper_filters {
|
| 18 |
int32_t n_mel;
|
|
|
|
| 32 |
virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) const = 0;
|
| 33 |
static whisper_span<const float> hann_window();
|
| 34 |
};
|
|
|
|
|
|
|
|
|
whisper.cpp
CHANGED
|
@@ -801,6 +801,7 @@ struct whisper_state {
|
|
| 801 |
whisper_kv_cache kv_pad;
|
| 802 |
|
| 803 |
whisper_mel mel;
|
|
|
|
| 804 |
|
| 805 |
whisper_batch batch;
|
| 806 |
|
|
@@ -870,8 +871,6 @@ struct whisper_context {
|
|
| 870 |
whisper_model model;
|
| 871 |
whisper_vocab vocab;
|
| 872 |
|
| 873 |
-
whisper_mel_calc * mel_calc = nullptr;
|
| 874 |
-
|
| 875 |
whisper_state * state = nullptr;
|
| 876 |
|
| 877 |
ggml_backend_t backend = nullptr;
|
|
@@ -893,7 +892,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
| 893 |
BYTESWAP_VALUE(dest);
|
| 894 |
}
|
| 895 |
|
| 896 |
-
static bool
|
| 897 |
struct whisper_kv_cache & cache,
|
| 898 |
ggml_backend_t backend,
|
| 899 |
ggml_type wtype,
|
|
@@ -936,7 +935,7 @@ static bool kv_cache_init(
|
|
| 936 |
return true;
|
| 937 |
}
|
| 938 |
|
| 939 |
-
static void
|
| 940 |
ggml_free(cache.ctx);
|
| 941 |
ggml_backend_buffer_free(cache.buffer);
|
| 942 |
cache.ctx = nullptr;
|
|
@@ -1250,9 +1249,12 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
|
|
| 1250 |
}
|
| 1251 |
#endif
|
| 1252 |
|
|
|
|
|
|
|
| 1253 |
if (backend_gpu) {
|
| 1254 |
return backend_gpu;
|
| 1255 |
}
|
|
|
|
| 1256 |
return ggml_backend_cpu_init();
|
| 1257 |
}
|
| 1258 |
|
|
@@ -2885,52 +2887,25 @@ struct whisper_global_cache {
|
|
| 2885 |
|
| 2886 |
// Mel spectrogram
|
| 2887 |
|
| 2888 |
-
whisper_mel
|
| 2889 |
-
|
| 2890 |
-
|
| 2891 |
-
|
| 2892 |
-
|
| 2893 |
-
|
| 2894 |
-
|
| 2895 |
-
|
| 2896 |
-
|
| 2897 |
-
if (this != &other) {
|
| 2898 |
-
reset();
|
| 2899 |
-
take(other);
|
| 2900 |
-
}
|
| 2901 |
-
return *this;
|
| 2902 |
-
}
|
| 2903 |
-
|
| 2904 |
-
void whisper_mel::init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
|
| 2905 |
-
this->n_len_org = n_len_org;
|
| 2906 |
-
assert(!ctx);
|
| 2907 |
-
ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
|
| 2908 |
-
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_len, n_mel);
|
| 2909 |
-
buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(tensor) + ggml_backend_get_alignment(backend));
|
| 2910 |
-
auto alloc = ggml_tallocr_new(buffer);
|
| 2911 |
-
ggml_tallocr_alloc(&alloc, tensor);
|
| 2912 |
-
}
|
| 2913 |
-
|
| 2914 |
-
void whisper_mel::reset() {
|
| 2915 |
-
ggml_free(ctx);
|
| 2916 |
-
ggml_backend_buffer_free(buffer);
|
| 2917 |
-
|
| 2918 |
-
n_len_org = 0;
|
| 2919 |
-
tensor = nullptr;
|
| 2920 |
-
ctx = nullptr;
|
| 2921 |
-
buffer = nullptr;
|
| 2922 |
}
|
| 2923 |
|
| 2924 |
-
void
|
| 2925 |
-
|
| 2926 |
-
|
| 2927 |
-
ctx = other.ctx;
|
| 2928 |
-
buffer = other.buffer;
|
| 2929 |
|
| 2930 |
-
|
| 2931 |
-
|
| 2932 |
-
|
| 2933 |
-
|
| 2934 |
}
|
| 2935 |
|
| 2936 |
whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
|
|
@@ -3026,7 +3001,7 @@ struct whisper_mel_data {
|
|
| 3026 |
int n_len;
|
| 3027 |
int n_len_org;
|
| 3028 |
int n_mel;
|
| 3029 |
-
float* data;
|
| 3030 |
};
|
| 3031 |
|
| 3032 |
void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
|
@@ -3100,7 +3075,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
|
|
| 3100 |
|
| 3101 |
struct mel_calc_cpu : public whisper_mel_calc {
|
| 3102 |
ggml_backend_t m_backend;
|
| 3103 |
-
const whisper_filters& m_filters;
|
| 3104 |
mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
|
| 3105 |
|
| 3106 |
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
|
|
@@ -3137,7 +3112,7 @@ struct mel_calc_cpu : public whisper_mel_calc {
|
|
| 3137 |
std::vector<float> host_mel_data;
|
| 3138 |
|
| 3139 |
whisper_mel ret;
|
| 3140 |
-
ret
|
| 3141 |
if (ggml_backend_buffer_is_host(ret.buffer)) {
|
| 3142 |
mel.data = reinterpret_cast<float*>(ret.tensor->data);
|
| 3143 |
} else {
|
|
@@ -3325,15 +3300,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3325 |
return nullptr;
|
| 3326 |
}
|
| 3327 |
|
|
|
|
|
|
|
| 3328 |
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
|
| 3329 |
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
| 3330 |
const int factor = 3;
|
| 3331 |
|
| 3332 |
-
if (!
|
| 3333 |
ctx->model.hparams.n_text_state,
|
| 3334 |
ctx->model.hparams.n_text_layer,
|
| 3335 |
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
| 3336 |
-
WHISPER_LOG_ERROR("%s:
|
| 3337 |
whisper_free_state(state);
|
| 3338 |
return nullptr;
|
| 3339 |
}
|
|
@@ -3343,11 +3320,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3343 |
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3344 |
}
|
| 3345 |
|
| 3346 |
-
if (!
|
| 3347 |
ctx->model.hparams.n_text_state,
|
| 3348 |
ctx->model.hparams.n_text_layer,
|
| 3349 |
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
| 3350 |
-
WHISPER_LOG_ERROR("%s:
|
| 3351 |
whisper_free_state(state);
|
| 3352 |
return nullptr;
|
| 3353 |
}
|
|
@@ -3357,11 +3334,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3357 |
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3358 |
}
|
| 3359 |
|
| 3360 |
-
if (!
|
| 3361 |
ctx->model.hparams.n_audio_state,
|
| 3362 |
1,
|
| 3363 |
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
| 3364 |
-
WHISPER_LOG_ERROR("%s:
|
| 3365 |
whisper_free_state(state);
|
| 3366 |
return nullptr;
|
| 3367 |
}
|
|
@@ -3373,7 +3350,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3373 |
|
| 3374 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 3375 |
if (ctx->params.dtw_token_timestamps) {
|
| 3376 |
-
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks,
|
| 3377 |
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
|
| 3378 |
whisper_free_state(state);
|
| 3379 |
return nullptr;
|
|
@@ -3416,7 +3393,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3416 |
|
| 3417 |
// conv allocator
|
| 3418 |
{
|
| 3419 |
-
bool ok = whisper_allocr_graph_init(state->alloc_conv,
|
| 3420 |
[&]() {
|
| 3421 |
return whisper_build_graph_conv(*ctx, *state, 0);
|
| 3422 |
});
|
|
@@ -3432,7 +3409,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3432 |
|
| 3433 |
// encoder allocator
|
| 3434 |
if (!whisper_encode_external(*state)) {
|
| 3435 |
-
bool ok = whisper_allocr_graph_init(state->alloc_encode,
|
| 3436 |
[&]() {
|
| 3437 |
return whisper_build_graph_encoder(*ctx, *state);
|
| 3438 |
});
|
|
@@ -3448,7 +3425,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3448 |
|
| 3449 |
// cross allocator
|
| 3450 |
{
|
| 3451 |
-
bool ok = whisper_allocr_graph_init(state->alloc_cross,
|
| 3452 |
[&]() {
|
| 3453 |
return whisper_build_graph_cross(*ctx, *state);
|
| 3454 |
});
|
|
@@ -3464,7 +3441,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3464 |
|
| 3465 |
// decoder allocator
|
| 3466 |
{
|
| 3467 |
-
bool ok = whisper_allocr_graph_init(state->alloc_decode,
|
| 3468 |
[&]() {
|
| 3469 |
const auto & hparams = ctx->model.hparams;
|
| 3470 |
|
|
@@ -3660,8 +3637,6 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
|
|
| 3660 |
return nullptr;
|
| 3661 |
}
|
| 3662 |
|
| 3663 |
-
ctx->mel_calc = whisper_mel_calc_create(ctx->backend, ctx->model.filters);
|
| 3664 |
-
|
| 3665 |
loader->close(loader->context);
|
| 3666 |
|
| 3667 |
return ctx;
|
|
@@ -3738,9 +3713,14 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
|
| 3738 |
|
| 3739 |
void whisper_free_state(struct whisper_state * state) {
|
| 3740 |
if (state) {
|
| 3741 |
-
|
| 3742 |
-
|
| 3743 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3744 |
|
| 3745 |
#ifdef WHISPER_USE_COREML
|
| 3746 |
if (state->ctx_coreml != nullptr) {
|
|
@@ -3782,8 +3762,6 @@ void whisper_free(struct whisper_context * ctx) {
|
|
| 3782 |
|
| 3783 |
ggml_backend_free(ctx->backend);
|
| 3784 |
|
| 3785 |
-
delete ctx->mel_calc;
|
| 3786 |
-
ctx->mel_calc = nullptr;
|
| 3787 |
delete ctx;
|
| 3788 |
}
|
| 3789 |
}
|
|
@@ -3800,9 +3778,11 @@ void whisper_free_params(struct whisper_full_params * params) {
|
|
| 3800 |
}
|
| 3801 |
}
|
| 3802 |
|
| 3803 |
-
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx
|
| 3804 |
const int64_t t_start_us = ggml_time_us();
|
| 3805 |
-
|
|
|
|
|
|
|
| 3806 |
state->t_mel_us += ggml_time_us() - t_start_us;
|
| 3807 |
|
| 3808 |
// Dump log_mel_spectrogram
|
|
@@ -3834,8 +3814,9 @@ int whisper_set_mel_with_state(
|
|
| 3834 |
return -1;
|
| 3835 |
}
|
| 3836 |
|
| 3837 |
-
state->mel
|
| 3838 |
-
state->mel
|
|
|
|
| 3839 |
ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
|
| 3840 |
|
| 3841 |
return 0;
|
|
|
|
| 801 |
whisper_kv_cache kv_pad;
|
| 802 |
|
| 803 |
whisper_mel mel;
|
| 804 |
+
whisper_mel_calc * mel_calc = nullptr;
|
| 805 |
|
| 806 |
whisper_batch batch;
|
| 807 |
|
|
|
|
| 871 |
whisper_model model;
|
| 872 |
whisper_vocab vocab;
|
| 873 |
|
|
|
|
|
|
|
| 874 |
whisper_state * state = nullptr;
|
| 875 |
|
| 876 |
ggml_backend_t backend = nullptr;
|
|
|
|
| 892 |
BYTESWAP_VALUE(dest);
|
| 893 |
}
|
| 894 |
|
| 895 |
+
static bool whisper_kv_cache_init(
|
| 896 |
struct whisper_kv_cache & cache,
|
| 897 |
ggml_backend_t backend,
|
| 898 |
ggml_type wtype,
|
|
|
|
| 935 |
return true;
|
| 936 |
}
|
| 937 |
|
| 938 |
+
static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
|
| 939 |
ggml_free(cache.ctx);
|
| 940 |
ggml_backend_buffer_free(cache.buffer);
|
| 941 |
cache.ctx = nullptr;
|
|
|
|
| 1249 |
}
|
| 1250 |
#endif
|
| 1251 |
|
| 1252 |
+
GGML_UNUSED(params);
|
| 1253 |
+
|
| 1254 |
if (backend_gpu) {
|
| 1255 |
return backend_gpu;
|
| 1256 |
}
|
| 1257 |
+
|
| 1258 |
return ggml_backend_cpu_init();
|
| 1259 |
}
|
| 1260 |
|
|
|
|
| 2887 |
|
| 2888 |
// Mel spectrogram
|
| 2889 |
|
| 2890 |
+
void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
|
| 2891 |
+
WHISPER_LOG_INFO("%s: n_len = %d, n_len_org = %d, n_mel = %d\n", __func__, n_len, n_len_org, n_mel);
|
| 2892 |
+
mel.n_len_org = n_len_org;
|
| 2893 |
+
assert(!mel.ctx);
|
| 2894 |
+
mel.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
|
| 2895 |
+
mel.tensor = ggml_new_tensor_2d(mel.ctx, GGML_TYPE_F32, n_len, n_mel);
|
| 2896 |
+
mel.buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(mel.tensor) + ggml_backend_get_alignment(backend));
|
| 2897 |
+
auto alloc = ggml_tallocr_new(mel.buffer);
|
| 2898 |
+
ggml_tallocr_alloc(&alloc, mel.tensor);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2899 |
}
|
| 2900 |
|
| 2901 |
+
void whisper_mel_free(whisper_mel & mel) {
|
| 2902 |
+
ggml_free(mel.ctx);
|
| 2903 |
+
ggml_backend_buffer_free(mel.buffer);
|
|
|
|
|
|
|
| 2904 |
|
| 2905 |
+
mel.n_len_org = 0;
|
| 2906 |
+
mel.ctx = nullptr;
|
| 2907 |
+
mel.tensor = nullptr;
|
| 2908 |
+
mel.buffer = nullptr;
|
| 2909 |
}
|
| 2910 |
|
| 2911 |
whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
|
|
|
|
| 3001 |
int n_len;
|
| 3002 |
int n_len_org;
|
| 3003 |
int n_mel;
|
| 3004 |
+
float * data;
|
| 3005 |
};
|
| 3006 |
|
| 3007 |
void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
|
|
|
| 3075 |
|
| 3076 |
struct mel_calc_cpu : public whisper_mel_calc {
|
| 3077 |
ggml_backend_t m_backend;
|
| 3078 |
+
const whisper_filters & m_filters;
|
| 3079 |
mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
|
| 3080 |
|
| 3081 |
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
|
|
|
|
| 3112 |
std::vector<float> host_mel_data;
|
| 3113 |
|
| 3114 |
whisper_mel ret;
|
| 3115 |
+
whisper_mel_init(ret, m_backend, mel.n_len, mel.n_len_org, mel.n_mel);
|
| 3116 |
if (ggml_backend_buffer_is_host(ret.buffer)) {
|
| 3117 |
mel.data = reinterpret_cast<float*>(ret.tensor->data);
|
| 3118 |
} else {
|
|
|
|
| 3300 |
return nullptr;
|
| 3301 |
}
|
| 3302 |
|
| 3303 |
+
state->mel_calc = whisper_mel_calc_create(state->backend, ctx->model.filters);
|
| 3304 |
+
|
| 3305 |
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
|
| 3306 |
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
| 3307 |
const int factor = 3;
|
| 3308 |
|
| 3309 |
+
if (!whisper_kv_cache_init(state->kv_self, state->backend, ctx->itype,
|
| 3310 |
ctx->model.hparams.n_text_state,
|
| 3311 |
ctx->model.hparams.n_text_layer,
|
| 3312 |
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
| 3313 |
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
| 3314 |
whisper_free_state(state);
|
| 3315 |
return nullptr;
|
| 3316 |
}
|
|
|
|
| 3320 |
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3321 |
}
|
| 3322 |
|
| 3323 |
+
if (!whisper_kv_cache_init(state->kv_cross, state->backend, ctx->itype,
|
| 3324 |
ctx->model.hparams.n_text_state,
|
| 3325 |
ctx->model.hparams.n_text_layer,
|
| 3326 |
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
| 3327 |
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
|
| 3328 |
whisper_free_state(state);
|
| 3329 |
return nullptr;
|
| 3330 |
}
|
|
|
|
| 3334 |
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3335 |
}
|
| 3336 |
|
| 3337 |
+
if (!whisper_kv_cache_init(state->kv_pad, state->backend, ctx->itype,
|
| 3338 |
ctx->model.hparams.n_audio_state,
|
| 3339 |
1,
|
| 3340 |
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
| 3341 |
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
| 3342 |
whisper_free_state(state);
|
| 3343 |
return nullptr;
|
| 3344 |
}
|
|
|
|
| 3350 |
|
| 3351 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 3352 |
if (ctx->params.dtw_token_timestamps) {
|
| 3353 |
+
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backend)) {
|
| 3354 |
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
|
| 3355 |
whisper_free_state(state);
|
| 3356 |
return nullptr;
|
|
|
|
| 3393 |
|
| 3394 |
// conv allocator
|
| 3395 |
{
|
| 3396 |
+
bool ok = whisper_allocr_graph_init(state->alloc_conv, state->backend,
|
| 3397 |
[&]() {
|
| 3398 |
return whisper_build_graph_conv(*ctx, *state, 0);
|
| 3399 |
});
|
|
|
|
| 3409 |
|
| 3410 |
// encoder allocator
|
| 3411 |
if (!whisper_encode_external(*state)) {
|
| 3412 |
+
bool ok = whisper_allocr_graph_init(state->alloc_encode, state->backend,
|
| 3413 |
[&]() {
|
| 3414 |
return whisper_build_graph_encoder(*ctx, *state);
|
| 3415 |
});
|
|
|
|
| 3425 |
|
| 3426 |
// cross allocator
|
| 3427 |
{
|
| 3428 |
+
bool ok = whisper_allocr_graph_init(state->alloc_cross, state->backend,
|
| 3429 |
[&]() {
|
| 3430 |
return whisper_build_graph_cross(*ctx, *state);
|
| 3431 |
});
|
|
|
|
| 3441 |
|
| 3442 |
// decoder allocator
|
| 3443 |
{
|
| 3444 |
+
bool ok = whisper_allocr_graph_init(state->alloc_decode, state->backend,
|
| 3445 |
[&]() {
|
| 3446 |
const auto & hparams = ctx->model.hparams;
|
| 3447 |
|
|
|
|
| 3637 |
return nullptr;
|
| 3638 |
}
|
| 3639 |
|
|
|
|
|
|
|
| 3640 |
loader->close(loader->context);
|
| 3641 |
|
| 3642 |
return ctx;
|
|
|
|
| 3713 |
|
| 3714 |
void whisper_free_state(struct whisper_state * state) {
|
| 3715 |
if (state) {
|
| 3716 |
+
whisper_kv_cache_free(state->kv_self);
|
| 3717 |
+
whisper_kv_cache_free(state->kv_cross);
|
| 3718 |
+
whisper_kv_cache_free(state->kv_pad);
|
| 3719 |
+
|
| 3720 |
+
whisper_mel_free(state->mel);
|
| 3721 |
+
|
| 3722 |
+
delete state->mel_calc;
|
| 3723 |
+
state->mel_calc = nullptr;
|
| 3724 |
|
| 3725 |
#ifdef WHISPER_USE_COREML
|
| 3726 |
if (state->ctx_coreml != nullptr) {
|
|
|
|
| 3762 |
|
| 3763 |
ggml_backend_free(ctx->backend);
|
| 3764 |
|
|
|
|
|
|
|
| 3765 |
delete ctx;
|
| 3766 |
}
|
| 3767 |
}
|
|
|
|
| 3778 |
}
|
| 3779 |
}
|
| 3780 |
|
| 3781 |
+
int whisper_pcm_to_mel_with_state(struct whisper_context * /*ctx*/, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 3782 |
const int64_t t_start_us = ggml_time_us();
|
| 3783 |
+
|
| 3784 |
+
state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads);
|
| 3785 |
+
|
| 3786 |
state->t_mel_us += ggml_time_us() - t_start_us;
|
| 3787 |
|
| 3788 |
// Dump log_mel_spectrogram
|
|
|
|
| 3814 |
return -1;
|
| 3815 |
}
|
| 3816 |
|
| 3817 |
+
whisper_mel_free(state->mel);
|
| 3818 |
+
whisper_mel_init(state->mel, ctx->backend, n_len, n_len, n_mel);
|
| 3819 |
+
|
| 3820 |
ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
|
| 3821 |
|
| 3822 |
return 0;
|