Spaces:
Running
Running
Gaurav Garg
commited on
Commit
·
1e69b8c
1
Parent(s):
40652de
cuda : enable CUDA Graph on CUDA Toolkit < 12.x (llama/12394)
Browse files* Enable CUDA Graph on CTK < 12.x
`cudaGraphExecUpdate` API was changed on 12.x. For this reason CUDA graph support was disabled on older CUDA toolkit. This change enables CUDA support in CTK version < 12.x by using older API if CTK < 12.x.
* Fix compilation errors with MUSA
* Disable CUDA Graph for MUSA
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
|
|
| 678 |
};
|
| 679 |
|
| 680 |
|
| 681 |
-
#if (
|
| 682 |
#define USE_CUDA_GRAPH
|
| 683 |
#endif
|
| 684 |
|
|
|
|
| 678 |
};
|
| 679 |
|
| 680 |
|
| 681 |
+
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
|
| 682 |
#define USE_CUDA_GRAPH
|
| 683 |
#endif
|
| 684 |
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -2610,13 +2610,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
|
|
| 2610 |
|
| 2611 |
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
| 2612 |
|
|
|
|
| 2613 |
cudaGraphExecUpdateResultInfo result_info;
|
| 2614 |
-
#ifdef __HIP_PLATFORM_AMD__
|
| 2615 |
-
hipGraphNode_t errorNode;
|
| 2616 |
-
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
| 2617 |
-
#else
|
| 2618 |
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
| 2619 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2620 |
if (stat == cudaErrorGraphExecUpdateFailure) {
|
| 2621 |
#ifndef NDEBUG
|
| 2622 |
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
|
|
|
| 2610 |
|
| 2611 |
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
| 2612 |
|
| 2613 |
+
#if CUDART_VERSION >= 12000
|
| 2614 |
cudaGraphExecUpdateResultInfo result_info;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2615 |
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
| 2616 |
+
#else
|
| 2617 |
+
cudaGraphNode_t errorNode;
|
| 2618 |
+
cudaGraphExecUpdateResult result_info;
|
| 2619 |
+
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
| 2620 |
+
#endif // CUDART_VERSION >= 12000
|
| 2621 |
+
|
| 2622 |
if (stat == cudaErrorGraphExecUpdateFailure) {
|
| 2623 |
#ifndef NDEBUG
|
| 2624 |
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
ggml/src/ggml-cuda/vendors/hip.h
CHANGED
|
@@ -112,7 +112,7 @@
|
|
| 112 |
#define cudaGraphExecDestroy hipGraphExecDestroy
|
| 113 |
#define cudaGraphLaunch hipGraphLaunch
|
| 114 |
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
| 115 |
-
#define
|
| 116 |
#define cudaGraphNodeType hipGraphNodeType
|
| 117 |
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
| 118 |
#define cudaGraphInstantiate hipGraphInstantiate
|
|
|
|
| 112 |
#define cudaGraphExecDestroy hipGraphExecDestroy
|
| 113 |
#define cudaGraphLaunch hipGraphLaunch
|
| 114 |
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
| 115 |
+
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
|
| 116 |
#define cudaGraphNodeType hipGraphNodeType
|
| 117 |
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
| 118 |
#define cudaGraphInstantiate hipGraphInstantiate
|
ggml/src/ggml-cuda/vendors/musa.h
CHANGED
|
@@ -119,7 +119,7 @@
|
|
| 119 |
#define cudaGraphExecDestroy musaGraphExecDestroy
|
| 120 |
#define cudaGraphExec_t musaGraphExec_t
|
| 121 |
#define cudaGraphExecUpdate musaGraphExecUpdate
|
| 122 |
-
#define
|
| 123 |
#define cudaGraphGetNodes musaGraphGetNodes
|
| 124 |
#define cudaGraphInstantiate musaGraphInstantiate
|
| 125 |
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
|
@@ -132,6 +132,7 @@
|
|
| 132 |
#define cudaGraph_t musaGraph_t
|
| 133 |
#define cudaKernelNodeParams musaKernelNodeParams
|
| 134 |
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
|
|
|
| 135 |
#define cudaStreamEndCapture musaStreamEndCapture
|
| 136 |
|
| 137 |
typedef mt_bfloat16 nv_bfloat16;
|
|
|
|
| 119 |
#define cudaGraphExecDestroy musaGraphExecDestroy
|
| 120 |
#define cudaGraphExec_t musaGraphExec_t
|
| 121 |
#define cudaGraphExecUpdate musaGraphExecUpdate
|
| 122 |
+
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
|
| 123 |
#define cudaGraphGetNodes musaGraphGetNodes
|
| 124 |
#define cudaGraphInstantiate musaGraphInstantiate
|
| 125 |
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
|
|
|
| 132 |
#define cudaGraph_t musaGraph_t
|
| 133 |
#define cudaKernelNodeParams musaKernelNodeParams
|
| 134 |
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
| 135 |
+
#define cudaStreamBeginCapture musaStreamBeginCapture
|
| 136 |
#define cudaStreamEndCapture musaStreamEndCapture
|
| 137 |
|
| 138 |
typedef mt_bfloat16 nv_bfloat16;
|
ggml/src/ggml-musa/CMakeLists.txt
CHANGED
|
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
|
|
| 67 |
add_compile_definitions(GGML_USE_MUSA)
|
| 68 |
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
| 69 |
|
| 70 |
-
if (GGML_CUDA_GRAPHS)
|
| 71 |
-
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
|
| 72 |
-
endif()
|
| 73 |
-
|
| 74 |
if (GGML_CUDA_FORCE_MMQ)
|
| 75 |
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
| 76 |
endif()
|
|
|
|
| 67 |
add_compile_definitions(GGML_USE_MUSA)
|
| 68 |
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
if (GGML_CUDA_FORCE_MMQ)
|
| 71 |
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
| 72 |
endif()
|