John Balis slaren commited on
Commit
75d438c
·
unverified ·
1 Parent(s): 5bf1614

`ggml_cuda_cpy` support for 4d tensors and float16->float32 upcasting (ggml/686)

Browse files

* added cuda float16->float32 upcasting to ggml_cuda_cpy

* added ability to copy 4d tensors with the cuda backend

* added tests for float16_>float32 upcast and 4d tensor cuda copys

* added 4d copy test for float32->float16 copy

* applied patch suggested by

@iamlemec


* simplify cpy tests

---------

Co-authored-by: slaren <[email protected]>

Files changed (1) hide show
  1. ggml-cuda.cu +85 -48
ggml-cuda.cu CHANGED
@@ -5357,27 +5357,37 @@ static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
5357
  *dsti = *xi;
5358
  }
5359
 
 
 
 
 
 
 
 
5360
  template <cpy_kernel_t cpy_1>
5361
  static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
5362
- const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5363
- const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
 
5364
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
5365
 
5366
  if (i >= ne) {
5367
  return;
5368
  }
5369
 
5370
- // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
5371
  // then combine those indices with the corresponding byte offsets to get the total offsets
5372
- const int i02 = i / (ne00*ne01);
5373
- const int i01 = (i - i02*ne01*ne00) / ne00;
5374
- const int i00 = i - i02*ne01*ne00 - i01*ne00;
5375
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
5376
-
5377
- const int i12 = i / (ne10*ne11);
5378
- const int i11 = (i - i12*ne10*ne11) / ne10;
5379
- const int i10 = i - i12*ne10*ne11 - i11*ne10;
5380
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
 
 
5381
 
5382
  cpy_1(cx + x_offset, cdst + dst_offset);
5383
  }
@@ -5471,23 +5481,26 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
5471
 
5472
  template <cpy_kernel_t cpy_blck, int qk>
5473
  static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
5474
- const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5475
- const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
 
5476
  const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
5477
 
5478
  if (i >= ne) {
5479
  return;
5480
  }
5481
 
5482
- const int i02 = i / (ne00*ne01);
5483
- const int i01 = (i - i02*ne01*ne00) / ne00;
5484
- const int i00 = (i - i02*ne01*ne00 - i01*ne00);
5485
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
 
5486
 
5487
- const int i12 = i / (ne10*ne11);
5488
- const int i11 = (i - i12*ne10*ne11) / ne10;
5489
- const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
5490
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
 
5491
 
5492
  cpy_blck(cx + x_offset, cdst + dst_offset);
5493
  }
@@ -7135,69 +7148,82 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
7135
  (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
7136
  }
7137
 
 
 
 
 
 
 
 
 
 
 
 
7138
  static void ggml_cpy_f32_f32_cuda(
7139
  const char * cx, char * cdst, const int ne,
7140
- const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
7141
- const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
7142
 
7143
  const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
7144
  cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
7145
- (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
7146
  }
7147
 
7148
  static void ggml_cpy_f32_f16_cuda(
7149
  const char * cx, char * cdst, const int ne,
7150
- const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
7151
- const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
7152
 
7153
  const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
7154
  cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
7155
- (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
7156
  }
7157
 
7158
  static void ggml_cpy_f32_q8_0_cuda(
7159
  const char * cx, char * cdst, const int ne,
7160
- const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
7161
- const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
7162
 
7163
  GGML_ASSERT(ne % QK8_0 == 0);
7164
  const int num_blocks = ne / QK8_0;
7165
  cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
7166
- (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
7167
  }
7168
 
7169
  static void ggml_cpy_f32_q4_0_cuda(
7170
  const char * cx, char * cdst, const int ne,
7171
- const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
7172
- const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
7173
 
7174
  GGML_ASSERT(ne % QK4_0 == 0);
7175
  const int num_blocks = ne / QK4_0;
7176
  cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
7177
- (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
7178
  }
7179
 
7180
  static void ggml_cpy_f32_q4_1_cuda(
7181
  const char * cx, char * cdst, const int ne,
7182
- const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
7183
- const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
7184
 
7185
  GGML_ASSERT(ne % QK4_1 == 0);
7186
  const int num_blocks = ne / QK4_1;
7187
  cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
7188
- (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
7189
  }
7190
 
7191
  static void ggml_cpy_f16_f16_cuda(
7192
  const char * cx, char * cdst, const int ne,
7193
- const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
7194
- const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
7195
 
7196
  const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
7197
  cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
7198
- (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
7199
  }
7200
 
 
 
7201
  static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
7202
  const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
7203
  scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -9941,19 +9967,25 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
9941
 
9942
  const int64_t ne00 = src0->ne[0];
9943
  const int64_t ne01 = src0->ne[1];
9944
- GGML_ASSERT(src0->ne[3] == 1);
 
 
9945
 
9946
  const int64_t nb00 = src0->nb[0];
9947
  const int64_t nb01 = src0->nb[1];
9948
  const int64_t nb02 = src0->nb[2];
 
9949
 
9950
  const int64_t ne10 = src1->ne[0];
9951
  const int64_t ne11 = src1->ne[1];
9952
- GGML_ASSERT(src1->ne[3] == 1);
 
 
9953
 
9954
  const int64_t nb10 = src1->nb[0];
9955
  const int64_t nb11 = src1->nb[1];
9956
  const int64_t nb12 = src1->nb[2];
 
9957
 
9958
  ggml_cuda_set_device(g_main_device);
9959
  cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
@@ -9965,17 +9997,19 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
9965
  char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
9966
 
9967
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
9968
- ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
9969
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
9970
- ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
9971
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
9972
- ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
9973
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
9974
- ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
9975
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
9976
- ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
9977
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
9978
- ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
 
 
9979
  } else {
9980
  fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
9981
  ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -10978,6 +11012,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
10978
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
10979
  return true;
10980
  }
 
 
 
10981
  return false;
10982
  } break;
10983
  case GGML_OP_DUP:
 
5357
  *dsti = *xi;
5358
  }
5359
 
5360
+ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
5361
+ const half * xi = (const half *) cxi;
5362
+ float * dsti = (float *) cdsti;
5363
+
5364
+ *dsti = *xi;
5365
+ }
5366
+
5367
  template <cpy_kernel_t cpy_1>
5368
  static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
5369
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
5370
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
5371
+ const int nb12, const int nb13) {
5372
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
5373
 
5374
  if (i >= ne) {
5375
  return;
5376
  }
5377
 
5378
+ // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
5379
  // then combine those indices with the corresponding byte offsets to get the total offsets
5380
+ const int i03 = i/(ne00 * ne01 * ne02);
5381
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
5382
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
5383
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
5384
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
5385
+
5386
+ const int i13 = i/(ne10 * ne11 * ne12);
5387
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
5388
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
5389
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
5390
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
5391
 
5392
  cpy_1(cx + x_offset, cdst + dst_offset);
5393
  }
 
5481
 
5482
  template <cpy_kernel_t cpy_blck, int qk>
5483
  static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
5484
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
5485
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
5486
+ const int nb12, const int nb13) {
5487
  const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
5488
 
5489
  if (i >= ne) {
5490
  return;
5491
  }
5492
 
5493
+ const int i03 = i/(ne00 * ne01 * ne02);
5494
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
5495
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
5496
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
5497
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
5498
 
5499
+ const int i13 = i/(ne10 * ne11 * ne12);
5500
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
5501
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
5502
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
5503
+ const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
5504
 
5505
  cpy_blck(cx + x_offset, cdst + dst_offset);
5506
  }
 
