Gaurav Garg commited on
Commit
1e69b8c
·
1 Parent(s): 40652de

cuda : enable CUDA Graph on CUDA Toolkit < 12.x (llama/12394)

Browse files

* Enable CUDA Graph on CTK < 12.x

`cudaGraphExecUpdate` API was changed on 12.x. For this reason CUDA graph support was disabled on older CUDA toolkit. This change enables CUDA support in CTK version < 12.x by using older API if CTK < 12.x.

* Fix compilation errors with MUSA

* Disable CUDA Graph for MUSA

ggml/src/ggml-cuda/common.cuh CHANGED
@@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
678
  };
679
 
680
 
681
- #if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
682
  #define USE_CUDA_GRAPH
683
  #endif
684
 
 
678
  };
679
 
680
 
681
+ #if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
682
  #define USE_CUDA_GRAPH
683
  #endif
684
 
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -2610,13 +2610,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
2610
 
2611
  static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2612
 
 
2613
  cudaGraphExecUpdateResultInfo result_info;
2614
- #ifdef __HIP_PLATFORM_AMD__
2615
- hipGraphNode_t errorNode;
2616
- hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
2617
- #else
2618
  cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2619
- #endif
 
 
 
 
 
2620
  if (stat == cudaErrorGraphExecUpdateFailure) {
2621
  #ifndef NDEBUG
2622
  GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
 
2610
 
2611
  static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2612
 
2613
+ #if CUDART_VERSION >= 12000
2614
  cudaGraphExecUpdateResultInfo result_info;
 
 
 
 
2615
  cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2616
+ #else
2617
+ cudaGraphNode_t errorNode;
2618
+ cudaGraphExecUpdateResult result_info;
2619
+ cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
2620
+ #endif // CUDART_VERSION >= 12000
2621
+
2622
  if (stat == cudaErrorGraphExecUpdateFailure) {
2623
  #ifndef NDEBUG
2624
  GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
ggml/src/ggml-cuda/vendors/hip.h CHANGED
@@ -112,7 +112,7 @@
112
  #define cudaGraphExecDestroy hipGraphExecDestroy
113
  #define cudaGraphLaunch hipGraphLaunch
114
  #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
115
- #define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
116
  #define cudaGraphNodeType hipGraphNodeType
117
  #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
118
  #define cudaGraphInstantiate hipGraphInstantiate
 
112
  #define cudaGraphExecDestroy hipGraphExecDestroy
113
  #define cudaGraphLaunch hipGraphLaunch
114
  #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
115
+ #define cudaGraphExecUpdateResult hipGraphExecUpdateResult
116
  #define cudaGraphNodeType hipGraphNodeType
117
  #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
118
  #define cudaGraphInstantiate hipGraphInstantiate
ggml/src/ggml-cuda/vendors/musa.h CHANGED
@@ -119,7 +119,7 @@
119
  #define cudaGraphExecDestroy musaGraphExecDestroy
120
  #define cudaGraphExec_t musaGraphExec_t
121
  #define cudaGraphExecUpdate musaGraphExecUpdate
122
- #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
123
  #define cudaGraphGetNodes musaGraphGetNodes
124
  #define cudaGraphInstantiate musaGraphInstantiate
125
  #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@@ -132,6 +132,7 @@
132
  #define cudaGraph_t musaGraph_t
133
  #define cudaKernelNodeParams musaKernelNodeParams
134
  #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
 
135
  #define cudaStreamEndCapture musaStreamEndCapture
136
 
137
  typedef mt_bfloat16 nv_bfloat16;
 
119
  #define cudaGraphExecDestroy musaGraphExecDestroy
120
  #define cudaGraphExec_t musaGraphExec_t
121
  #define cudaGraphExecUpdate musaGraphExecUpdate
122
+ #define cudaGraphExecUpdateResult musaGraphExecUpdateResult
123
  #define cudaGraphGetNodes musaGraphGetNodes
124
  #define cudaGraphInstantiate musaGraphInstantiate
125
  #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
 
132
  #define cudaGraph_t musaGraph_t
133
  #define cudaKernelNodeParams musaKernelNodeParams
134
  #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135
+ #define cudaStreamBeginCapture musaStreamBeginCapture
136
  #define cudaStreamEndCapture musaStreamEndCapture
137
 
138
  typedef mt_bfloat16 nv_bfloat16;
ggml/src/ggml-musa/CMakeLists.txt CHANGED
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
67
  add_compile_definitions(GGML_USE_MUSA)
68
  add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
69
 
70
- if (GGML_CUDA_GRAPHS)
71
- add_compile_definitions(GGML_CUDA_USE_GRAPHS)
72
- endif()
73
-
74
  if (GGML_CUDA_FORCE_MMQ)
75
  add_compile_definitions(GGML_CUDA_FORCE_MMQ)
76
  endif()
 
67
  add_compile_definitions(GGML_USE_MUSA)
68
  add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
69
 
 
 
 
 
70
  if (GGML_CUDA_FORCE_MMQ)
71
  add_compile_definitions(GGML_CUDA_FORCE_MMQ)
72
  endif()