KitaitiMakoto commited on
Commit
8aaba9a
·
unverified ·
1 Parent(s): ae07b89

ruby : add more APIs (#2518)

Browse files

* Add test for built package existence

* Add more tests for Whisper::Params

* Add more Whisper::Params attributes

* Add tests for callbacks

* Add progress and abort callback features

* [skip ci] Add prompt usage in README

* Change prompt text in example

bindings/ruby/README.md CHANGED
@@ -31,6 +31,7 @@ params.duration = 60_000
31
  params.max_text_tokens = 300
32
  params.translate = true
33
  params.print_timestamps = false
 
34
 
35
  whisper.transcribe("path/to/audio.wav", params) do |whole_text|
36
  puts whole_text
 
31
  params.max_text_tokens = 300
32
  params.translate = true
33
  params.print_timestamps = false
34
+ params.prompt = "Initial prompt here."
35
 
36
  whisper.transcribe("path/to/audio.wav", params) do |whole_text|
37
  puts whole_text
bindings/ruby/ext/ruby_whisper.cpp CHANGED
@@ -107,10 +107,16 @@ void rb_whisper_free(ruby_whisper *rw) {
107
  free(rw);
108
  }
109
 
 
 
 
 
 
 
110
  void rb_whisper_params_mark(ruby_whisper_params *rwp) {
111
- rb_gc_mark(rwp->new_segment_callback_container->user_data);
112
- rb_gc_mark(rwp->new_segment_callback_container->callback);
113
- rb_gc_mark(rwp->new_segment_callback_container->callbacks);
114
  }
115
 
116
  void rb_whisper_params_free(ruby_whisper_params *rwp) {
@@ -141,6 +147,8 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
141
  rwp = ALLOC(ruby_whisper_params);
142
  rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
143
  rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
 
 
144
  return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
145
  }
146
 
@@ -316,6 +324,54 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
316
  rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
317
  }
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
320
  fprintf(stderr, "failed to process audio\n");
321
  return self;
@@ -631,6 +687,30 @@ static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
631
  static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
632
  BOOL_PARAMS_SETTER(self, split_on_word, value)
