stanimirovb commited on
Commit
521186a
·
unverified ·
1 Parent(s): c6894d3

whisper : calculate mel spectrogram directly into a ggml_tensor (#2208)

Browse files

* whisper : calculate mel spectrogram directly into a ggml_tensor

* whisper : remove unused temp buffer from state

* whisper : fix not initializing wstate.embd_enc

Files changed (3) hide show
  1. whisper-mel-cuda.cu +9 -12
  2. whisper-mel.hpp +16 -4
  3. whisper.cpp +119 -51
whisper-mel-cuda.cu CHANGED
@@ -8,6 +8,7 @@
8
  #include <cublas_v2.h>
9
  #include <cuComplex.h>
10
  #include <cub/device/device_reduce.cuh>
 
11
 
12
  #include <algorithm>
13
 
@@ -301,27 +302,23 @@ public:
301
  &fzero,
302
  mel_data, int(n_mag_frames)));
303
 
304
- float * log_mels = nullptr;
305
- CUDA_CHECK(cudaMallocAsync(&log_mels, m_n_mel * n_mag_frames * sizeof(float), m_stream));
 
 
 
 
 
306
 
307
  calc_log_mel(
308
  mel_data, int(m_n_mel * n_mag_frames),
309
- m_log_mel_temp_storage, int(m_log_mel_temp_storage_size),
310
  log_mels, m_stream);
311
 
312
- whisper_mel ret;
313
- ret.n_mel = m_n_mel;
314
- ret.n_len = int(n_mag_frames);
315
- // Calculate semi-padded sample length to ensure compatibility
316
- ret.n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
317
- ret.data.resize(m_n_mel * n_mag_frames);
318
- CUDA_CHECK(cudaMemcpyAsync(ret.data.data(), log_mels, ret.data.size() * sizeof(float), cudaMemcpyDeviceToHost, m_stream));
319
-
320
  CUDA_CHECK(cudaStreamSynchronize(m_stream));
321
 
322
  // cleanup
323
  CUFFT_CHECK(cufftDestroy(plan));
324
- CUDA_CHECK(cudaFreeAsync(log_mels, m_stream));
325
  CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
326
  CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
327
  CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
 
8
  #include <cublas_v2.h>
9
  #include <cuComplex.h>
10
  #include <cub/device/device_reduce.cuh>
11
+ #include <device_launch_parameters.h>
12
 
13
  #include <algorithm>
14
 
 
302
  &fzero,
303
  mel_data, int(n_mag_frames)));
304
 
305
+ whisper_mel ret;
306
+ // Calculate semi-padded sample length to ensure compatibility
307
+ int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
308
+ ret.init(m_backend, int(n_mag_frames), n_len_org, m_n_mel);
309
+ assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
310
+
311
+ float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
312
 
313
  calc_log_mel(
314
  mel_data, int(m_n_mel * n_mag_frames),
315
+ m_log_mel_temp_storage , int(m_log_mel_temp_storage_size),
316
  log_mels, m_stream);
317
 
 
 
 
 
 
 
 
 
318
  CUDA_CHECK(cudaStreamSynchronize(m_stream));
319
 
320
  // cleanup
321
  CUFFT_CHECK(cufftDestroy(plan));
 
322
  CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
323
  CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
324
  CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
whisper-mel.hpp CHANGED
@@ -3,11 +3,23 @@
3
  #include <vector>
4
 
5
  struct whisper_mel {
6
- int n_len;
7
- int n_len_org;
8
- int n_mel;
9
 
10
- std::vector<float> data;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  };
12
 
