mky_coder Mike Fan commited on
Commit
cc603fa
·
unverified ·
1 Parent(s): e8e18fb

whisper : optimize fft() function (#2242)

Browse files

Co-authored-by: Mike Fan <[email protected]>

Files changed (1) hide show
  1. whisper.cpp +22 -34
whisper.cpp CHANGED
@@ -2974,10 +2974,7 @@ whisper_span<const float> whisper_mel_calc::hann_window() {
2974
  // naive Discrete Fourier Transform
2975
  // input is real-valued
2976
  // output is complex-valued
2977
- static void dft(const std::vector<float> & in, std::vector<float> & out) {
2978
- int N = in.size();
2979
-
2980
- out.resize(N*2);
2981
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2982
 
2983
  for (int k = 0; k < N; k++) {
@@ -2999,44 +2996,35 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2999
  // poor man's implementation - use something better
3000
  // input is real-valued
3001
  // output is complex-valued
3002
- static void fft(const std::vector<float> & in, std::vector<float> & out) {
3003
- out.resize(in.size()*2);
3004
-
3005
- int N = in.size();
3006
-
3007
  if (N == 1) {
3008
  out[0] = in[0];
3009
  out[1] = 0;
3010
  return;
3011
  }
3012
 
3013
- if (N%2 == 1) {
3014
- dft(in, out);
 
3015
  return;
3016
  }
3017
 
3018
- std::vector<float> even;
3019
- std::vector<float> odd;
3020
-
3021
- even.reserve(N/2);
3022
- odd.reserve(N/2);
3023
-
3024
- for (int i = 0; i < N; i++) {
3025
- if (i % 2 == 0) {
3026
- even.push_back(in[i]);
3027
- } else {
3028
- odd.push_back(in[i]);
3029
- }
3030
  }
 
 
3031
 
3032
- std::vector<float> even_fft;
3033
- std::vector<float> odd_fft;
3034
-
3035
- fft(even, even_fft);
3036
- fft(odd, odd_fft);
 
3037
 
3038
  const int sin_cos_step = SIN_COS_N_COUNT / N;
3039
- for (int k = 0; k < N/2; k++) {
3040
  int idx = k * sin_cos_step; // t = 2*M_PI*k/N
3041
  float re = global_cache.cos_vals[idx]; // cos(t)
3042
  float im = -global_cache.sin_vals[idx]; // sin(t)
@@ -3047,8 +3035,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
3047
  out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
3048
  out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
3049
 
3050
- out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
3051
- out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
3052
  }
3053
  }
3054
 
@@ -3066,8 +3054,8 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
3066
  const whisper_filters & filters, whisper_mel_data & mel) {
3067
  const auto frame_size = WHISPER_N_FFT;
3068
  const auto frame_step = WHISPER_HOP_LENGTH;
3069
- std::vector<float> fft_in(frame_size, 0.0);
3070
- std::vector<float> fft_out(2 * frame_size);
3071
  int n_fft = filters.n_fft;
3072
  int i = ith;
3073
 
@@ -3088,7 +3076,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
3088
  }
3089
 
3090
  // FFT
3091
- fft(fft_in, fft_out);
3092
 
3093
  // Calculate modulus^2 of complex numbers
3094
  // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
 
2974
  // naive Discrete Fourier Transform
2975
  // input is real-valued
2976
  // output is complex-valued
2977
+ static void dft(const float* in, int N, float* out) {
 
 
 
2978
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2979
 
2980
  for (int k = 0; k < N; k++) {
 
2996
  // poor man's implementation - use something better
2997
  // input is real-valued
2998
  // output is complex-valued
2999
+ static void fft(float* in, int N, float* out) {
 
 
 
 
3000
  if (N == 1) {
3001
  out[0] = in[0];
3002
  out[1] = 0;
3003
  return;
3004
  }
3005
 
3006
+ const int half_N = N / 2;
3007
+ if (N - half_N*2 == 1) {
3008
+ dft(in, N, out);
3009
  return;
3010
  }
3011
 
3012
+ float* even = in + N;
3013
+ for (int i = 0; i < half_N; ++i) {
3014
+ even[i]= in[2*i];
 
 
 
 
 
 
 
 
 
3015
  }
3016
+ float* even_fft = out + 2 * N;
3017
+ fft(even, half_N, even_fft);
3018
 
3019
+ float* odd = even;
3020
+ for (int i = 0; i < half_N; ++i) {
3021
+ odd[i] = in[2*i + 1];
3022
+ }
3023
+ float* odd_fft = even_fft + N;
3024
+ fft(odd, half_N, odd_fft);
3025
 
3026
  const int sin_cos_step = SIN_COS_N_COUNT / N;
3027
+ for (int k = 0; k < half_N; k++) {
3028
  int idx = k * sin_cos_step; // t = 2*M_PI*k/N
3029
  float re = global_cache.cos_vals[idx]; // cos(t)
3030
  float im = -global_cache.sin_vals[idx]; // sin(t)
 
3035
  out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
3036
  out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
3037
 
3038
+ out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
3039
+ out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
3040
  }
3041
  }
3042
 
 
3054
  const whisper_filters & filters, whisper_mel_data & mel) {
3055
  const auto frame_size = WHISPER_N_FFT;
3056
  const auto frame_step = WHISPER_HOP_LENGTH;
3057
+ std::vector<float> fft_in(frame_size * 2, 0.0);
3058
+ std::vector<float> fft_out(frame_size * 2 * 2 * 2);
3059
  int n_fft = filters.n_fft;
3060
  int i = ith;
3061
 
 
3076
  }
3077
 
3078
  // FFT
3079
+ fft(fft_in.data(), frame_size, fft_out.data());
3080
 
3081
  // Calculate modulus^2 of complex numbers
3082
  // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.