Spaces:
Running
Running
mky_coder
Mike Fan
commited on
whisper : optimize fft() function (#2242)
Browse filesCo-authored-by: Mike Fan <[email protected]>
- whisper.cpp +22 -34
whisper.cpp
CHANGED
|
@@ -2974,10 +2974,7 @@ whisper_span<const float> whisper_mel_calc::hann_window() {
|
|
| 2974 |
// naive Discrete Fourier Transform
|
| 2975 |
// input is real-valued
|
| 2976 |
// output is complex-valued
|
| 2977 |
-
static void dft(const
|
| 2978 |
-
int N = in.size();
|
| 2979 |
-
|
| 2980 |
-
out.resize(N*2);
|
| 2981 |
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
| 2982 |
|
| 2983 |
for (int k = 0; k < N; k++) {
|
|
@@ -2999,44 +2996,35 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
| 2999 |
// poor man's implementation - use something better
|
| 3000 |
// input is real-valued
|
| 3001 |
// output is complex-valued
|
| 3002 |
-
static void fft(
|
| 3003 |
-
out.resize(in.size()*2);
|
| 3004 |
-
|
| 3005 |
-
int N = in.size();
|
| 3006 |
-
|
| 3007 |
if (N == 1) {
|
| 3008 |
out[0] = in[0];
|
| 3009 |
out[1] = 0;
|
| 3010 |
return;
|
| 3011 |
}
|
| 3012 |
|
| 3013 |
-
|
| 3014 |
-
|
|
|
|
| 3015 |
return;
|
| 3016 |
}
|
| 3017 |
|
| 3018 |
-
|
| 3019 |
-
|
| 3020 |
-
|
| 3021 |
-
even.reserve(N/2);
|
| 3022 |
-
odd.reserve(N/2);
|
| 3023 |
-
|
| 3024 |
-
for (int i = 0; i < N; i++) {
|
| 3025 |
-
if (i % 2 == 0) {
|
| 3026 |
-
even.push_back(in[i]);
|
| 3027 |
-
} else {
|
| 3028 |
-
odd.push_back(in[i]);
|
| 3029 |
-
}
|
| 3030 |
}
|
|
|
|
|
|
|
| 3031 |
|
| 3032 |
-
|
| 3033 |
-
|
| 3034 |
-
|
| 3035 |
-
|
| 3036 |
-
|
|
|
|
| 3037 |
|
| 3038 |
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
| 3039 |
-
for (int k = 0; k <
|
| 3040 |
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
| 3041 |
float re = global_cache.cos_vals[idx]; // cos(t)
|
| 3042 |
float im = -global_cache.sin_vals[idx]; // sin(t)
|
|
@@ -3047,8 +3035,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
| 3047 |
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
| 3048 |
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
| 3049 |
|
| 3050 |
-
out[2*(k +
|
| 3051 |
-
out[2*(k +
|
| 3052 |
}
|
| 3053 |
}
|
| 3054 |
|
|
@@ -3066,8 +3054,8 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
|
|
| 3066 |
const whisper_filters & filters, whisper_mel_data & mel) {
|
| 3067 |
const auto frame_size = WHISPER_N_FFT;
|
| 3068 |
const auto frame_step = WHISPER_HOP_LENGTH;
|
| 3069 |
-
std::vector<float> fft_in(frame_size, 0.0);
|
| 3070 |
-
std::vector<float> fft_out(2 *
|
| 3071 |
int n_fft = filters.n_fft;
|
| 3072 |
int i = ith;
|
| 3073 |
|
|
@@ -3088,7 +3076,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
|
|
| 3088 |
}
|
| 3089 |
|
| 3090 |
// FFT
|
| 3091 |
-
fft(fft_in, fft_out);
|
| 3092 |
|
| 3093 |
// Calculate modulus^2 of complex numbers
|
| 3094 |
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
|
|
|
| 2974 |
// naive Discrete Fourier Transform
|
| 2975 |
// input is real-valued
|
| 2976 |
// output is complex-valued
|
| 2977 |
+
static void dft(const float* in, int N, float* out) {
|
|
|
|
|
|
|
|
|
|
| 2978 |
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
| 2979 |
|
| 2980 |
for (int k = 0; k < N; k++) {
|
|
|
|
| 2996 |
// poor man's implementation - use something better
|
| 2997 |
// input is real-valued
|
| 2998 |
// output is complex-valued
|
| 2999 |
+
static void fft(float* in, int N, float* out) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3000 |
if (N == 1) {
|
| 3001 |
out[0] = in[0];
|
| 3002 |
out[1] = 0;
|
| 3003 |
return;
|
| 3004 |
}
|
| 3005 |
|
| 3006 |
+
const int half_N = N / 2;
|
| 3007 |
+
if (N - half_N*2 == 1) {
|
| 3008 |
+
dft(in, N, out);
|
| 3009 |
return;
|
| 3010 |
}
|
| 3011 |
|
| 3012 |
+
float* even = in + N;
|
| 3013 |
+
for (int i = 0; i < half_N; ++i) {
|
| 3014 |
+
even[i]= in[2*i];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3015 |
}
|
| 3016 |
+
float* even_fft = out + 2 * N;
|
| 3017 |
+
fft(even, half_N, even_fft);
|
| 3018 |
|
| 3019 |
+
float* odd = even;
|
| 3020 |
+
for (int i = 0; i < half_N; ++i) {
|
| 3021 |
+
odd[i] = in[2*i + 1];
|
| 3022 |
+
}
|
| 3023 |
+
float* odd_fft = even_fft + N;
|
| 3024 |
+
fft(odd, half_N, odd_fft);
|
| 3025 |
|
| 3026 |
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
| 3027 |
+
for (int k = 0; k < half_N; k++) {
|
| 3028 |
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
| 3029 |
float re = global_cache.cos_vals[idx]; // cos(t)
|
| 3030 |
float im = -global_cache.sin_vals[idx]; // sin(t)
|
|
|
|
| 3035 |
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
| 3036 |
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
| 3037 |
|
| 3038 |
+
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
| 3039 |
+
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
| 3040 |
}
|
| 3041 |
}
|
| 3042 |
|
|
|
|
| 3054 |
const whisper_filters & filters, whisper_mel_data & mel) {
|
| 3055 |
const auto frame_size = WHISPER_N_FFT;
|
| 3056 |
const auto frame_step = WHISPER_HOP_LENGTH;
|
| 3057 |
+
std::vector<float> fft_in(frame_size * 2, 0.0);
|
| 3058 |
+
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
|
| 3059 |
int n_fft = filters.n_fft;
|
| 3060 |
int i = ith;
|
| 3061 |
|
|
|
|
| 3076 |
}
|
| 3077 |
|
| 3078 |
// FFT
|
| 3079 |
+
fft(fft_in.data(), frame_size, fft_out.data());
|
| 3080 |
|
| 3081 |
// Calculate modulus^2 of complex numbers
|
| 3082 |
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|