13
  struct whisper_filters {
 
3
  #include <vector>
4
 
5
  struct whisper_mel {
6
+ int n_len_org = 0;
 
 
7
 
8
+ ggml_tensor * tensor = nullptr;
9
+ ggml_context * ctx = nullptr;
10
+ ggml_backend_buffer_t buffer = nullptr;
11
+
12
+ whisper_mel() = default;
13
+ ~whisper_mel();
14
+
15
+ whisper_mel(const whisper_mel &) = delete;
16
+ whisper_mel & operator=(const whisper_mel &) = delete;
17
+ whisper_mel(whisper_mel &&) noexcept;
18
+ whisper_mel & operator=(whisper_mel &&) noexcept;
19
+
20
+ void init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
21
+ void reset();
22
+ void take(whisper_mel & other) noexcept;
23
  };
24
 
25
  struct whisper_filters {
whisper.cpp CHANGED
@@ -821,7 +821,6 @@ struct whisper_state {
821
  struct ggml_tensor * embd_enc = nullptr;
822
 
823
  // helpers for GPU offloading
824
- std::vector<float> inp_mel;
825
  std::vector<float> inp_mask;
826
 
827
  // decode output (2-dimensional array: [n_tokens][n_vocab])
@@ -1815,7 +1814,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
1815
 
1816
  static struct ggml_cgraph * whisper_build_graph_conv(
1817
  whisper_context & wctx,
1818
- whisper_state & wstate) {
 
1819
  const auto & model = wctx.model;
1820
  const auto & hparams = model.hparams;
1821
 
@@ -1834,9 +1834,32 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1834
 
1835
  ggml_cgraph * gf = ggml_new_graph(ctx0);
1836
 
1837
- struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1838
- ggml_set_name(mel, "mel");
1839
- ggml_set_input(mel);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1840
 
1841
  struct ggml_tensor * cur = nullptr;
1842
 
@@ -2218,45 +2241,21 @@ static bool whisper_encode_internal(
2218
  {
2219
  auto & alloc = wstate.alloc_conv.alloc;
2220
 
2221
- ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
2222
 
2223
  if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2224
  // should never happen as we pre-allocate the memory
2225
  return false;
2226
  }
2227
 
2228
- struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
2229
-
2230
- // set the input
2231
- {
2232
- const auto & mel_inp = wstate.mel;
2233
- const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
2234
-
2235
- assert(mel->type == GGML_TYPE_F32);
2236
- assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
2237
-
2238
- wstate.inp_mel.resize(ggml_nelements(mel));
2239
-
2240
- float * dst = wstate.inp_mel.data();
2241
- memset(dst, 0, ggml_nbytes(mel));
2242
-
2243
- const int i0 = std::min(mel_offset, mel_inp.n_len);
2244
- const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
2245
-
2246
- for (int j = 0; j < mel_inp.n_mel; ++j) {
2247
- for (int i = i0; i < i1; ++i) {
2248
- dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
2249
- }
2250
- }
2251
-
2252
- ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
2253
  }
2254
 
2255
- if (!whisper_encode_external(wstate)) {
2256
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2257
- return false;
2258
- }
2259
- } else {
2260
  #if defined(WHISPER_USE_COREML)
2261
  whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
2262
  #elif defined(WHISPER_USE_OPENVINO)
@@ -2886,6 +2885,54 @@ struct whisper_global_cache {
2886
 
2887
  // Mel spectrogram
2888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2889
  whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
2890
 
2891
  whisper_span<const float> whisper_mel_calc::hann_window() {
@@ -2973,9 +3020,18 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2973
  }
2974
  }
2975
 
2976
- static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
 
 
 
 
 
 
 
 
 
2977
  int n_samples, int n_threads,
2978
- const whisper_filters & filters, whisper_mel & mel) {
2979
  const auto frame_size = WHISPER_N_FFT;
2980
  const auto frame_step = WHISPER_HOP_LENGTH;
2981
  std::vector<float> fft_in(frame_size, 0.0);
@@ -3041,10 +3097,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const
3041
  }
3042
  }
3043
  }
3044
- namespace {
3045
  struct mel_calc_cpu : public whisper_mel_calc {
 
3046
  const whisper_filters& m_filters;
3047
- mel_calc_cpu(const whisper_filters & filters) : m_filters(filters) {}
3048
 
3049
  // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
3050
  whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) const override {
@@ -3069,15 +3126,24 @@ struct mel_calc_cpu : public whisper_mel_calc {
3069
  // reflective pad 200 samples at the beginning of audio
3070
  std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
3071
 
3072
- whisper_mel mel;
3073
  mel.n_mel = m_filters.n_mel;
3074
  // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
3075
  // Calculate number of frames + remove the last frame
3076
  mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
3077
  // Calculate semi-padded sample length to ensure compatibility
3078
  mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
3079
- mel.data.resize(mel.n_mel * mel.n_len);
3080
 
 
 
 
 
 
 
 
 
 
 
3081
 
3082
  {
3083
  std::vector<std::thread> workers(n_threads - 1);
@@ -3114,7 +3180,12 @@ struct mel_calc_cpu : public whisper_mel_calc {
3114
  mel.data[i] = (mel.data[i] + 4.0)/4.0;
3115
  }
3116
 
3117
- return mel;
 
 
 
 
 
3118
  }
3119
  };
3120
  }
@@ -3129,7 +3200,7 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper
3129
  return ret;
3130
  } else
3131
  #endif
3132
- return new mel_calc_cpu(filters);
3133
  }
3134
 
3135
  // split text into tokens
@@ -3347,7 +3418,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3347
  {
3348
  bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
3349
  [&]() {
3350
- return whisper_build_graph_conv(*ctx, *state);
3351
  });
3352
 
3353
  if (!ok) {
@@ -3763,12 +3834,9 @@ int whisper_set_mel_with_state(
3763
  return -1;
3764
  }
3765
 
3766
- state->mel.n_len = n_len;
3767
- state->mel.n_len_org = n_len;
3768
- state->mel.n_mel = n_mel;
3769
-
3770
- state->mel.data.resize(n_len*n_mel);
3771
- memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
3772
 
3773
  return 0;
3774
  }
 
821
  struct ggml_tensor * embd_enc = nullptr;
822
 
823
  // helpers for GPU offloading
 
824
  std::vector<float> inp_mask;
825
 
826
  // decode output (2-dimensional array: [n_tokens][n_vocab])
 
1814
 
1815
  static struct ggml_cgraph * whisper_build_graph_conv(
1816
  whisper_context & wctx,
1817
+ whisper_state & wstate,
1818
+ const int mel_offset) {
1819
  const auto & model = wctx.model;
1820
  const auto & hparams = model.hparams;
1821
 
 
1834
 
1835
  ggml_cgraph * gf = ggml_new_graph(ctx0);
1836
 
1837
+ ggml_tensor * mel_inp = wstate.mel.tensor;
1838
+ ggml_tensor * mel;
1839
+ if (mel_inp) {
1840
+ const int n_len = int(mel_inp->ne[0]);
1841
+ const int out_s = 2 * n_ctx;
1842
+ const int i0 = std::min(mel_offset, n_len);
1843
+ const int i1 = std::min(mel_offset + out_s, n_len);
1844
+ const int mel_s = i1 - i0;
1845
+
1846
+ assert(mel_inp->type == GGML_TYPE_F32);
1847
+ assert(mel_inp->ne[1] == n_mels);
1848
+
1849
+ ggml_tensor * cur = ggml_view_2d(ctx0, mel_inp, out_s, n_mels, mel_inp->nb[1], ggml_row_size(mel_inp->type, i0));
1850
+
1851
+ if (mel_s < out_s) {
1852
+ mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0);
1853
+ }
1854
+ else {
1855
+ mel = ggml_cont(ctx0, cur);
1856
+ }
1857
+ }
1858
+ else {
1859
+ // just create some tensor so that the graph/buffer size estimation is correct
1860
+ mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels);
1861
+ }
1862
+ ggml_set_name(mel, "mel"); // used with external encoding
1863
 
