KitaitiMakoto commited on
Commit
855927b
·
unverified ·
1 Parent(s): 469f43c

ruby : add encoder begin callback related methods (#3076)

Browse files

* Lazy run TestBase.whisper

* Fix indentation

* Remove disused GGML_HIP_UMA from Ruby

* Add encoder_begin_callback

* Comment out existing abort mechanism

* Add test for encoder_begin_callback

* Add signatures for encoder_begin_callback related methods

* Update gem date

bindings/ruby/ext/options.rb CHANGED
@@ -114,7 +114,6 @@ class Options
114
  bool "GGML_HIP_GRAPHS"
115
  bool "GGML_HIP_NO_VMM"
116
  bool "GGML_HIP_ROCWMMA_FATTN"
117
- bool "GGML_HIP_UMA"
118
  ignored "GGML_INCLUDE_INSTALL_DIR"
119
  bool "GGML_KOMPUTE"
120
  bool "GGML_LASX"
 
114
  bool "GGML_HIP_GRAPHS"
115
  bool "GGML_HIP_NO_VMM"
116
  bool "GGML_HIP_ROCWMMA_FATTN"
 
117
  ignored "GGML_INCLUDE_INSTALL_DIR"
118
  bool "GGML_KOMPUTE"
119
  bool "GGML_LASX"
bindings/ruby/ext/ruby_whisper.h CHANGED
@@ -19,6 +19,7 @@ typedef struct {
19
  bool diarize;
20
  ruby_whisper_callback_container *new_segment_callback_container;
21
  ruby_whisper_callback_container *progress_callback_container;
 
22
  ruby_whisper_callback_container *abort_callback_container;
23
  } ruby_whisper_params;
24
 
 
19
  bool diarize;
20
  ruby_whisper_callback_container *new_segment_callback_container;
21
  ruby_whisper_callback_container *progress_callback_container;
22
+ ruby_whisper_callback_container *encoder_begin_callback_container;
23
  ruby_whisper_callback_container *abort_callback_container;
24
  } ruby_whisper_params;
25
 
bindings/ruby/ext/ruby_whisper_params.c CHANGED
@@ -26,7 +26,7 @@
26
  rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
27
  rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
28
 
29
- #define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
30
 
31
  extern VALUE cParams;
32
 
@@ -63,6 +63,8 @@ static ID id_new_segment_callback;
63
  static ID id_new_segment_callback_user_data;
64
  static ID id_progress_callback;
65
  static ID id_progress_callback_user_data;
 
 
66
  static ID id_abort_callback;
67
  static ID id_abort_callback_user_data;
68
 
@@ -126,6 +128,33 @@ static void progress_callback(struct whisper_context *ctx, struct whisper_state
126
  }
127
  }
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  static bool abort_callback(void * user_data) {
130
  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
131
  if (!NIL_P(container->callback)) {
@@ -161,6 +190,12 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
161
  rwp->params.progress_callback_user_data = rwp->progress_callback_container;
162
  }
163
 
 
 
 
 
 
 
164
  if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
165
  rwp->abort_callback_container->context = context;
166
  rwp->params.abort_callback = abort_callback;
@@ -173,6 +208,7 @@ rb_whisper_params_mark(ruby_whisper_params *rwp)
173
  {
174
  rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
175
  rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
 
176
  rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
177
  }
178
 
@@ -198,6 +234,7 @@ ruby_whisper_params_allocate(VALUE klass)
198
  rwp->diarize = false;
199
  rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
200
  rwp->progress_callback_container = rb_whisper_callback_container_allocate();
 
201
  rwp->abort_callback_container = rb_whisper_callback_container_allocate();
202
  return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
203
  }
@@ -849,6 +886,57 @@ ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
849
  rwp->progress_callback_container->user_data = value;
850
  return value;
851
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
  static VALUE
853
  ruby_whisper_params_get_abort_callback(VALUE self)
854
  {
@@ -958,6 +1046,8 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
958
  SET_PARAM_IF_SAME(new_segment_callback_user_data)
959
  SET_PARAM_IF_SAME(progress_callback)
960
  SET_PARAM_IF_SAME(progress_callback_user_data)
 
 
961
  SET_PARAM_IF_SAME(abort_callback)
962
  SET_PARAM_IF_SAME(abort_callback_user_data)
963
  }
@@ -1008,6 +1098,26 @@ ruby_whisper_params_on_progress(VALUE self)
1008
  return Qnil;
1009
  }