633
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  /*
635
  * If true, enables diarization.
636
  *
@@ -725,6 +805,124 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
725
  rwp->params.n_max_text_ctx = NUM2INT(value);
726
  return value;
727
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  /*
729
  * Sets new segment callback, called for every newly generated text segment.
730
  *
@@ -753,6 +951,62 @@ static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self,
753
  rwp->new_segment_callback_container->user_data = value;
754
  return value;
755
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
 
757
  // High level API
758
 
@@ -835,6 +1089,46 @@ static VALUE ruby_whisper_params_on_new_segment(VALUE self) {
835
  return Qnil;
836
  }
837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838
  /*
839
  * Start time in milliseconds.
840
  *
@@ -946,6 +1240,8 @@ void Init_whisper() {
946
  rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
947
  rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
948
  rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
 
 
949
  rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
950
  rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
951
 
@@ -956,9 +1252,25 @@ void Init_whisper() {
956
 
957
  rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
958
  rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
 
 
 
 
 
 
 
 
 
 
 
 
959
 
960
  rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
961
  rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
 
 
 
 
962
 
963
  // High leve
964
  cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
@@ -966,6 +1278,8 @@ void Init_whisper() {
966
  rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
967
  rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
968
  rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
 
 
969
  rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
970
  rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
971
  rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
 
107
  free(rw);
108
  }
109
 
110
+ void rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) {
111
+ rb_gc_mark(rwc->user_data);
112
+ rb_gc_mark(rwc->callback);
113
+ rb_gc_mark(rwc->callbacks);
114
+ }
115
+
116
  void rb_whisper_params_mark(ruby_whisper_params *rwp) {
117
+ rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
118
+ rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
119
+ rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
120
  }
121
 
122
  void rb_whisper_params_free(ruby_whisper_params *rwp) {
 
147
  rwp = ALLOC(ruby_whisper_params);
148
  rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
149
  rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
150
+ rwp->progress_callback_container = rb_whisper_callback_container_allocate();
151
+ rwp->abort_callback_container = rb_whisper_callback_container_allocate();
152
  return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
153
  }
154
 
 
324
  rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
325
  }
326
 
327
+ if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
328
+ rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) {
329
+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
330
+ const VALUE progress = INT2NUM(progress_cur);
331
+ // Currently, doesn't support state because
332
+ // those require to resolve GC-related problems.
333
+ if (!NIL_P(container->callback)) {
334
+ rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
335
+ }
336
+ const long callbacks_len = RARRAY_LEN(container->callbacks);
337
+ if (0 == callbacks_len) {
338
+ return;
339
+ }
340
+ for (int j = 0; j < callbacks_len; j++) {
341
+ VALUE cb = rb_ary_entry(container->callbacks, j);
342
+ rb_funcall(cb, id_call, 1, progress);
343
+ }
344
+ };
345
+ rwp->progress_callback_container->context = &self;
346
+ rwp->params.progress_callback_user_data = rwp->progress_callback_container;
347
+ }
348
+
349
+ if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
350
+ rwp->params.abort_callback = [](void * user_data) {
351
+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
352
+ if (!NIL_P(container->callback)) {
353
+ VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
354
+ if (!NIL_P(result) && Qfalse != result) {
355
+ return true;
356
+ }
357
+ }
358
+ const long callbacks_len = RARRAY_LEN(container->callbacks);
359
+ if (0 == callbacks_len) {
360
+ return false;
361
+ }
362
+ for (int j = 0; j < callbacks_len; j++) {
363
+ VALUE cb = rb_ary_entry(container->callbacks, j);
364
+ VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
365
+ if (!NIL_P(result) && Qfalse != result) {
366
+ return true;
367
+ }
368
+ }
369
+ return false;
370
+ };
371
+ rwp->abort_callback_container->context = &self;
372
+ rwp->params.abort_callback_user_data = rwp->abort_callback_container;
373
+ }
374
+
375
  if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
376
  fprintf(stderr, "failed to process audio\n");
377
  return self;
 
687
  static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
688
  BOOL_PARAMS_SETTER(self, split_on_word, value)
689
  }
690
+ /*
691
+ * Tokens to provide to the whisper decoder as initial prompt
692
+ * these are prepended to any existing text context from a previous call
693
+ * use whisper_tokenize() to convert text to tokens.
694
+ * Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
695
+ *
696
+ * call-seq:
697
+ * initial_prompt -> String
698
+ */
699
+ static VALUE ruby_whisper_params_get_initial_prompt(VALUE self) {
700
+ ruby_whisper_params *rwp;
701
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
702
+ return rwp->params.initial_prompt == nullptr ? Qnil : rb_str_new2(rwp->params.initial_prompt);
703
+ }
704
+ /*
705
+ * call-seq:
706
+ * initial_prompt = prompt -> prompt
707
+ */
708
+ static VALUE ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) {
709
+ ruby_whisper_params *rwp;
710
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
711
+ rwp->params.initial_prompt = StringValueCStr(value);
712
+ return value;
713
+ }
714
  /*
715
  * If true, enables diarization.
716
  *
 
805
  rwp->params.n_max_text_ctx = NUM2INT(value);
806
  return value;
807
  }
808
+ /*
809
+ * call-seq:
810
+ * temperature -> Float
811
+ */
812
+ static VALUE ruby_whisper_params_get_temperature(VALUE self) {
813
+ ruby_whisper_params *rwp;
814
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
815
+ return DBL2NUM(rwp->params.temperature);
816
+ }
817
+ /*
818
+ * call-seq:
819
+ * temperature = temp -> temp
820
+ */
821
+ static VALUE ruby_whisper_params_set_temperature(VALUE self, VALUE value) {
822
+ ruby_whisper_params *rwp;
823
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
824
+ rwp->params.temperature = RFLOAT_VALUE(value);
825
+ return value;
826
+ }
827
+ /*
828
+ * See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
829
+ *
830
+ * call-seq:
831
+ * max_initial_ts -> Flaot
832
+ */
833
+ static VALUE ruby_whisper_params_get_max_initial_ts(VALUE self) {
834
+ ruby_whisper_params *rwp;
835
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
836
+ return DBL2NUM(rwp->params.max_initial_ts);
837
+ }
838
+ /*
839
+ * call-seq:
840
+ * max_initial_ts = timestamp -> timestamp
841
+ */
842
+ static VALUE ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value) {
843
+ ruby_whisper_params *rwp;
844
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
845
+ rwp->params.max_initial_ts = RFLOAT_VALUE(value);
846
+ return value;
847
+ }
848
+ /*
849
+ * call-seq:
850
+ * length_penalty -> Float
851
+ */
852
+ static VALUE ruby_whisper_params_get_length_penalty(VALUE self) {
853
+ ruby_whisper_params *rwp;
854
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
855
+ return DBL2NUM(rwp->params.length_penalty);
856
+ }
857
+ /*
858
+ * call-seq:
859
+ * length_penalty = penalty -> penalty
860
+ */
861
+ static VALUE ruby_whisper_params_set_length_penalty(VALUE self, VALUE value) {
862
+ ruby_whisper_params *rwp;
863
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
864
+ rwp->params.length_penalty = RFLOAT_VALUE(value);
865
+ return value;
866
+ }
867
+ /*
868
+ * call-seq:
869
+ * temperature_inc -> Float
870
+ */
871
+ static VALUE ruby_whisper_params_get_temperature_inc(VALUE self) {
872
+ ruby_whisper_params *rwp;
873
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
874
+ return DBL2NUM(rwp->params.temperature_inc);
875
+ }
876
+ /*
877
+ * call-seq:
878
+ * temperature_inc = inc -> inc
879
+ */
880
+ static VALUE ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value) {
881
+ ruby_whisper_params *rwp;
882
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
883
+ rwp->params.temperature_inc = RFLOAT_VALUE(value);
884
+ return value;
885
+ }
886
+ /*
887
+ * Similar to OpenAI's "compression_ratio_threshold"
888
+ *
889
+ * call-seq:
890
+ * entropy_thold -> Float
891
+ */
892
+ static VALUE ruby_whisper_params_get_entropy_thold(VALUE self) {
893
+ ruby_whisper_params *rwp;
894
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
895
+ return DBL2NUM(rwp->params.entropy_thold);
896
+ }
897
+ /*
898
+ * call-seq:
899
+ * entropy_thold = threshold -> threshold
900
+ */
901
+ static VALUE ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value) {
902
+ ruby_whisper_params *rwp;
903
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
904
+ rwp->params.entropy_thold = RFLOAT_VALUE(value);
905
+ return value;
906
+ }
907
+ /*
908
+ * call-seq:
909
+ * logprob_thold -> Float
910
+ */
911
+ static VALUE ruby_whisper_params_get_logprob_thold(VALUE self) {
912
+ ruby_whisper_params *rwp;
913
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
914
+ return DBL2NUM(rwp->params.logprob_thold);
915
+ }
916
+ /*
917
+ * call-seq:
918
+ * logprob_thold = threshold -> threshold
919
+ */
920
+ static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) {
921
+ ruby_whisper_params *rwp;
922
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
923
+ rwp->params.logprob_thold = RFLOAT_VALUE(value);
924
+ return value;
925
+ }
926
  /*
927
  * Sets new segment callback, called for every newly generated text segment.
928
  *
 
951
  rwp->new_segment_callback_container->user_data = value;
952
  return value;
953
  }
954
+ /*
955
+ * Sets progress callback, called on each progress update.
956
+ *
957
+ * params.new_segment_callback = ->(context, _, n_new, user_data) {
958
+ * # ...
959
+ * }
960
+ *
961
+ * call-seq:
962
+ * progress_callback = callback -> callback
963
+ */
964
+ static VALUE ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) {
965
+ ruby_whisper_params *rwp;
966
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
967
+ rwp->progress_callback_container->callback = value;
968
+ return value;
969
+ }
970
+ /*
971
+ * Sets user data passed to the last argument of progress callback.
972
+ *
973
+ * call-seq:
974
+ * progress_callback_user_data = user_data -> use_data
975
+ */
976
+ static VALUE ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) {
977
+ ruby_whisper_params *rwp;
978
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
979
+ rwp->progress_callback_container->user_data = value;
980
+ return value;
981
+ }
982
+ /*
983
+ * Sets abort callback, called to check if the process should be aborted.
984
+ *
985
+ * params.abort_callback = ->(user_data) {
986
+ * # ...
987
+ * }
988
+ *
989
+ * call-seq:
990
+ * abort_callback = callback -> callback
991
+ */
992
+ static VALUE ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) {
993
+ ruby_whisper_params *rwp;
994
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
995
+ rwp->abort_callback_container->callback = value;
996
+ return value;
997
+ }
998
+ /*
999
+ * Sets user data passed to the last argument of abort callback.
1000
+ *
1001
+ * call-seq:
1002
+ * abort_callback_user_data = user_data -> use_data
1003
+ */
1004
+ static VALUE ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) {
1005
+ ruby_whisper_params *rwp;
1006
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
1007
+ rwp->abort_callback_container->user_data = value;
1008
+ return value;
1009
+ }
1010
 