1864
  struct ggml_tensor * cur = nullptr;
1865
 
 
2241
  {
2242
  auto & alloc = wstate.alloc_conv.alloc;
2243
 
2244
+ ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
2245
 
2246
  if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2247
  // should never happen as we pre-allocate the memory
2248
  return false;
2249
  }
2250
 
2251
+ if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2252
+ return false;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2253
  }
2254
 
2255
+ if (whisper_encode_external(wstate)) {
2256
+ ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
2257
+ assert(mel->ne[1] == wctx.model.hparams.n_mels);
2258
+ GGML_UNUSED(mel);
 
2259
  #if defined(WHISPER_USE_COREML)
2260
  whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
2261
  #elif defined(WHISPER_USE_OPENVINO)
 
2885
 
2886
  // Mel spectrogram
2887
 
2888
+ whisper_mel::~whisper_mel() {
2889
+ reset();
2890
+ }
2891
+
2892
+ whisper_mel::whisper_mel(whisper_mel && other) noexcept {
2893
+ take(other);
2894
+ }
2895
+
2896
+ whisper_mel & whisper_mel::operator=(whisper_mel && other) noexcept {
2897
+ if (this != &other) {
2898
+ reset();
2899
+ take(other);
2900
+ }
2901
+ return *this;
2902
+ }
2903
+
2904
+ void whisper_mel::init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
2905
+ this->n_len_org = n_len_org;
2906
+ assert(!ctx);
2907
+ ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
2908
+ tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_len, n_mel);
2909
+ buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(tensor) + ggml_backend_get_alignment(backend));
2910
+ auto alloc = ggml_tallocr_new(buffer);
2911
+ ggml_tallocr_alloc(&alloc, tensor);
2912
+ }
2913
+
2914
+ void whisper_mel::reset() {
2915
+ ggml_free(ctx);
2916
+ ggml_backend_buffer_free(buffer);
2917
+
2918
+ n_len_org = 0;
2919
+ tensor = nullptr;
2920
+ ctx = nullptr;
2921
+ buffer = nullptr;
2922
+ }
2923
+
2924
+ void whisper_mel::take(whisper_mel & other) noexcept {
2925
+ n_len_org = other.n_len_org;
2926
+ tensor = other.tensor;
2927
+ ctx = other.ctx;
2928
+ buffer = other.buffer;
2929
+
2930
+ other.n_len_org = 0;
2931
+ other.tensor = nullptr;
2932
+ other.ctx = nullptr;
2933
+ other.buffer = nullptr;
2934
+ }
2935
+
2936
  whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