1010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1011
  /*
1012
  * Call block to determine whether abort or not. Return +true+ when you want to abort.
1013
  *
@@ -1068,10 +1178,13 @@ init_ruby_whisper_params(VALUE *mWhisper)
1068
  DEFINE_PARAM(new_segment_callback_user_data, 25)
1069
  DEFINE_PARAM(progress_callback, 26)
1070
  DEFINE_PARAM(progress_callback_user_data, 27)
1071
- DEFINE_PARAM(abort_callback, 28)
1072
- DEFINE_PARAM(abort_callback_user_data, 29)
 
 
1073
 
1074
  rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
1075
  rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
 
1076
  rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
1077
  }
 
26
  rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
27
  rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
28
 
29
+ #define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32
30
 
31
  extern VALUE cParams;
32
 
 
63
  static ID id_new_segment_callback_user_data;
64
  static ID id_progress_callback;
65
  static ID id_progress_callback_user_data;
66
+ static ID id_encoder_begin_callback;
67
+ static ID id_encoder_begin_callback_user_data;
68
  static ID id_abort_callback;
69
  static ID id_abort_callback_user_data;
70
 
 
128
  }
129
  }
130
 
131
+ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) {
132
+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
133
+ bool is_aborted = false;
134
+ VALUE result;
135
+
136
+ // Currently, doesn't support state because
137
+ // those require to resolve GC-related problems.
138
+ if (!NIL_P(container->callback)) {
139
+ result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
140
+ if (result == Qfalse) {
141
+ is_aborted = true;
142
+ }
143
+ }
144
+ const long callbacks_len = RARRAY_LEN(container->callbacks);
145
+ if (0 == callbacks_len) {
146
+ return !is_aborted;
147
+ }
148
+ for (int j = 0; j < callbacks_len; j++) {
149
+ VALUE cb = rb_ary_entry(container->callbacks, j);
150
+ result = rb_funcall(cb, id_call, 0);
151
+ if (result == Qfalse) {
152
+ is_aborted = true;
153
+ }
154
+ }
155
+ return !is_aborted;
156
+ }
157
+
158
  static bool abort_callback(void * user_data) {
159
  const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
160
  if (!NIL_P(container->callback)) {
 
190
  rwp->params.progress_callback_user_data = rwp->progress_callback_container;
191
  }
192
 
193
+ if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) {
194
+ rwp->encoder_begin_callback_container->context = context;
195
+ rwp->params.encoder_begin_callback = encoder_begin_callback;
196
+ rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
197
+ }
198
+
199
  if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
200
  rwp->abort_callback_container->context = context;
201
  rwp->params.abort_callback = abort_callback;
 
208
  {
209
  rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
210
  rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
211
+ rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
212
  rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
213
  }
214
 
 
234
  rwp->diarize = false;
235
  rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
236
  rwp->progress_callback_container = rb_whisper_callback_container_allocate();
237
+ rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
238
  rwp->abort_callback_container = rb_whisper_callback_container_allocate();
239
  return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
240
  }
 
886
  rwp->progress_callback_container->user_data = value;
887
  return value;
888
  }
889
+
890
+ static VALUE
891
+ ruby_whisper_params_get_encoder_begin_callback(VALUE self)
892
+ {
893
+ ruby_whisper_params *rwp;
894
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
895
+ return rwp->encoder_begin_callback_container->callback;
896
+ }
897
+
898
+ /*
899
+ * Sets encoder begin callback, called when the encoder starts.
900
+ *
901
+ * params.encoder_begin_callback = ->(context, _, user_data) {
902
+ * # ...
903
+ * }
904
+ *
905
+ * call-seq:
906
+ * encoder_begin_callback = callback -> callback
907
+ */
908
+ static VALUE
909
+ ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
910
+ {
911
+ ruby_whisper_params *rwp;
912
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
913
+ rwp->encoder_begin_callback_container->callback = value;
914
+ return value;
915
+ }
916
+
917
+ static VALUE
918
+ ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
919
+ {
920
+ ruby_whisper_params *rwp;
921
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
922
+ return rwp->encoder_begin_callback_container->user_data;
923
+ }
924
+
925
+ /*
926
+ * Sets user data passed to the last argument of encoder begin callback.
927
+ *
928
+ * call-seq:
929
+ * encoder_begin_callback_user_data = user_data -> use_data
930
+ */
931
+ static VALUE
932
+ ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
933
+ {
934
+ ruby_whisper_params *rwp;
935
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
936
+ rwp->encoder_begin_callback_container->user_data = value;
937
+ return value;
938
+ }
939
+
940
  static VALUE
941
  ruby_whisper_params_get_abort_callback(VALUE self)
942
  {
 
1046
  SET_PARAM_IF_SAME(new_segment_callback_user_data)
1047
  SET_PARAM_IF_SAME(progress_callback)
1048
  SET_PARAM_IF_SAME(progress_callback_user_data)
1049
+ SET_PARAM_IF_SAME(encoder_begin_callback)
1050
+ SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
1051
  SET_PARAM_IF_SAME(abort_callback)
1052
  SET_PARAM_IF_SAME(abort_callback_user_data)
1053
  }
 
