danbev commited on
Commit
9994342
·
unverified ·
1 Parent(s): 58220b6

node : add language detection support (#3190)

Browse files

This commit add support for language detection in the Whisper Node.js
addon example. It also updates the node addon to return an object
instead of an array as the results.

The motivation for this change is to enable the inclusion of the
detected language in the result, in addition to the transcription
segments.

For example, when using the `detect_language` option, the result will
now be:
```console
{ language: 'en' }
```

And if the `language` option is set to "auto", it will also return:
```console
{
language: 'en',
transcription: [
[
'00:00:00.000',
'00:00:07.600',
' And so my fellow Americans, ask not what your country can do for you,'
],
[
'00:00:07.600',
'00:00:10.600',
' ask what you can do for your country.'
]
]
}
```

examples/addon.node/__test__/whisper.spec.js CHANGED
@@ -17,6 +17,7 @@ const whisperParamsMock = {
17
  comma_in_time: false,
18
  translate: true,
19
  no_timestamps: false,
 
20
  audio_ctx: 0,
21
  max_len: 0,
22
  prompt: "",
@@ -30,8 +31,9 @@ const whisperParamsMock = {
30
  describe("Run whisper.node", () => {
31
  test("it should receive a non-empty value", async () => {
32
  let result = await whisperAsync(whisperParamsMock);
 
33
 
34
- expect(result.length).toBeGreaterThan(0);
35
  }, 10000);
36
  });
37
 
 
17
  comma_in_time: false,
18
  translate: true,
19
  no_timestamps: false,
20
+ detect_language: false,
21
  audio_ctx: 0,
22
  max_len: 0,
23
  prompt: "",
 
31
  describe("Run whisper.node", () => {
32
  test("it should receive a non-empty value", async () => {
33
  let result = await whisperAsync(whisperParamsMock);
34
+ console.log(result);
35
 
36
+ expect(result['transcription'].length).toBeGreaterThan(0);
37
  }, 10000);
38
  });
39
 
examples/addon.node/addon.cpp CHANGED
@@ -38,6 +38,7 @@ struct whisper_params {
38
  bool print_progress = false;
39
  bool no_timestamps = false;
40
  bool no_prints = false;
 
41
  bool use_gpu = true;
42
  bool flash_attn = false;
43
  bool comma_in_time = true;
@@ -130,6 +131,11 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
130
 
131
  void cb_log_disable(enum ggml_log_level, const char *, void *) {}
132
 
 
 
 
 
 
133
  class ProgressWorker : public Napi::AsyncWorker {
134
  public:
135
  ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
@@ -160,15 +166,27 @@ class ProgressWorker : public Napi::AsyncWorker {
160
 
161
  void OnOK() override {
162
  Napi::HandleScope scope(Env());
163
- Napi::Object res = Napi::Array::New(Env(), result.size());
164
- for (uint64_t i = 0; i < result.size(); ++i) {
 
 
 
 
 
 
 
 
 
 
 
165
  Napi::Object tmp = Napi::Array::New(Env(), 3);
166
  for (uint64_t j = 0; j < 3; ++j) {
167
- tmp[j] = Napi::String::New(Env(), result[i][j]);
168
  }
169
- res[i] = tmp;
170
- }
171
- Callback().Call({Env().Null(), res});
 
172
  }
173
 
174
  // Progress callback function - using thread-safe function
@@ -185,12 +203,12 @@ class ProgressWorker : public Napi::AsyncWorker {
185
 
186
  private:
187
  whisper_params params;
188
- std::vector<std::vector<std::string>> result;
189
  Napi::Env env;
190
  Napi::ThreadSafeFunction tsfn;
191
 
192
  // Custom run function with progress callback support
193
- int run_with_progress(whisper_params &params, std::vector<std::vector<std::string>> &result) {
194
  if (params.no_prints) {
195
  whisper_log_set(cb_log_disable, NULL);
196
  }
@@ -279,7 +297,8 @@ class ProgressWorker : public Napi::AsyncWorker {
279
  wparams.print_timestamps = !params.no_timestamps;
280
  wparams.print_special = params.print_special;
281
  wparams.translate = params.translate;
282
- wparams.language = params.language.c_str();
 
283
  wparams.n_threads = params.n_threads;
284
  wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
285
  wparams.offset_ms = params.offset_t_ms;
@@ -330,18 +349,22 @@ class ProgressWorker : public Napi::AsyncWorker {
330
  return 10;
331
  }
332
  }
333
- }
334
 
 
 
 
335
  const int n_segments = whisper_full_n_segments(ctx);
336
- result.resize(n_segments);
 
337
  for (int i = 0; i < n_segments; ++i) {
338
  const char * text = whisper_full_get_segment_text(ctx, i);
339
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
340
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
341
 
342
- result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
343
- result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
344
- result[i].emplace_back(text);
345
  }
346
 
347
  whisper_print_timings(ctx);
@@ -366,6 +389,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
366
  bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
367
  bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
368
  bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
 
369
  int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
370
  bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
371
  int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
@@ -418,6 +442,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
418
  params.max_context = max_context;
419
  params.print_progress = print_progress;
420
  params.prompt = prompt;
 
421
 
422
  Napi::Function callback = info[1].As<Napi::Function>();
423
  // Create a new Worker class with progress callback support
 
38
  bool print_progress = false;
39
  bool no_timestamps = false;
40
  bool no_prints = false;
41
+ bool detect_language= false;
42
  bool use_gpu = true;
43
  bool flash_attn = false;
44
  bool comma_in_time = true;
 
131
 
132
  void cb_log_disable(enum ggml_log_level, const char *, void *) {}
133
 
134
+ struct whisper_result {
135
+ std::vector<std::vector<std::string>> segments;
136
+ std::string language;
137
+ };
138
+
139
  class ProgressWorker : public Napi::AsyncWorker {
140
  public:
141
  ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
 
166
 
167
  void OnOK() override {
168
  Napi::HandleScope scope(Env());
169
+
170
+ if (params.detect_language) {
171
+ Napi::Object resultObj = Napi::Object::New(Env());
172
+ resultObj.Set("language", Napi::String::New(Env(), result.language));
173
+ Callback().Call({Env().Null(), resultObj});
174
+ }
175
+
176
+ Napi::Object returnObj = Napi::Object::New(Env());
177
+ if (!result.language.empty()) {
178
+ returnObj.Set("language", Napi::String::New(Env(), result.language));
179
+ }
180
+ Napi::Array transcriptionArray = Napi::Array::New(Env(), result.segments.size());
181
+ for (uint64_t i = 0; i < result.segments.size(); ++i) {
182
  Napi::Object tmp = Napi::Array::New(Env(), 3);
183
  for (uint64_t j = 0; j < 3; ++j) {
184
+ tmp[j] = Napi::String::New(Env(), result.segments[i][j]);
185
  }
186
+ transcriptionArray[i] = tmp;
187
+ }
188
+ returnObj.Set("transcription", transcriptionArray);
189
+ Callback().Call({Env().Null(), returnObj});
190
  }
191
 
192
  // Progress callback function - using thread-safe function
 
203
 
204
  private:
205
  whisper_params params;
206
+ whisper_result result;
207
  Napi::Env env;
208
  Napi::ThreadSafeFunction tsfn;
209
 
210
  // Custom run function with progress callback support
211
+ int run_with_progress(whisper_params &params, whisper_result & result) {
212
  if (params.no_prints) {
213
  whisper_log_set(cb_log_disable, NULL);
214
  }
 
297
  wparams.print_timestamps = !params.no_timestamps;
298
  wparams.print_special = params.print_special;
299
  wparams.translate = params.translate;
300
+ wparams.language = params.detect_language ? "auto" : params.language.c_str();
301
+ wparams.detect_language = params.detect_language;
302
  wparams.n_threads = params.n_threads;
303
  wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
304
  wparams.offset_ms = params.offset_t_ms;
 
349
  return 10;
350
  }
351
  }
352
+ }
353
 
354
+ if (params.detect_language || params.language == "auto") {
355
+ result.language = whisper_lang_str(whisper_full_lang_id(ctx));
356
+ }
357
  const int n_segments = whisper_full_n_segments(ctx);
358
+ result.segments.resize(n_segments);
359
+
360
  for (int i = 0; i < n_segments; ++i) {
361
  const char * text = whisper_full_get_segment_text(ctx, i);
362
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
363
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
364
 
365
+ result.segments[i].emplace_back(to_timestamp(t0, params.comma_in_time));
366
+ result.segments[i].emplace_back(to_timestamp(t1, params.comma_in_time));
367
+ result.segments[i].emplace_back(text);
368
  }
369
 
370
  whisper_print_timings(ctx);
 
389
  bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
390
  bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
391
  bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
392
+ bool detect_language = whisper_params.Get("detect_language").As<Napi::Boolean>();
393
  int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
394
  bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
395
  int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
 
442
  params.max_context = max_context;
443
  params.print_progress = print_progress;
444
  params.prompt = prompt;
445
+ params.detect_language = detect_language;
446
 
447
  Napi::Function callback = info[1].As<Napi::Function>();
448
  // Create a new Worker class with progress callback support
examples/addon.node/index.js CHANGED
@@ -17,6 +17,7 @@ const whisperParams = {
17
  comma_in_time: false,
18
  translate: true,
19
  no_timestamps: false,
 
20
  audio_ctx: 0,
21
  max_len: 0,
22
  progress_callback: (progress) => {
@@ -31,6 +32,8 @@ const params = Object.fromEntries(
31
  const [key, value] = item.slice(2).split("=");
32
  if (key === "audio_ctx") {
33
  whisperParams[key] = parseInt(value);
 
 
34
  } else {
35
  whisperParams[key] = value;
36
  }
 
17
  comma_in_time: false,
18
  translate: true,
19
  no_timestamps: false,
20
+ detect_language: false,
21
  audio_ctx: 0,
22
  max_len: 0,
23
  progress_callback: (progress) => {
 
32
  const [key, value] = item.slice(2).split("=");
33
  if (key === "audio_ctx") {
34
  whisperParams[key] = parseInt(value);
35
+ } else if (key === "detect_language") {
36
+ whisperParams[key] = value === "true";
37
  } else {
38
  whisperParams[key] = value;
39
  }