ggerganov commited on
Commit
3b7b90c
·
unverified ·
1 Parent(s): 03fb680

whisper : switch back to F32 mask (#0)

Browse files
Files changed (1) hide show
  1. whisper.cpp +3 -5
whisper.cpp CHANGED
@@ -2294,8 +2294,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2294
  ggml_set_name(KQ_mask, "KQ_mask");
2295
  ggml_set_input(KQ_mask);
2296
 
2297
- struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
2298
-
2299
  // token encoding + position encoding
2300
  struct ggml_tensor * cur =
2301
  ggml_add(ctx0,
@@ -2379,7 +2377,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2379
  // K * Q
2380
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2381
 
2382
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask_f16, 1.0f, 0.0f);
2383
 
2384
  struct ggml_tensor * V =
2385
  ggml_view_3d(ctx0, kv_self.v,
@@ -2873,8 +2871,8 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2873
  int i = ith;
2874
 
2875
  // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
2876
- assert( n_fft == 1 + (frame_size / 2) );
2877
-
2878
  // calculate FFT only when fft_in are not all zero
2879
  for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
2880
  const int offset = i * frame_step;
 
2294
  ggml_set_name(KQ_mask, "KQ_mask");
2295
  ggml_set_input(KQ_mask);
2296
 
 
 
2297
  // token encoding + position encoding
2298
  struct ggml_tensor * cur =
2299
  ggml_add(ctx0,
 
2377
  // K * Q
2378
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2379
 
2380
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
2381
 
2382
  struct ggml_tensor * V =
2383
  ggml_view_3d(ctx0, kv_self.v,
 
2871
  int i = ith;
2872
 
2873
  // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
2874
+ assert(n_fft == 1 + (frame_size / 2));
2875
+
2876
  // calculate FFT only when fft_in are not all zero
2877
  for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
2878
  const int offset = i * frame_step;