1098
  return Qnil;
1099
  }
1100
 
1101
+ /*
1102
+ * Hook called when the encoder starts.
1103
+ *
1104
+ * whisper.on_encoder_begin do
1105
+ * # ...
1106
+ * end
1107
+ *
1108
+ * call-seq:
1109
+ * on_encoder_begin { ... }
1110
+ */
1111
+ static VALUE
1112
+ ruby_whisper_params_on_encoder_begin(VALUE self)
1113
+ {
1114
+ ruby_whisper_params *rws;
1115
+ Data_Get_Struct(self, ruby_whisper_params, rws);
1116
+ const VALUE blk = rb_block_proc();
1117
+ rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
1118
+ return Qnil;
1119
+ }
1120
+
1121
  /*
1122
  * Call block to determine whether abort or not. Return +true+ when you want to abort.
1123
  *
 
1178
  DEFINE_PARAM(new_segment_callback_user_data, 25)
1179
  DEFINE_PARAM(progress_callback, 26)
1180
  DEFINE_PARAM(progress_callback_user_data, 27)
1181
+ DEFINE_PARAM(encoder_begin_callback, 28)
1182
+ DEFINE_PARAM(encoder_begin_callback_user_data, 29)
1183
+ DEFINE_PARAM(abort_callback, 30)
1184
+ DEFINE_PARAM(abort_callback_user_data, 31)
1185
 
1186
  rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
1187
  rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
1188
+ rb_define_method(cParams, "on_encoder_begin", ruby_whisper_params_on_encoder_begin, 0);
1189
  rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
1190
  }
bindings/ruby/ext/ruby_whisper_transcribe.cpp CHANGED
@@ -50,15 +50,16 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
50
  fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
51
  return self;
52
  }
53
- {
54
- static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
 
55
 
56
- rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
57
- bool is_aborted = *(bool*)user_data;
58
- return !is_aborted;
59
- };
60
- rwp->params.encoder_begin_callback_user_data = &is_aborted;
61
- }
62
 
63
  register_callbacks(rwp, &self);
64
 
 
50
  fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
51
  return self;
52
  }
53
+ // Commented out because it is work in progress
54
+ // {
55
+ // static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
56
 
57
+ // rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
58
+ // bool is_aborted = *(bool*)user_data;
59
+ // return !is_aborted;
60
+ // };
61
+ // rwp->params.encoder_begin_callback_user_data = &is_aborted;
62
+ // }
63
 
64
  register_callbacks(rwp, &self);
65
 
bindings/ruby/lib/whisper/model/uri.rb CHANGED
@@ -53,7 +53,7 @@ module Whisper
53
  http.request request do |response|
54
  case response
55
  when Net::HTTPNotModified
56
- # noop
57
  when Net::HTTPOK
58
  download response
59
  when Net::HTTPRedirection
@@ -68,7 +68,7 @@ module Whisper
68
  rescue => err
69
  if cache_path.exist?
70
  warn err
71
- # Use cache file
72
  else
73
  raise
74
  end
 
53
  http.request request do |response|
54
  case response
55
  when Net::HTTPNotModified
56
+ # noop
57
  when Net::HTTPOK
58
  download response
59
  when Net::HTTPRedirection
 
68
  rescue => err
69
  if cache_path.exist?
70
  warn err
71
+ # Use cache file
72
  else
73
  raise
74
  end
bindings/ruby/sig/whisper.rbs CHANGED
@@ -7,6 +7,7 @@ module Whisper
7
  type log_callback = ^(Integer level, String message, Object user_data) -> void
8
  type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
9
  type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
 
10
  type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
11
 
12
  LOG_LEVEL_NONE: Integer
@@ -146,6 +147,8 @@ module Whisper
146
  ?new_segment_callback_user_data: Object,
147
  ?progress_callback: progress_callback,
148
  ?progress_callback_user_data: Object,
 
 
149
  ?abort_callback: abort_callback,
150
  ?abort_callback_user_data: Object
151
  ) -> instance
@@ -306,6 +309,18 @@ module Whisper
306
 
307
  def progress_callback_user_data: () -> Object
308
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  # Sets abort callback, called to check if the process should be aborted.
310
  #
311
  # params.abort_callback = ->(user_data) {
@@ -335,6 +350,10 @@ module Whisper
335
  #
336
  def on_progress: { (Integer progress) -> void } -> void
337
 
 
 
 
 
338
  # Call block to determine whether abort or not. Return +true+ when you want to abort.
339
  #
340
  # params.abort_on do
 
7
  type log_callback = ^(Integer level, String message, Object user_data) -> void
8
  type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
9
  type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
10
+ type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
11
  type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
12
 
13
  LOG_LEVEL_NONE: Integer
 
147
  ?new_segment_callback_user_data: Object,
148
  ?progress_callback: progress_callback,
149
  ?progress_callback_user_data: Object,
150
+ ?encoder_begin_callback: encoder_begin_callback,
151
+ ?encoder_begin_callback_user_data: Object,
152
  ?abort_callback: abort_callback,
153
  ?abort_callback_user_data: Object
154
  ) -> instance
 
309
 
310
  def progress_callback_user_data: () -> Object
311
 
312
+ # Sets encoder begin callback, called when the encoder starts.
313
+ #
314
+ def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback
315
+
316
+ def encoder_begin_callback: () -> (encoder_begin_callback | nil)
317
+
318
+ # Sets user data passed to the last argument of encoder begin callback.
319
+ #
320
+ def encoder_begin_callback_user_data=: (Object) -> Object
321
+
322
+ def encoder_begin_callback_user_data: () -> Object
323
+
324
  # Sets abort callback, called to check if the process should be aborted.
325
  #
326
  # params.abort_callback = ->(user_data) {
 
350
  #
351
  def on_progress: { (Integer progress) -> void } -> void
352
 
353
+ # Hook called on encoder starts.
354
+ #
355
+ def on_encoder_begin: { () -> void } -> void
356
+
357
  # Call block to determine whether abort or not. Return +true+ when you want to abort.
358
  #
359
  # params.abort_on do
bindings/ruby/tests/helper.rb CHANGED
@@ -6,9 +6,9 @@ class TestBase < Test::Unit::TestCase
6
  AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
7
 
8
  class << self
9
- attr_reader :whisper
 
10
 
11
- def startup
12
  @whisper = Whisper::Context.new("base.en")
13
  params = Whisper::Params.new
14
  params.print_timestamps = false
 
6
  AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
7
 
8
  class << self
9
+ def whisper
10
+ return @whisper if @whisper
11
 
 
12
  @whisper = Whisper::Context.new("base.en")
13
  params = Whisper::Params.new
14
  params.print_timestamps = false
bindings/ruby/tests/test_callback.rb CHANGED
@@ -111,6 +111,48 @@ class TestCallback < TestBase
111
  assert_equal 100, last
112
  end
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def test_abort_callback
115
  i = 0
116
  @params.abort_callback = ->(user_data) {
 
111
  assert_equal 100, last
112
  end
113
 
114
+ def test_encoder_begin_callback
115
+ i = 0
116
+ @params.encoder_begin_callback = ->(context, state, user_data) {
117
+ i += 1
118
+ }
119
+ @whisper.transcribe(@audio, @params)
120
+ assert i > 0
121
+ end
122
+
123
+ def test_encoder_begin_callback_abort
124
+ logs = []
125
+ Whisper.log_set -> (level, buffer, user_data) {
126
+ logs << buffer if level == Whisper::LOG_LEVEL_ERROR
127
+ }, logs
128
+ @params.encoder_begin_callback = ->(context, state, user_data) {
129
+ return false
130
+ }
131
+ @whisper.transcribe(@audio, @params)
132
+ assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
133
+ Whisper.log_set ->(level, buffer, user_data) {}, nil
134
+ end
135
+
136
+ def test_encoder_begin_callback_user_data
137
+ udata = Object.new
138
+ @params.encoder_begin_callback_user_data = udata
139
+ yielded = nil
140
+ @params.encoder_begin_callback = ->(context, state, user_data) {
141
+ yielded = user_data
142
+ }
143
+ @whisper.transcribe(@audio, @params)
144
+ assert_same udata, yielded
145
+ end
146
+
147
+ def test_on_encoder_begin
148
+ i = 0
149
+ @params.on_encoder_begin do
150
+ i += 1
151
+ end
152
+ @whisper.transcribe(@audio, @params)
153
+ assert i > 0
154
+ end
155
+
156
  def test_abort_callback
157
  i = 0
158
  @params.abort_callback = ->(user_data) {
bindings/ruby/whispercpp.gemspec CHANGED
@@ -4,7 +4,7 @@ Gem::Specification.new do |s|
4
  s.name = "whispercpp"
5
  s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
6
  s.version = '1.3.2'
7
- s.date = '2025-04-17'
8
  s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
9
  s.email = '[email protected]'
10
  s.extra_rdoc_files = ['LICENSE', 'README.md']
 
4
  s.name = "whispercpp"
5
  s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
6
  s.version = '1.3.2'
7
+ s.date = '2025-04-25'
8
  s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
9
  s.email = '[email protected]'
10
  s.extra_rdoc_files = ['LICENSE', 'README.md']