7148
  (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
7149
  }
7150
 
7151
+
7152
+ static void ggml_cpy_f16_f32_cuda(
7153
+ const char * cx, char * cdst, const int ne,
7154
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
7155
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
7156
+
7157
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
7158
+ cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
7159
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
7160
+ }
7161
+
7162
  static void ggml_cpy_f32_f32_cuda(
7163
  const char * cx, char * cdst, const int ne,
7164
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
7165
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
7166
 
7167
  const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
7168
  cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
7169
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
7170
  }
7171
 
7172
  static void ggml_cpy_f32_f16_cuda(
7173
  const char * cx, char * cdst, const int ne,
7174
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
7175
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
7176
 
7177
  const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
7178
  cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
7179
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
7180
  }
7181
 
7182
  static void ggml_cpy_f32_q8_0_cuda(
7183
  const char * cx, char * cdst, const int ne,
7184
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
7185
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
7186
 
7187
  GGML_ASSERT(ne % QK8_0 == 0);
7188
  const int num_blocks = ne / QK8_0;
7189
  cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
7190
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
7191
  }
7192
 
7193
  static void ggml_cpy_f32_q4_0_cuda(
7194
  const char * cx, char * cdst, const int ne,
7195
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
7196
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
7197
 
7198
  GGML_ASSERT(ne % QK4_0 == 0);
7199
  const int num_blocks = ne / QK4_0;
7200
  cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
7201
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
7202
  }
7203
 
7204
  static void ggml_cpy_f32_q4_1_cuda(
7205
  const char * cx, char * cdst, const int ne,
7206
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
7207
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
7208
 
7209
  GGML_ASSERT(ne % QK4_1 == 0);
7210
  const int num_blocks = ne / QK4_1;
7211
  cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
7212
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
7213
  }
7214
 
7215
  static void ggml_cpy_f16_f16_cuda(
7216
  const char * cx, char * cdst, const int ne,
7217
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
7218
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
7219
 
7220
  const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
7221
  cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
7222
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
7223
  }
7224
 
7225
+
7226
+
7227
  static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
7228
  const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
7229
  scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
 
9967
 
9968
  const int64_t ne00 = src0->ne[0];
9969
  const int64_t ne01 = src0->ne[1];
9970
+ const int64_t ne02 = src0->ne[2];
9971
+
9972
+ //GGML_ASSERT(src0->ne[3] == 1);
9973
 
9974
  const int64_t nb00 = src0->nb[0];
9975
  const int64_t nb01 = src0->nb[1];
9976
  const int64_t nb02 = src0->nb[2];
9977
+ const int64_t nb03 = src0->nb[3];
9978
 
9979
  const int64_t ne10 = src1->ne[0];
9980
  const int64_t ne11 = src1->ne[1];
9981
+ const int64_t ne12 = src1->ne[2];
9982
+
9983
+ //GGML_ASSERT(src1->ne[3] == 1);
9984
 
9985
  const int64_t nb10 = src1->nb[0];
9986
  const int64_t nb11 = src1->nb[1];
9987
  const int64_t nb12 = src1->nb[2];
9988
+ const int64_t nb13 = src1->nb[3];
9989
 
9990
  ggml_cuda_set_device(g_main_device);
9991
  cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
9997
  char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
9998
 
9999
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
10000
+ ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
10001
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
10002
+ ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
10003
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
10004
+ ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
10005
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
10006
+ ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
10007
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
10008
+ ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
10009
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
10010
+ ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
10011
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
10012
+ ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
10013
  } else {
10014
  fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
10015
  ggml_type_name(src0->type), ggml_type_name(src1->type));
 
11012
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
11013
  return true;
11014
  }
11015
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
11016
+ return true;
11017
+ }
11018
  return false;
11019
  } break;
11020
  case GGML_OP_DUP: