Sigbjørn Skjæret commited on
Commit
332bcaf
·
1 Parent(s): d41b883

CUDA: don't convert BF16 weights to FP32 (ggml/1174)

Browse files

* add bf16 support

* use convert_from_bf16_cuda instead of convert_unary_cuda for f32

* revert 7ec5085

* move functionality into convert_unary with constexpr

ggml/src/ggml-cuda/convert.cu CHANGED
@@ -579,7 +579,13 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
579
 
580
  const src_t * x = (const src_t *) vx;
581
 
582
- y[i] = x[i];
 
 
 
 
 
 
583
  }
584
 
585
  template <typename src_t, typename dst_t>
@@ -588,6 +594,17 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
588
  convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
589
  }
590
 
 
 
 
 
 
 
 
 
 
 
 
591
  to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
592
  switch (type) {
593
  case GGML_TYPE_Q4_0:
@@ -633,6 +650,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
633
  return dequantize_row_iq3_s_cuda;
634
  case GGML_TYPE_F32:
635
  return convert_unary_cuda<float>;
 
 
636
  default:
637
  return nullptr;
638
  }
 
579
 
580
  const src_t * x = (const src_t *) vx;
581
 
582
+ if constexpr (std::is_same_v<src_t, nv_bfloat16>) {
583
+ y[i] = __bfloat162float(x[i]);
584
+ } else if constexpr (std::is_same_v<dst_t, nv_bfloat16> && std::is_same_v<src_t, half>) {
585
+ y[i] = (float)x[i];
586
+ } else {
587
+ y[i] = x[i];
588
+ }
589
  }
590
 
591
  template <typename src_t, typename dst_t>
 
594
  convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
595
  }
596
 
597
+ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
598
+ switch (type) {
599
+ case GGML_TYPE_F32:
600
+ return convert_unary_cuda<float>;
601
+ case GGML_TYPE_F16:
602
+ return convert_unary_cuda<half>;
603
+ default:
604
+ return nullptr;
605
+ }
606
+ }
607
+
608
  to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
609
  switch (type) {
610
  case GGML_TYPE_Q4_0:
 
650
  return dequantize_row_iq3_s_cuda;
651
  case GGML_TYPE_F32:
652
  return convert_unary_cuda<float>;
653
+ case GGML_TYPE_BF16:
654
+ return convert_unary_cuda<nv_bfloat16>;
655
  default:
656
  return nullptr;
657
  }
ggml/src/ggml-cuda/convert.cuh CHANGED
@@ -7,7 +7,10 @@ using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, in
7
 
8
  typedef to_t_cuda_t<float> to_fp32_cuda_t;
9
  typedef to_t_cuda_t<half> to_fp16_cuda_t;
 
10
 
11
  to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
12
 
 
 
13
  to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
 
7
 
8
  typedef to_t_cuda_t<float> to_fp32_cuda_t;
9
  typedef to_t_cuda_t<half> to_fp16_cuda_t;
10
+ typedef to_t_cuda_t<nv_bfloat16> to_bf16_cuda_t;
11
 
12
  to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
13
 
14
+ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
15
+
16
  to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -1194,7 +1194,35 @@ static void ggml_cuda_op_mul_mat_cublas(
1194
 
1195
  const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
1196
 
1197
- if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1198
  // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1199
  ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
1200
  if (src0->type != GGML_TYPE_F16) {
 
1194
 
1195
  const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
1196
 
1197
+ if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1198
+ ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
1199
+ if (src1->type != GGML_TYPE_BF16) {
1200
+ const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
1201
+ GGML_ASSERT(to_bf16_cuda != nullptr);
1202
+ size_t ne = src1_ncols*ne10;
1203
+ src1_as_bf16.alloc(ne);
1204
+ to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream);
1205
+ }
1206
+ const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
1207
+ const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i;
1208
+ ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
1209
+
1210
+ const float alpha_f32 = 1.0f;
1211
+ const float beta_f32 = 0.0f;
1212
+
1213
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1214
+ CUBLAS_CHECK(
1215
+ cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1216
+ row_diff, src1_ncols, ne10,
1217
+ &alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
1218
+ src1_ptr, CUDA_R_16BF, ne10,
1219
+ &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
1220
+ CUBLAS_COMPUTE_32F,
1221
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1222
+
1223
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
1224
+ to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1225
+ } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
1226
  // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1227
  ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
1228
  if (src0->type != GGML_TYPE_F16) {