conradev commited on
Commit
b822172
·
1 Parent(s): 124c156

metal : add abort callback (ggml/905)

Browse files
Files changed (2) hide show
  1. ggml/include/ggml-metal.h +2 -0
  2. ggml/src/ggml-metal.m +38 -3
ggml/include/ggml-metal.h CHANGED
@@ -50,6 +50,8 @@ GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void
50
 
51
  GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
52
 
 
 
53
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
54
 
55
  // helper to check if the device supports a specific family
 
50
 
51
  GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
52
 
53
+ GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
54
+
55
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
56
 
57
  // helper to check if the device supports a specific family
ggml/src/ggml-metal.m CHANGED
@@ -224,6 +224,10 @@ struct ggml_metal_context {
224
  bool support_simdgroup_mm;
225
 
226
  bool should_capture_next_compute;
 
 
 
 
227
  };
228
 
229
  // MSL code
@@ -878,8 +882,11 @@ static enum ggml_status ggml_metal_graph_compute(
878
  id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
879
  command_buffer_builder[cb_idx] = command_buffer;
880
 
881
- // enqueue the command buffers in order to specify their execution order
882
- [command_buffer enqueue];
 
 
 
883
  }
884
 
885
  const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
@@ -2829,7 +2836,9 @@ static enum ggml_status ggml_metal_graph_compute(
2829
 
2830
  [encoder endEncoding];
2831
 
2832
- [command_buffer commit];
 
 
2833
  });
2834
 
2835
  // Wait for completion and check status of each command buffer
@@ -2849,6 +2858,23 @@ static enum ggml_status ggml_metal_graph_compute(
2849
 
2850
  return GGML_STATUS_FAILED;
2851
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2852
  }
2853
 
2854
  if (should_capture) {
@@ -3244,6 +3270,15 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
3244
  ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
3245
  }
3246
 
 
 
 
 
 
 
 
 
 
3247
  bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
3248
  GGML_ASSERT(ggml_backend_is_metal(backend));
3249
 
 
224
  bool support_simdgroup_mm;
225
 
226
  bool should_capture_next_compute;
227
+
228
+ // abort ggml_metal_graph_compute if callback returns true
229
+ ggml_abort_callback abort_callback;
230
+ void * abort_callback_data;
231
  };
232
 
233
  // MSL code
 
882
  id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
883
  command_buffer_builder[cb_idx] = command_buffer;
884
 
885
+ // always enqueue the first two command buffers
886
+ // enqueue all of the command buffers if we don't need to abort
887
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
888
+ [command_buffer enqueue];
889
+ }
890
  }
891
 
892
  const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
 
2836
 
2837
  [encoder endEncoding];
2838
 
2839
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
2840
+ [command_buffer commit];
2841
+ }
2842
  });
2843
 
2844
  // Wait for completion and check status of each command buffer
 
2858
 
2859
  return GGML_STATUS_FAILED;
2860
  }
2861
+
2862
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
2863
+ if (!next_buffer) {
2864
+ continue;
2865
+ }
2866
+
2867
+ bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
2868
+ if (next_queued) {
2869
+ continue;
2870
+ }
2871
+
2872
+ if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
2873
+ GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
2874
+ return GGML_STATUS_ABORTED;
2875
+ }
2876
+
2877
+ [next_buffer commit];
2878
  }
2879
 
2880
  if (should_capture) {
 
3270
  ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
3271
  }
3272
 
3273
+ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
3274
+ GGML_ASSERT(ggml_backend_is_metal(backend));
3275
+
3276
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
3277
+
3278
+ ctx->abort_callback = abort_callback;
3279
+ ctx->abort_callback_data = user_data;
3280
+ }
3281
+
3282
  bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
3283
  GGML_ASSERT(ggml_backend_is_metal(backend));
3284