1011
  // High level API
1012
 
 
1089
  return Qnil;
1090
  }
1091
 
1092
+ /*
1093
+ * Hook called on progress update. Yields each progress Integer between 0 and 100.
1094
+ *
1095
+ * whisper.on_progress do |progress|
1096
+ * # ...
1097
+ * end
1098
+ *
1099
+ * call-seq:
1100
+ * on_progress {|progress| ... }
1101
+ */
1102
+ static VALUE ruby_whisper_params_on_progress(VALUE self) {
1103
+ ruby_whisper_params *rws;
1104
+ Data_Get_Struct(self, ruby_whisper_params, rws);
1105
+ const VALUE blk = rb_block_proc();
1106
+ rb_ary_push(rws->progress_callback_container->callbacks, blk);
1107
+ return Qnil;
1108
+ }
1109
+
1110
+ /*
1111
+ * Call block to determine whether abort or not. Return +true+ when you want to abort.
1112
+ *
1113
+ * params.abort_on do
1114
+ * if some_condition
1115
+ * true # abort
1116
+ * else
1117
+ * false # continue
1118
+ * end
1119
+ * end
1120
+ *
1121
+ * call-seq:
1122
+ * abort_on { ... }
1123
+ */
1124
+ static VALUE ruby_whisper_params_abort_on(VALUE self) {
1125
+ ruby_whisper_params *rws;
1126
+ Data_Get_Struct(self, ruby_whisper_params, rws);
1127
+ const VALUE blk = rb_block_proc();
1128
+ rb_ary_push(rws->abort_callback_container->callbacks, blk);
1129
+ return Qnil;
1130
+ }
1131
+
1132
  /*
1133
  * Start time in milliseconds.
1134
  *
 
1240
  rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
1241
  rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
1242
  rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
1243
+ rb_define_method(cParams, "initial_prompt", ruby_whisper_params_get_initial_prompt, 0);
1244
+ rb_define_method(cParams, "initial_prompt=", ruby_whisper_params_set_initial_prompt, 1);
1245
  rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
1246
  rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
1247
 
 
1252
 
1253
  rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
1254
  rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
1255
+ rb_define_method(cParams, "temperature", ruby_whisper_params_get_temperature, 0);
1256
+ rb_define_method(cParams, "temperature=", ruby_whisper_params_set_temperature, 1);
1257
+ rb_define_method(cParams, "max_initial_ts", ruby_whisper_params_get_max_initial_ts, 0);
1258
+ rb_define_method(cParams, "max_initial_ts=", ruby_whisper_params_set_max_initial_ts, 1);
1259
+ rb_define_method(cParams, "length_penalty", ruby_whisper_params_get_length_penalty, 0);
1260
+ rb_define_method(cParams, "length_penalty=", ruby_whisper_params_set_length_penalty, 1);
1261
+ rb_define_method(cParams, "temperature_inc", ruby_whisper_params_get_temperature_inc, 0);
1262
+ rb_define_method(cParams, "temperature_inc=", ruby_whisper_params_set_temperature_inc, 1);
1263
+ rb_define_method(cParams, "entropy_thold", ruby_whisper_params_get_entropy_thold, 0);
1264
+ rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1);
1265
+ rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0);
1266
+ rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1);
1267
 
1268
  rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
1269
  rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
1270
+ rb_define_method(cParams, "progress_callback=", ruby_whisper_params_set_progress_callback, 1);
1271
+ rb_define_method(cParams, "progress_callback_user_data=", ruby_whisper_params_set_progress_callback_user_data, 1);
1272
+ rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
1273
+ rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
1274
 
1275
  // High leve
1276
  cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
 
1278
  rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
1279
  rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
1280
  rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
1281
+ rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
1282
+ rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
1283
  rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
1284
  rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
1285
  rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
bindings/ruby/ext/ruby_whisper.h CHANGED
@@ -18,6 +18,8 @@ typedef struct {
18
  struct whisper_full_params params;
19
  bool diarize;
20
  ruby_whisper_callback_container *new_segment_callback_container;
 
 
21
  } ruby_whisper_params;
22
 
23
  #endif
 
18
  struct whisper_full_params params;
19
  bool diarize;
20
  ruby_whisper_callback_container *new_segment_callback_container;
21
+ ruby_whisper_callback_container *progress_callback_container;
22
+ ruby_whisper_callback_container *abort_callback_container;
23
  } ruby_whisper_params;
24
 
25
  #endif
bindings/ruby/tests/test_callback.rb CHANGED
@@ -5,6 +5,7 @@ class TestCallback < Test::Unit::TestCase
5
  TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
6
 
7
  def setup
 
8
  @params = Whisper::Params.new
9
  @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
10
  @audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@@ -73,4 +74,90 @@ class TestCallback < Test::Unit::TestCase
73
 
74
  assert_same @whisper, @whisper.transcribe(@audio, @params)
75
  end
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  end
 
5
  TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
6
 
7
  def setup
8
+ GC.start
9
  @params = Whisper::Params.new
10
  @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
11
  @audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
 
74
 
75
  assert_same @whisper, @whisper.transcribe(@audio, @params)
76
  end
77
+
78
+ def test_progress_callback
79
+ first = nil
80
+ last = nil
81
+ @params.progress_callback = ->(context, state, progress, user_data) {
82
+ assert_kind_of Integer, progress
83
+ assert 0 <= progress && progress <= 100
84
+ assert_same @whisper, context
85
+ first = progress if first.nil?
86
+ last = progress
87
+ }
88
+ @whisper.transcribe(@audio, @params)
89
+ assert_equal 0, first
90
+ assert_equal 100, last
91
+ end
92
+
93
+ def test_progress_callback_user_data
94
+ udata = Object.new
95
+ @params.progress_callback_user_data = udata
96
+ @params.progress_callback = ->(context, state, n_new, user_data) {
97
+ assert_same udata, user_data
98
+ }
99
+
100
+ @whisper.transcribe(@audio, @params)
101
+ end
102
+
103
+ def test_on_progress
104
+ first = nil
105
+ last = nil
106
+ @params.on_progress do |progress|
107
+ assert_kind_of Integer, progress
108
+ assert 0 <= progress && progress <= 100
109
+ first = progress if first.nil?
110
+ last = progress
111
+ end
112
+ @whisper.transcribe(@audio, @params)
113
+ assert_equal 0, first
114
+ assert_equal 100, last
115
+ end
116
+
117
+ def test_abort_callback
118
+ i = 0
119
+ @params.abort_callback = ->(user_data) {
120
+ assert_nil user_data
121
+ i += 1
122
+ return false
123
+ }
124
+ @whisper.transcribe(@audio, @params)
125
+ assert i > 0
126
+ end
127
+
128
+ def test_abort_callback_abort
129
+ i = 0
130
+ @params.abort_callback = ->(user_data) {
131
+ i += 1
132
+ return i == 3
133
+ }
134
+ @whisper.transcribe(@audio, @params)
135
+ assert_equal 3, i
136
+ end
137
+
138
+ def test_abort_callback_user_data
139
+ udata = Object.new
140
+ @params.abort_callback_user_data = udata
141
+ yielded = nil
142
+ @params.abort_callback = ->(user_data) {
143
+ yielded = user_data
144
+ }
145
+ @whisper.transcribe(@audio, @params)
146
+ assert_same udata, yielded
147
+ end
148
+
149
+ def test_abort_on
150
+ do_abort = false
151
+ aborted_from_callback = false
152
+ @params.on_new_segment do |segment|
153
+ do_abort = true if segment.text.match? /ask/
154
+ end
155
+ i = 0
156
+ @params.abort_on do
157
+ i += 1
158
+ do_abort
159
+ end
160
+ @whisper.transcribe(@audio, @params)
161
+ assert i > 0
162
+ end
163
  end
bindings/ruby/tests/test_package.rb CHANGED
@@ -8,6 +8,7 @@ class TestPackage < Test::Unit::TestCase
8
  Tempfile.create do |file|
9
  assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
10
  assert file.size > 0
 
11
  end
12
  end
13
 
 
8
  Tempfile.create do |file|
9
  assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
10
  assert file.size > 0
11
+ assert_path_exist file.to_path
12
  end
13
  end
14
 
bindings/ruby/tests/test_params.rb CHANGED
@@ -1,3 +1,4 @@
 
1
  require 'whisper'
2
 
3
  class TestParams < Test::Unit::TestCase
@@ -109,4 +110,46 @@ class TestParams < Test::Unit::TestCase
109
  @params.split_on_word = false
110
  assert [email protected]_on_word
111
  end
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  end
 
1
+ require 'test/unit'
2
  require 'whisper'
3
 
4
  class TestParams < Test::Unit::TestCase
 
110
  @params.split_on_word = false
111
  assert [email protected]_on_word
112
  end
113
+
114
+ def test_initial_prompt
115
+ assert_nil @params.initial_prompt
116
+ @params.initial_prompt = "You are a polite person."
117
+ assert_equal "You are a polite person.", @params.initial_prompt
118
+ end
119
+
120
+ def test_temperature
121
+ assert_equal 0.0, @params.temperature
122
+ @params.temperature = 0.5
123
+ assert_equal 0.5, @params.temperature
124
+ end
125
+
126
+ def test_max_initial_ts
127
+ assert_equal 1.0, @params.max_initial_ts
128
+ @params.max_initial_ts = 600.0
129
+ assert_equal 600.0, @params.max_initial_ts
130
+ end
131
+
132
+ def test_length_penalty
133
+ assert_equal -1.0, @params.length_penalty
134
+ @params.length_penalty = 0.5
135
+ assert_equal 0.5, @params.length_penalty
136
+ end
137
+
138
+ def test_temperature_inc
139
+ assert_in_delta 0.2, @params.temperature_inc
140
+ @params.temperature_inc = 0.5
141
+ assert_in_delta 0.5, @params.temperature_inc
142
+ end
143
+
144
+ def test_entropy_thold
145
+ assert_in_delta 2.4, @params.entropy_thold
146
+ @params.entropy_thold = 3.0
147
+ assert_in_delta 3.0, @params.entropy_thold
148
+ end
149
+
150
+ def test_logprob_thold
151
+ assert_in_delta -1.0, @params.logprob_thold
152
+ @params.logprob_thold = -0.5
153
+ assert_in_delta -0.5, @params.logprob_thold
154
+ end
155
  end