Spaces:
Running
Running
metal : add abort callback (ggml/905)
Browse files- ggml/include/ggml-metal.h +2 -0
- ggml/src/ggml-metal.m +38 -3
ggml/include/ggml-metal.h
CHANGED
|
@@ -50,6 +50,8 @@ GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void
|
|
| 50 |
|
| 51 |
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
| 52 |
|
|
|
|
|
|
|
| 53 |
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
| 54 |
|
| 55 |
// helper to check if the device supports a specific family
|
|
|
|
| 50 |
|
| 51 |
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
| 52 |
|
| 53 |
+
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
|
| 54 |
+
|
| 55 |
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
| 56 |
|
| 57 |
// helper to check if the device supports a specific family
|
ggml/src/ggml-metal.m
CHANGED
|
@@ -224,6 +224,10 @@ struct ggml_metal_context {
|
|
| 224 |
bool support_simdgroup_mm;
|
| 225 |
|
| 226 |
bool should_capture_next_compute;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
};
|
| 228 |
|
| 229 |
// MSL code
|
|
@@ -878,8 +882,11 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 878 |
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
| 879 |
command_buffer_builder[cb_idx] = command_buffer;
|
| 880 |
|
| 881 |
-
// enqueue the command buffers
|
| 882 |
-
|
|
|
|
|
|
|
|
|
|
| 883 |
}
|
| 884 |
|
| 885 |
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
|
@@ -2829,7 +2836,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2829 |
|
| 2830 |
[encoder endEncoding];
|
| 2831 |
|
| 2832 |
-
|
|
|
|
|
|
|
| 2833 |
});
|
| 2834 |
|
| 2835 |
// Wait for completion and check status of each command buffer
|
|
@@ -2849,6 +2858,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2849 |
|
| 2850 |
return GGML_STATUS_FAILED;
|
| 2851 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2852 |
}
|
| 2853 |
|
| 2854 |
if (should_capture) {
|
|
@@ -3244,6 +3270,15 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
| 3244 |
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
| 3245 |
}
|
| 3246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3247 |
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
| 3248 |
GGML_ASSERT(ggml_backend_is_metal(backend));
|
| 3249 |
|
|
|
|
| 224 |
bool support_simdgroup_mm;
|
| 225 |
|
| 226 |
bool should_capture_next_compute;
|
| 227 |
+
|
| 228 |
+
// abort ggml_metal_graph_compute if callback returns true
|
| 229 |
+
ggml_abort_callback abort_callback;
|
| 230 |
+
void * abort_callback_data;
|
| 231 |
};
|
| 232 |
|
| 233 |
// MSL code
|
|
|
|
| 882 |
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
| 883 |
command_buffer_builder[cb_idx] = command_buffer;
|
| 884 |
|
| 885 |
+
// always enqueue the first two command buffers
|
| 886 |
+
// enqueue all of the command buffers if we don't need to abort
|
| 887 |
+
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
| 888 |
+
[command_buffer enqueue];
|
| 889 |
+
}
|
| 890 |
}
|
| 891 |
|
| 892 |
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
|
|
|
| 2836 |
|
| 2837 |
[encoder endEncoding];
|
| 2838 |
|
| 2839 |
+
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
| 2840 |
+
[command_buffer commit];
|
| 2841 |
+
}
|
| 2842 |
});
|
| 2843 |
|
| 2844 |
// Wait for completion and check status of each command buffer
|
|
|
|
| 2858 |
|
| 2859 |
return GGML_STATUS_FAILED;
|
| 2860 |
}
|
| 2861 |
+
|
| 2862 |
+
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
|
| 2863 |
+
if (!next_buffer) {
|
| 2864 |
+
continue;
|
| 2865 |
+
}
|
| 2866 |
+
|
| 2867 |
+
bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
|
| 2868 |
+
if (next_queued) {
|
| 2869 |
+
continue;
|
| 2870 |
+
}
|
| 2871 |
+
|
| 2872 |
+
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
|
| 2873 |
+
GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
|
| 2874 |
+
return GGML_STATUS_ABORTED;
|
| 2875 |
+
}
|
| 2876 |
+
|
| 2877 |
+
[next_buffer commit];
|
| 2878 |
}
|
| 2879 |
|
| 2880 |
if (should_capture) {
|
|
|
|
| 3270 |
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
| 3271 |
}
|
| 3272 |
|
| 3273 |
+
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
|
| 3274 |
+
GGML_ASSERT(ggml_backend_is_metal(backend));
|
| 3275 |
+
|
| 3276 |
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
| 3277 |
+
|
| 3278 |
+
ctx->abort_callback = abort_callback;
|
| 3279 |
+
ctx->abort_callback_data = user_data;
|
| 3280 |
+
}
|
| 3281 |
+
|
| 3282 |
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
| 3283 |
GGML_ASSERT(ggml_backend_is_metal(backend));
|
| 3284 |
|