2937
 
2938
  whisper_span<const float> whisper_mel_calc::hann_window() {
 
3020
  }
3021
  }
3022
 
3023
+ namespace {
3024
+
3025
+ struct whisper_mel_data {
3026
+ int n_len;
3027
+ int n_len_org;
3028
+ int n_mel;
3029
+ float* data;
3030
+ };
3031
+
3032
+ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
3033
  int n_samples, int n_threads,
3034
+ const whisper_filters & filters, whisper_mel_data & mel) {
3035
  const auto frame_size = WHISPER_N_FFT;
3036
  const auto frame_step = WHISPER_HOP_LENGTH;
3037
  std::vector<float> fft_in(frame_size, 0.0);
 
3097
  }
3098
  }
3099
  }
3100
+
3101
  struct mel_calc_cpu : public whisper_mel_calc {
3102
+ ggml_backend_t m_backend;
3103
  const whisper_filters& m_filters;
3104
+ mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
3105
 
3106
  // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
3107
  whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) const override {
 
3126
  // reflective pad 200 samples at the beginning of audio
3127
  std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
3128
 
3129
+ whisper_mel_data mel;
3130
  mel.n_mel = m_filters.n_mel;
3131
  // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
3132
  // Calculate number of frames + remove the last frame
3133
  mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
3134
  // Calculate semi-padded sample length to ensure compatibility
3135
  mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
 
3136
 
3137
+ std::vector<float> host_mel_data;
3138
+
3139
+ whisper_mel ret;
3140
+ ret.init(m_backend, mel.n_len, mel.n_len_org, mel.n_mel);
3141
+ if (ggml_backend_buffer_is_host(ret.buffer)) {
3142
+ mel.data = reinterpret_cast<float*>(ret.tensor->data);
3143
+ } else {
3144
+ host_mel_data.resize(mel.n_len * mel.n_mel);
3145
+ mel.data = host_mel_data.data();
3146
+ }
3147
 
3148
  {
3149
  std::vector<std::thread> workers(n_threads - 1);
 
3180
  mel.data[i] = (mel.data[i] + 4.0)/4.0;
3181
  }
3182
 
3183
+ if (!host_mel_data.empty()) {
3184
+ // the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it
3185
+ ggml_backend_tensor_set(ret.tensor, host_mel_data.data(), 0, ggml_nbytes(ret.tensor));
3186
+ }
3187
+
3188
+ return ret;
3189
  }
3190
  };
3191
  }
 
3200
  return ret;
3201
  } else
3202
  #endif
3203
+ return new mel_calc_cpu(backend, filters);
3204
  }
3205
 
3206
  // split text into tokens
 
3418
  {
3419
  bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
3420
  [&]() {
3421
+ return whisper_build_graph_conv(*ctx, *state, 0);
3422
  });
3423
 
3424
  if (!ok) {
 
3834
  return -1;
3835
  }
3836
 
3837
+ state->mel.reset();
3838
+ state->mel.init(ctx->backend, n_len, n_len, n_mel);
3839
+ ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
 
 
 
3840
 
3841
  return 0;
3842
  }