JohannesGaessler commited on
Commit
2fbcec1
·
1 Parent(s): bf3dc93

CUDA: backwards pass for misc. ops, add tests (llama/11257)

Browse files

* CUDA: backwards pass for misc. ops, add tests

* remove restrict from pointers

ggml/include/ggml.h CHANGED
@@ -1384,16 +1384,20 @@ extern "C" {
1384
  float scale,
1385
  float max_bias);
1386
 
1387
- GGML_API struct ggml_tensor * ggml_soft_max_back(
1388
  struct ggml_context * ctx,
1389
  struct ggml_tensor * a,
1390
- struct ggml_tensor * b);
 
 
1391
 
1392
  // in-place, returns view(a)
1393
- GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
1394
  struct ggml_context * ctx,
1395
  struct ggml_tensor * a,
1396
- struct ggml_tensor * b);
 
 
1397
 
1398
  // rotary position embedding
1399
  // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
 
1384
  float scale,
1385
  float max_bias);
1386
 
1387
+ GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
1388
  struct ggml_context * ctx,
1389
  struct ggml_tensor * a,
1390
+ struct ggml_tensor * b,
1391
+ float scale,
1392
+ float max_bias);
1393
 
1394
  // in-place, returns view(a)
1395
+ GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
1396
  struct ggml_context * ctx,
1397
  struct ggml_tensor * a,
1398
+ struct ggml_tensor * b,
1399
+ float scale,
1400
+ float max_bias);
1401
 
1402
  // rotary position embedding
1403
  // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
ggml/src/ggml-alloc.c CHANGED
@@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
37
  return true;
38
  }
39
 
 
40
  static bool ggml_op_can_inplace(enum ggml_op op) {
41
  switch (op) {
42
  case GGML_OP_SCALE:
@@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
52
  case GGML_OP_LOG:
53
  case GGML_OP_UNARY:
54
  case GGML_OP_ROPE:
 
 
55
  case GGML_OP_RMS_NORM:
 
56
  case GGML_OP_SOFT_MAX:
 
57
  return true;
58
 
59
  default:
 
37
  return true;
38
  }
39
 
40
+ // ops that return true for this function must not use restrict pointers for their backend implementations
41
  static bool ggml_op_can_inplace(enum ggml_op op) {
42
  switch (op) {
43
  case GGML_OP_SCALE:
 
53
  case GGML_OP_LOG:
54
  case GGML_OP_UNARY:
55
  case GGML_OP_ROPE:
56
+ case GGML_OP_ROPE_BACK:
57
+ case GGML_OP_SILU_BACK:
58
  case GGML_OP_RMS_NORM:
59
+ case GGML_OP_RMS_NORM_BACK:
60
  case GGML_OP_SOFT_MAX:
61
+ case GGML_OP_SOFT_MAX_BACK:
62
  return true;
63
 
64
  default:
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -6691,20 +6691,20 @@ static void ggml_compute_forward_silu_back_f32(
6691
  const struct ggml_compute_params * params,
6692
  struct ggml_tensor * dst) {
6693
 
6694
- const struct ggml_tensor * src0 = dst->src[0];
6695
- const struct ggml_tensor * grad = dst->src[1];
6696
 
6697
  assert(ggml_is_contiguous_1(grad));
6698
- assert(ggml_is_contiguous_1(src0));
6699
  assert(ggml_is_contiguous_1(dst));
6700
- assert(ggml_are_same_shape(src0, dst));
6701
- assert(ggml_are_same_shape(src0, grad));
6702
 
6703
  const int ith = params->ith;
6704
  const int nth = params->nth;
6705
 
6706
- const int nc = src0->ne[0];
6707
- const int nr = ggml_nrows(src0);
6708
 
6709
  // rows per thread
6710
  const int dr = (nr + nth - 1)/nth;
@@ -6716,7 +6716,7 @@ static void ggml_compute_forward_silu_back_f32(
6716
  for (int i1 = ir0; i1 < ir1; i1++) {
6717
  ggml_vec_silu_backward_f32(nc,
6718
  (float *) ((char *) dst->data + i1*( dst->nb[1])),
6719
- (float *) ((char *) src0->data + i1*(src0->nb[1])),
6720
  (float *) ((char *) grad->data + i1*(grad->nb[1])));
6721
 
6722
  #ifndef NDEBUG
@@ -6895,7 +6895,7 @@ static void ggml_compute_forward_norm_f32(
6895
  float eps;
6896
  memcpy(&eps, dst->op_params, sizeof(float));
6897
 
6898
- GGML_ASSERT(eps > 0.0f);
6899
 
6900
  // TODO: optimize
6901
  for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -6966,7 +6966,7 @@ static void ggml_compute_forward_rms_norm_f32(
6966
  float eps;
6967
  memcpy(&eps, dst->op_params, sizeof(float));
6968
 
6969
- GGML_ASSERT(eps > 0.0f);
6970
 
6971
  // TODO: optimize
6972
  for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7018,12 +7018,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
7018
  const struct ggml_compute_params * params,
7019
  struct ggml_tensor * dst) {
7020
 
7021
- const struct ggml_tensor * src0 = dst->src[0];
7022
- const struct ggml_tensor * src1 = dst->src[1];
7023
 
7024
  GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
7025
 
7026
  GGML_ASSERT(src0->nb[0] == sizeof(float));
 
7027
 
7028
  const int ith = params->ith;
7029
  const int nth = params->nth;
@@ -7042,8 +7043,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
7042
  const int64_t i12 = i02;
7043
  const int64_t i13 = i03;
7044
 
7045
- const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7046
- const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
7047
 
7048
  ggml_float sum_xx = 0.0;
7049
  ggml_float sum_xdz = 0.0;
@@ -7066,9 +7067,9 @@ static void ggml_compute_forward_rms_norm_back_f32(
7066
  {
7067
  // z = rms_norm(x)
7068
  //
7069
- // rms_norm(src0) =
7070
  // scale(
7071
- // src0,
7072
  // div(
7073
  // 1,
7074
  // sqrt(
@@ -7076,13 +7077,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
7076
  // scale(
7077
  // sum(
7078
  // sqr(
7079
- // src0)),
7080
  // (1.0/N)),
7081
  // eps))));
7082
 
7083
  // postorder:
7084
  // ## op args grad
7085
- // 00 param src0 grad[#00]
7086
  // 01 const 1
7087
  // 02 sqr (#00) grad[#02]
7088
  // 03 sum (#02) grad[#03]
@@ -7159,6 +7160,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
7159
  // dx := scale(dx, rrms)
7160
  float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
7161
 
 
7162
  ggml_vec_cpy_f32 (ne00, dx, x);
7163
  // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
7164
  ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@@ -7750,12 +7752,13 @@ static void ggml_compute_forward_out_prod_f32(
7750
  const int ith = params->ith;
7751
  const int nth = params->nth;
7752
 
7753
- GGML_ASSERT(ne0 == ne00);
7754
- GGML_ASSERT(ne1 == ne10);
7755
- GGML_ASSERT(ne2 == ne02);
7756
- GGML_ASSERT(ne02 == ne12);
7757
- GGML_ASSERT(ne3 == ne13);
7758
- GGML_ASSERT(ne03 == ne13);
 
7759
 
7760
  // we don't support permuted src0 or src1
7761
  GGML_ASSERT(nb00 == sizeof(float));
@@ -7797,6 +7800,10 @@ static void ggml_compute_forward_out_prod_f32(
7797
  const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
7798
  const int64_t blck_1 = 16;
7799
 
 
 
 
 
7800
  for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
7801
  const int64_t bir1 = MIN(bir + blck_1, ir1);
7802
  for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@@ -7807,8 +7814,8 @@ static void ggml_compute_forward_out_prod_f32(
7807
  const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
7808
  const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
7809
 
7810
- const int64_t i02 = i2;
7811
- const int64_t i03 = i3;
7812
 
7813
  //const int64_t i10 = i1;
7814
  const int64_t i12 = i2;
@@ -8906,9 +8913,9 @@ static void ggml_compute_forward_soft_max(
8906
  }
8907
 
8908
 
8909
- // ggml_compute_forward_soft_max_back
8910
 
8911
- static void ggml_compute_forward_soft_max_back_f32(
8912
  const struct ggml_compute_params * params,
8913
  struct ggml_tensor * dst) {
8914
 
@@ -8921,6 +8928,14 @@ static void ggml_compute_forward_soft_max_back_f32(
8921
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
8922
  GGML_ASSERT(ggml_are_same_shape(src1, dst));
8923
 
 
 
 
 
 
 
 
 
8924
  // TODO: handle transposed/permuted matrices
8925
 
8926
  const int ith = params->ith;
@@ -8969,10 +8984,11 @@ static void ggml_compute_forward_soft_max_back_f32(
8969
 
8970
  // linear runtime, no additional memory
8971
  float dot_y_dy = 0;
8972
- ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
8973
- ggml_vec_cpy_f32 (nc, dx, dy);
8974
- ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
8975
- ggml_vec_mul_f32 (nc, dx, dx, y);
 
8976
 
8977
  #ifndef NDEBUG
8978
  for (int i = 0; i < nc; ++i) {
@@ -8983,7 +8999,7 @@ static void ggml_compute_forward_soft_max_back_f32(
8983
  }
8984
  }
8985
 
8986
- static void ggml_compute_forward_soft_max_back(
8987
  const struct ggml_compute_params * params,
8988
  struct ggml_tensor * dst) {
8989
 
@@ -8992,7 +9008,7 @@ static void ggml_compute_forward_soft_max_back(
8992
  switch (src0->type) {
8993
  case GGML_TYPE_F32:
8994
  {
8995
- ggml_compute_forward_soft_max_back_f32(params, dst);
8996
  } break;
8997
  default:
8998
  {
@@ -9985,9 +10001,10 @@ static void ggml_compute_forward_im2col_back_f32(
9985
  const struct ggml_compute_params * params,
9986
  struct ggml_tensor * dst) {
9987
 
9988
- const struct ggml_tensor * src0 = dst->src[0];
9989
- const struct ggml_tensor * src1 = dst->src[1];
9990
 
 
9991
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
9992
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
9993
 
@@ -10009,11 +10026,11 @@ static void ggml_compute_forward_im2col_back_f32(
10009
  const int64_t IH = is_2D ? ne1 : 1;
10010
  const int64_t IW = ne0;
10011
 
10012
- const int64_t KH = is_2D ? ne01 : 1;
10013
- const int64_t KW = ne00;
10014
 
10015
- const int64_t OH = is_2D ? ne12 : 1;
10016
- const int64_t OW = ne11;
10017
 
10018
  int ofs0 = is_2D ? nb3 : nb2;
10019
  int ofs1 = is_2D ? nb2 : nb1;
@@ -10059,9 +10076,9 @@ static void ggml_compute_forward_im2col_back_f32(
10059
  continue;
10060
  }
10061
 
10062
- const float * const src_data = (const float *) src1->data
10063
  + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
10064
- grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
10065
  }
10066
  }
10067
  float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@@ -12484,22 +12501,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12484
  const struct ggml_compute_params * params,
12485
  struct ggml_tensor * dst) {
12486
 
12487
- const struct ggml_tensor * src0 = dst->src[0];
12488
- const struct ggml_tensor * src1 = dst->src[1];
12489
- const struct ggml_tensor * opt0 = dst->src[2];
12490
 
12491
  GGML_ASSERT(ggml_is_contiguous(dst));
12492
- GGML_ASSERT(ggml_is_contiguous(src0));
12493
- GGML_ASSERT(ggml_is_contiguous(src1));
12494
- GGML_ASSERT(ggml_is_contiguous(opt0));
12495
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
12496
 
12497
  const int64_t ith = params->ith;
12498
  const int64_t nth = params->nth;
12499
 
12500
  // TODO: handle transposed/permuted matrices
12501
- const int64_t nc = src0->ne[0];
12502
- const int64_t nr = ggml_nrows(src0);
12503
 
12504
  // rows per thread
12505
  const int64_t dr = (nr + nth - 1)/nth;
@@ -12508,12 +12525,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12508
  const int64_t ir0 = dr*ith;
12509
  const int64_t ir1 = MIN(ir0 + dr, nr);
12510
 
12511
- const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
12512
 
12513
  for (int64_t i1 = ir0; i1 < ir1; i1++) {
12514
- float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
12515
- float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
12516
- float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
12517
 
12518
  #ifndef NDEBUG
12519
  for (int64_t i = 0; i < nc; ++i) {
@@ -12526,11 +12543,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12526
  // soft_max
12527
  float max = -INFINITY;
12528
  ggml_vec_max_f32(nc, &max, s0);
12529
- ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
12530
  assert(sum > 0.0);
12531
  ggml_vec_scale_f32(nc, ds0, 1.0/sum);
12532
 
12533
- // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
12534
  ggml_vec_sub_f32(nc, ds0, ds0, s1);
12535
  ggml_vec_scale_f32(nc, ds0, d_by_nr);
12536
 
@@ -12827,7 +12844,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12827
  } break;
12828
  case GGML_OP_SOFT_MAX_BACK:
12829
  {
12830
- ggml_compute_forward_soft_max_back(params, tensor);
12831
  } break;
12832
  case GGML_OP_ROPE:
12833
  {
 
6691
  const struct ggml_compute_params * params,
6692
  struct ggml_tensor * dst) {
6693
 
6694
+ const struct ggml_tensor * grad = dst->src[0];
6695
+ const struct ggml_tensor * src1 = dst->src[1];
6696
 
6697
  assert(ggml_is_contiguous_1(grad));
6698
+ assert(ggml_is_contiguous_1(src1));
6699
  assert(ggml_is_contiguous_1(dst));
6700
+ assert(ggml_are_same_shape(src1, dst));
6701
+ assert(ggml_are_same_shape(src1, grad));
6702
 
6703
  const int ith = params->ith;
6704
  const int nth = params->nth;
6705
 
6706
+ const int nc = src1->ne[0];
6707
+ const int nr = ggml_nrows(src1);
6708
 
6709
  // rows per thread
6710
  const int dr = (nr + nth - 1)/nth;
 
6716
  for (int i1 = ir0; i1 < ir1; i1++) {
6717
  ggml_vec_silu_backward_f32(nc,
6718
  (float *) ((char *) dst->data + i1*( dst->nb[1])),
6719
+ (float *) ((char *) src1->data + i1*(src1->nb[1])),
6720
  (float *) ((char *) grad->data + i1*(grad->nb[1])));
6721
 
6722
  #ifndef NDEBUG
 
6895
  float eps;
6896
  memcpy(&eps, dst->op_params, sizeof(float));
6897
 
6898
+ GGML_ASSERT(eps >= 0.0f);
6899
 
6900
  // TODO: optimize
6901
  for (int64_t i03 = 0; i03 < ne03; i03++) {
 
6966
  float eps;
6967
  memcpy(&eps, dst->op_params, sizeof(float));
6968
 
6969
+ GGML_ASSERT(eps >= 0.0f);
6970
 
6971
  // TODO: optimize
6972
  for (int64_t i03 = 0; i03 < ne03; i03++) {
 
7018
  const struct ggml_compute_params * params,
7019
  struct ggml_tensor * dst) {
7020
 
7021
+ const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
7022
+ const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
7023
 
7024
  GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
7025
 
7026
  GGML_ASSERT(src0->nb[0] == sizeof(float));
7027
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
7028
 
7029
  const int ith = params->ith;
7030
  const int nth = params->nth;
 
7043
  const int64_t i12 = i02;
7044
  const int64_t i13 = i03;
7045
 
7046
+ const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7047
+ const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
7048
 
7049
  ggml_float sum_xx = 0.0;
7050
  ggml_float sum_xdz = 0.0;
 
7067
  {
7068
  // z = rms_norm(x)
7069
  //
7070
+ // rms_norm(src1) =
7071
  // scale(
7072
+ // src1,
7073
  // div(
7074
  // 1,
7075
  // sqrt(
 
7077
  // scale(
7078
  // sum(
7079
  // sqr(
7080
+ // src1)),
7081
  // (1.0/N)),
7082
  // eps))));
7083
 
7084
  // postorder:
7085
  // ## op args grad
7086
+ // 00 param src1 grad[#00]
7087
  // 01 const 1
7088
  // 02 sqr (#00) grad[#02]
7089
  // 03 sum (#02) grad[#03]
 
7160
  // dx := scale(dx, rrms)
7161
  float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
7162
 
7163
+ // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
7164
  ggml_vec_cpy_f32 (ne00, dx, x);
7165
  // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
7166
  ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
 
7752
  const int ith = params->ith;
7753
  const int nth = params->nth;
7754
 
7755
+ GGML_ASSERT(ne0 == ne00);
7756
+ GGML_ASSERT(ne1 == ne10);
7757
+ GGML_ASSERT(ne2 == ne12);
7758
+ GGML_ASSERT(ne3 == ne13);
7759
+
7760
+ GGML_ASSERT(ne2 % ne02 == 0);
7761
+ GGML_ASSERT(ne3 % ne03 == 0);
7762
 
7763
  // we don't support permuted src0 or src1
7764
  GGML_ASSERT(nb00 == sizeof(float));
 
7800
  const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
7801
  const int64_t blck_1 = 16;
7802
 
7803
+ // dps == dst per src0, used for group query attention
7804
+ const int64_t dps2 = ne2 / ne02;
7805
+ const int64_t dps3 = ne3 / ne03;
7806
+
7807
  for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
7808
  const int64_t bir1 = MIN(bir + blck_1, ir1);
7809
  for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
 
7814
  const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
7815
  const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
7816
 
7817
+ const int64_t i02 = i2 / dps2;
7818
+ const int64_t i03 = i3 / dps3;
7819
 
7820
  //const int64_t i10 = i1;
7821
  const int64_t i12 = i2;
 
8913
  }
8914
 
8915
 
8916
+ // ggml_compute_forward_soft_max_ext_back
8917
 
8918
+ static void ggml_compute_forward_soft_max_ext_back_f32(
8919
  const struct ggml_compute_params * params,
8920
  struct ggml_tensor * dst) {
8921
 
 
8928
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
8929
  GGML_ASSERT(ggml_are_same_shape(src1, dst));
8930
 
8931
+ float scale = 1.0f;
8932
+ float max_bias = 0.0f;
8933
+
8934
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8935
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8936
+
8937
+ GGML_ASSERT(max_bias == 0.0f);
8938
+
8939
  // TODO: handle transposed/permuted matrices
8940
 
8941
  const int ith = params->ith;
 
8984
 
8985
  // linear runtime, no additional memory
8986
  float dot_y_dy = 0;
8987
+ ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
8988
+ ggml_vec_cpy_f32 (nc, dx, dy);
8989
+ ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
8990
+ ggml_vec_mul_f32 (nc, dx, dx, y);
8991
+ ggml_vec_scale_f32(nc, dx, scale);
8992
 
8993
  #ifndef NDEBUG
8994
  for (int i = 0; i < nc; ++i) {
 
8999
  }
9000
  }
9001
 
9002
+ static void ggml_compute_forward_soft_max_ext_back(
9003
  const struct ggml_compute_params * params,
9004
  struct ggml_tensor * dst) {
9005
 
 
9008
  switch (src0->type) {
9009
  case GGML_TYPE_F32:
9010
  {
9011
+ ggml_compute_forward_soft_max_ext_back_f32(params, dst);
9012
  } break;
9013
  default:
9014
  {
 
10001
  const struct ggml_compute_params * params,
10002
  struct ggml_tensor * dst) {
10003
 
10004
+ const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
10005
+ const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
10006
 
10007
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
10008
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
10009
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
10010
 
 
10026
  const int64_t IH = is_2D ? ne1 : 1;
10027
  const int64_t IW = ne0;
10028
 
10029
+ const int64_t KH = is_2D ? ne11 : 1;
10030
+ const int64_t KW = ne10;
10031
 
10032
+ const int64_t OH = is_2D ? ne02 : 1;
10033
+ const int64_t OW = ne01;
10034
 
10035
  int ofs0 = is_2D ? nb3 : nb2;
10036
  int ofs1 = is_2D ? nb2 : nb1;
 
10076
  continue;
10077
  }
10078
 
10079
+ const float * const grad_in = (const float *) src0->data
10080
  + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
10081
+ grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
10082
  }
10083
  }
10084
  float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
 
12501
  const struct ggml_compute_params * params,
12502
  struct ggml_tensor * dst) {
12503
 
12504
+ const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
12505
+ const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
12506
+ const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
12507
 
12508
  GGML_ASSERT(ggml_is_contiguous(dst));
12509
+ GGML_ASSERT(ggml_is_contiguous(src0f));
12510
+ GGML_ASSERT(ggml_is_contiguous(src1f));
12511
+ GGML_ASSERT(ggml_is_contiguous(grad));
12512
+ GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
12513
 
12514
  const int64_t ith = params->ith;
12515
  const int64_t nth = params->nth;
12516
 
12517
  // TODO: handle transposed/permuted matrices
12518
+ const int64_t nc = src0f->ne[0];
12519
+ const int64_t nr = ggml_nrows(src0f);
12520
 
12521
  // rows per thread
12522
  const int64_t dr = (nr + nth - 1)/nth;
 
12525
  const int64_t ir0 = dr*ith;
12526
  const int64_t ir1 = MIN(ir0 + dr, nr);
12527
 
12528
+ const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
12529
 
12530
  for (int64_t i1 = ir0; i1 < ir1; i1++) {
12531
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
12532
+ const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
12533
+ const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
12534
 
12535
  #ifndef NDEBUG
12536
  for (int64_t i = 0; i < nc; ++i) {
 
12543
  // soft_max
12544
  float max = -INFINITY;
12545
  ggml_vec_max_f32(nc, &max, s0);
12546
+ const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
12547
  assert(sum > 0.0);
12548
  ggml_vec_scale_f32(nc, ds0, 1.0/sum);
12549
 
12550
+ // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
12551
  ggml_vec_sub_f32(nc, ds0, ds0, s1);
12552
  ggml_vec_scale_f32(nc, ds0, d_by_nr);
12553
 
 
12844
  } break;
12845
  case GGML_OP_SOFT_MAX_BACK:
12846
  {
12847
+ ggml_compute_forward_soft_max_ext_back(params, tensor);
12848
  } break;
12849
  case GGML_OP_ROPE:
12850
  {
ggml/src/ggml-cpu/ggml-cpu.cpp CHANGED
@@ -403,6 +403,16 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
403
  op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
404
  case GGML_OP_MUL_MAT:
405
  return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
 
 
 
 
 
 
 
 
 
 
406
  case GGML_OP_IM2COL_BACK:
407
  return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
408
  case GGML_OP_OUT_PROD:
 
403
  op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
404
  case GGML_OP_MUL_MAT:
405
  return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
406
+ case GGML_OP_SOFT_MAX_BACK: {
407
+ if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
408
+ return false;
409
+ }
410
+ float max_bias = 0.0f;
411
+
412
+ memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
413
+
414
+ return max_bias == 0.0f;
415
+ }
416
  case GGML_OP_IM2COL_BACK:
417
  return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
418
  case GGML_OP_OUT_PROD:
ggml/src/ggml-cuda/cross-entropy-loss.cu CHANGED
@@ -5,95 +5,89 @@
5
  #include <cmath>
6
  #include <cstdint>
7
 
8
- static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
9
- const int warp_id = threadIdx.x / WARP_SIZE;
10
- const int lane_id = threadIdx.x % WARP_SIZE;
11
- const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
12
-
13
- const int ne_tmp = WARP_SIZE*nclasses;
14
-
15
- extern __shared__ float tmp_all[];
16
- float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
17
- float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
18
-
19
- // Each warp first loads ne_tmp logits/labels into shared memory:
20
- for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
21
- const int ig = i0*nclasses + i; // ig == i global
22
-
23
- tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
24
- tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
25
- }
26
 
27
- // Each thread in the warp then calculates the cross entropy loss for a single row.
28
- // TODO: pad in order to avoid shared memory bank conflicts.
29
 
30
  // Find maximum for softmax:
31
- float max = -INFINITY;
32
- for (int i = 0; i < nclasses; ++i) {
33
- max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
 
 
 
 
 
34
  }
 
35
 
36
  // Calculate log(softmax(logits)) which is just logits - max:
37
  float sum = 0.0f;
38
- for (int i = 0; i < nclasses; ++i) {
39
- float val = tmp_logits[lane_id*nclasses + i] - max;
40
- sum += expf(val);
41
- tmp_logits[lane_id*nclasses + i] = val;
42
  }
 
43
  sum = logf(sum);
44
 
45
  // log(exp(logits - max) / sum) = (logits - max) - log(sum)
46
  float loss = 0.0f;
47
- for (int i = 0; i < nclasses; ++i) {
48
- loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
 
49
  }
50
  loss = -warp_reduce_sum(loss) / (float)k;
51
 
52
- __syncthreads();
53
-
54
- if (lane_id == 0) {
55
- tmp_all[warp_id] = loss;
56
- }
57
-
58
- __syncthreads();
59
-
60
- if (warp_id != 0) {
61
- return;
62
- }
63
-
64
- loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
65
- loss = warp_reduce_sum(loss);
66
-
67
- if (lane_id != 0) {
68
  return;
69
  }
70
 
71
  dst[blockIdx.x] = loss;
72
  }
73
 
74
- static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
 
 
 
75
  extern __shared__ float tmp[];
76
 
 
 
 
 
77
  float maxval = -INFINITY;
78
  for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
79
- const float val = logits[blockIdx.x*nclasses + i];
80
  maxval = fmaxf(maxval, val);
81
- tmp[i] = val;
 
 
 
82
  }
83
  maxval = warp_reduce_max(maxval);
84
 
85
  float sum = 0.0f;
86
  for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
87
- const float val = expf(tmp[i] - maxval);
88
  sum += val;
89
- tmp[i] = val;
 
 
 
 
 
90
  }
91
  sum = warp_reduce_sum(sum);
92
  const float sm_scale = 1.0f/sum;
93
 
94
- const float d_by_nrows = *loss/gridDim.x;
95
  for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
96
- dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
 
97
  }
98
  }
99
 
@@ -119,48 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
119
  ggml_cuda_pool & pool = ctx.pool();
120
  cudaStream_t stream = ctx.stream();
121
 
122
- const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
123
- const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
124
- const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
 
 
 
125
 
126
  ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
127
 
128
- cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  // Combine results from individual blocks:
131
  sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
132
  }
133
 
134
  void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
135
- const ggml_tensor * src0 = dst->src[0];
136
- const ggml_tensor * src1 = dst->src[1];
137
- const ggml_tensor * opt0 = dst->src[2];
138
-
139
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
140
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
141
- GGML_ASSERT(opt0->type == GGML_TYPE_F32);
142
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
143
-
144
- GGML_ASSERT(ggml_is_contiguous(src0));
145
- GGML_ASSERT(ggml_is_contiguous(src1));
146
- GGML_ASSERT(ggml_is_contiguous(opt0));
147
  GGML_ASSERT(ggml_is_contiguous(dst));
148
- GGML_ASSERT(ggml_are_same_shape(src0, src1));
149
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
150
 
151
- const int64_t ne00 = src0->ne[0];
152
- const int64_t nrows = ggml_nrows(src0);
153
 
154
- const float * src0_d = (const float *) src0->data;
155
- const float * src1_d = (const float *) src1->data;
156
- const float * opt0_d = (const float *) opt0->data;
157
- float * dst_d = (float *) dst->data;
158
 
159
  cudaStream_t stream = ctx.stream();
160
 
161
  const dim3 blocks_dim(WARP_SIZE, 1, 1);
162
  const dim3 blocks_num(nrows, 1, 1);
163
- const int shmem = ne00*sizeof(float);
164
-
165
- cross_entropy_loss_back_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, opt0_d, dst_d, ne00);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  }
 
5
  #include <cmath>
6
  #include <cstdint>
7
 
8
+ template <bool use_shared>
9
+ static __global__ void cross_entropy_loss_f32(
10
+ const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
11
+ extern __shared__ float tmp[];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ logits += int64_t(blockIdx.x)*nclasses;
14
+ labels += int64_t(blockIdx.x)*nclasses;
15
 
16
  // Find maximum for softmax:
17
+ float max_logit = -INFINITY;
18
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
19
+ const float val = logits[i];
20
+ max_logit = fmaxf(max_logit, val);
21
+
22
+ if (use_shared) {
23
+ tmp[i] = val;
24
+ }
25
  }
26
+ max_logit = warp_reduce_max(max_logit);
27
 
28
  // Calculate log(softmax(logits)) which is just logits - max:
29
  float sum = 0.0f;
30
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
31
+ const float logit_i = use_shared ? tmp[i] : logits[i];
32
+ sum += expf(logit_i - max_logit);
 
33
  }
34
+ sum = warp_reduce_sum(sum);
35
  sum = logf(sum);
36
 
37
  // log(exp(logits - max) / sum) = (logits - max) - log(sum)
38
  float loss = 0.0f;
39
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
40
+ const float logit_i = use_shared ? tmp[i] : logits[i];
41
+ loss += (logit_i - max_logit - sum) * labels[i];
42
  }
43
  loss = -warp_reduce_sum(loss) / (float)k;
44
 
45
+ if (threadIdx.x != 0) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  return;
47
  }
48
 
49
  dst[blockIdx.x] = loss;
50
  }
51
 
52
+ template <bool use_shared>
53
+ static __global__ void cross_entropy_loss_back_f32(
54
+ const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
55
+ float * __restrict__ dst, const int nclasses) {
56
  extern __shared__ float tmp[];
57
 
58
+ logits += int64_t(blockIdx.x)*nclasses;
59
+ labels += int64_t(blockIdx.x)*nclasses;
60
+ dst += int64_t(blockIdx.x)*nclasses;
61
+
62
  float maxval = -INFINITY;
63
  for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
64
+ const float val = logits[i];
65
  maxval = fmaxf(maxval, val);
66
+
67
+ if (use_shared) {
68
+ tmp[i] = val;
69
+ }
70
  }
71
  maxval = warp_reduce_max(maxval);
72
 
73
  float sum = 0.0f;
74
  for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
75
+ const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
76
  sum += val;
77
+
78
+ if (use_shared) {
79
+ tmp[i] = val;
80
+ } else {
81
+ dst[i] = val;
82
+ }
83
  }
84
  sum = warp_reduce_sum(sum);
85
  const float sm_scale = 1.0f/sum;
86
 
87
+ const float d_by_nrows = *grad/gridDim.x;
88
  for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
89
+ const float val = use_shared ? tmp[i] : dst[i];
90
+ dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
91
  }
92
  }
93
 
 
113
  ggml_cuda_pool & pool = ctx.pool();
114
  cudaStream_t stream = ctx.stream();
115
 
116
+ const dim3 blocks_dim(WARP_SIZE, 1, 1);
117
+ const dim3 blocks_num(nrows, 1, 1);
118
+ const size_t nbytes_shared = ne00*sizeof(float);
119
+
120
+ const int id = ggml_cuda_get_device();
121
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
122
 
123
  ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
124
 
125
+ if (nbytes_shared <= smpbo) {
126
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
127
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
128
+ if (!shared_memory_limit_raised[id]) {
129
+ CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
130
+ shared_memory_limit_raised[id] = true;
131
+ }
132
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
133
+ cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
134
+ } else {
135
+ cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
136
+ }
137
+ CUDA_CHECK(cudaGetLastError());
138
 
139
  // Combine results from individual blocks:
140
  sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
141
  }
142
 
143
  void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
144
+ const ggml_tensor * grad = dst->src[0];
145
+ const ggml_tensor * src0f = dst->src[1];
146
+ const ggml_tensor * src1f = dst->src[2];
147
+
148
+ GGML_ASSERT(src0f->type == GGML_TYPE_F32);
149
+ GGML_ASSERT(src1f->type == GGML_TYPE_F32);
150
+ GGML_ASSERT( grad->type == GGML_TYPE_F32);
151
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
152
+
153
+ GGML_ASSERT(ggml_is_scalar(grad));
154
+ GGML_ASSERT(ggml_is_contiguous(src0f));
155
+ GGML_ASSERT(ggml_is_contiguous(src1f));
156
  GGML_ASSERT(ggml_is_contiguous(dst));
157
+ GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
158
+ GGML_ASSERT(ggml_are_same_shape(src0f, dst));
159
 
160
+ const int64_t ne00 = src0f->ne[0];
161
+ const int64_t nrows = ggml_nrows(src0f);
162
 
163
+ const float * grad_d = (const float *) grad->data;
164
+ const float * src0f_d = (const float *) src0f->data;
165
+ const float * src1f_d = (const float *) src1f->data;
166
+ float * dst_d = (float *) dst->data;
167
 
168
  cudaStream_t stream = ctx.stream();
169
 
170
  const dim3 blocks_dim(WARP_SIZE, 1, 1);
171
  const dim3 blocks_num(nrows, 1, 1);
172
+ const size_t nbytes_shared = ne00*sizeof(float);
173
+
174
+ const int id = ggml_cuda_get_device();
175
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
176
+
177
+ if (nbytes_shared <= smpbo) {
178
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
179
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
180
+ if (!shared_memory_limit_raised[id]) {
181
+ CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
182
+ shared_memory_limit_raised[id] = true;
183
+ }
184
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
185
+ cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
186
+ } else {
187
+ cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
188
+ }
189
  }
ggml/src/ggml-cuda/getrows.cu CHANGED
@@ -3,15 +3,15 @@
3
 
4
  template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
5
  static __global__ void k_get_rows(
6
- const void * src0, const int32_t * src1, dst_t * dst,
7
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
8
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
9
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
10
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
11
- size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
12
 
13
  const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
14
- const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
15
  const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
16
  const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
17
 
@@ -22,10 +22,10 @@ static __global__ void k_get_rows(
22
  const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
23
 
24
  dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
25
- const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
26
 
27
- const int ib = i00/qk; // block index
28
- const int iqs = (i00%qk)/qr; // quant index
29
  const int iybs = i00 - i00%qk; // dst block start index
30
  const int y_offset = qr == 1 ? 1 : qk/2;
31
 
@@ -39,15 +39,15 @@ static __global__ void k_get_rows(
39
 
40
  template<typename src0_t, typename dst_t>
41
  static __global__ void k_get_rows_float(
42
- const src0_t * src0, const int32_t * src1, dst_t * dst,
43
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
44
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
45
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
46
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
47
- size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
48
-
49
- const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
50
- const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
51
  const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
52
  const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
53
 
@@ -58,14 +58,38 @@ static __global__ void k_get_rows_float(
58
  const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
59
 
60
  dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
61
- const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
62
 
63
  dst_row[i00] = src0_row[i00];
64
  }
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  template<int qk, int qr, dequantize_kernel_t dq>
67
- static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
68
- const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
 
69
 
70
  GGML_TENSOR_BINARY_OP_LOCALS
71
 
@@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg
87
  GGML_ASSERT(ne00 % 2 == 0);
88
 
89
  k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
90
- src0_dd, src1_dd, dst_dd,
91
- ne00, /*ne01, ne02, ne03,*/
92
- /*ne10, ne11,*/ ne12, /*ne13,*/
93
- /* s0,*/ s1, s2, s3,
94
- /* nb00,*/ nb01, nb02, nb03,
95
- s10, s11, s12/*, s13*/);
96
 
97
  GGML_UNUSED(dst);
98
  }
99
 
100
  template<typename src0_t>
101
- static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
102
- const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
 
103
 
104
  GGML_TENSOR_BINARY_OP_LOCALS
105
 
 
 
106
  const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
107
  const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
108
  const dim3 block_nums(block_num_x, ne10, ne11*ne12);
@@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
119
  //const size_t s13 = nb13 / ggml_element_size(src1);
120
 
121
  k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
122
- src0_dd, src1_dd, dst_dd,
123
- ne00, /*ne01, ne02, ne03,*/
124
- /*ne10, ne11,*/ ne12, /*ne13,*/
125
- /* s0,*/ s1, s2, s3,
126
- /* nb00,*/ nb01, nb02, nb03,
127
- s10, s11, s12/*, s13*/);
128
 
129
  GGML_UNUSED(dst);
130
  }
@@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
132
  void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
133
  const ggml_tensor * src0 = dst->src[0];
134
  const ggml_tensor * src1 = dst->src[1];
135
- const float * src0_d = (const float *)src0->data;
136
- const float * src1_d = (const float *)src1->data;
137
- float * dst_d = (float *)dst->data;
138
- cudaStream_t stream = ctx.stream();
139
 
 
 
 
 
 
140
 
141
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
142
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
143
 
144
  GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
145
  GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
146
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
147
-
148
- const int32_t * src1_i32 = (const int32_t *) src1_d;
149
 
150
  switch (src0->type) {
151
  case GGML_TYPE_F16:
152
- get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
153
  break;
154
  case GGML_TYPE_F32:
155
- get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
156
  break;
157
  case GGML_TYPE_Q4_0:
158
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
159
  break;
160
  case GGML_TYPE_Q4_1:
161
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
162
  break;
163
  case GGML_TYPE_Q5_0:
164
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
165
  break;
166
  case GGML_TYPE_Q5_1:
167
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
168
  break;
169
  case GGML_TYPE_Q8_0:
170
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
171
  break;
172
  default:
173
  // TODO: k-quants
@@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
175
  break;
176
  }
177
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
5
  static __global__ void k_get_rows(
6
+ const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
7
+ const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
8
+ /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
9
+ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
10
+ /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
11
+ const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
12
 
13
  const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
14
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
15
  const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
16
  const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
17
 
 
22
  const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
23
 
24
  dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
25
+ const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
26
 
27
+ const int ib = i00/qk; // block index
28
+ const int iqs = (i00%qk)/qr; // quant index
29
  const int iybs = i00 - i00%qk; // dst block start index
30
  const int y_offset = qr == 1 ? 1 : qk/2;
31
 
 
39
 
40
  template<typename src0_t, typename dst_t>
41
  static __global__ void k_get_rows_float(
42
+ const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
43
+ const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
44
+ /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
45
+ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
46
+ /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
47
+ const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
48
+
49
+ const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
50
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
51
  const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
52
  const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
53
 
 
58
  const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
59
 
60
  dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
61
+ const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
62
 
63
  dst_row[i00] = src0_row[i00];
64
  }
65
 
66
+ template<typename grad_t, typename dst_t>
67
+ static __global__ void k_get_rows_back_float(
68
+ const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
69
+ const int col = blockIdx.x*blockDim.x + threadIdx.x;
70
+
71
+ if (col >= ncols) {
72
+ return;
73
+ }
74
+
75
+ const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
76
+
77
+ float sum = 0.0f;
78
+
79
+ for (int64_t i = 0; i < nrows_grad; ++i) {
80
+ if (rows[i] != dst_row) {
81
+ continue;
82
+ }
83
+ sum += grad[i*ncols + col];
84
+ }
85
+
86
+ dst[dst_row*ncols + col] = sum;
87
+ }
88
+
89
  template<int qk, int qr, dequantize_kernel_t dq>
90
+ static void get_rows_cuda(
91
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
92
+ const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
93
 
94
  GGML_TENSOR_BINARY_OP_LOCALS
95
 
 
111
  GGML_ASSERT(ne00 % 2 == 0);
112
 
113
  k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
114
+ src0_dd, src1_dd, dst_dd,
115
+ ne00, /*ne01, ne02, ne03,*/
116
+ /*ne10, ne11,*/ ne12, /*ne13,*/
117
+ /* s0,*/ s1, s2, s3,
118
+ /* nb00,*/ nb01, nb02, nb03,
119
+ s10, s11, s12/*, s13*/);
120
 
121
  GGML_UNUSED(dst);
122
  }
123
 
124
  template<typename src0_t>
125
+ static void get_rows_cuda_float(
126
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
127
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
128
 
129
  GGML_TENSOR_BINARY_OP_LOCALS
130
 
131
+ GGML_ASSERT(ne13 == 1);
132
+
133
  const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
134
  const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
135
  const dim3 block_nums(block_num_x, ne10, ne11*ne12);
 
146
  //const size_t s13 = nb13 / ggml_element_size(src1);
147
 
148
  k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
149
+ src0_dd, src1_dd, dst_dd,
150
+ ne00, /*ne01, ne02, ne03,*/
151
+ /*ne10, ne11,*/ ne12, /*ne13,*/
152
+ /* s0,*/ s1, s2, s3,
153
+ /* nb00,*/ nb01, nb02, nb03,
154
+ s10, s11, s12/*, s13*/);
155
 
156
  GGML_UNUSED(dst);
157
  }
 
159
  void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160
  const ggml_tensor * src0 = dst->src[0];
161
  const ggml_tensor * src1 = dst->src[1];
 
 
 
 
162
 
163
+ const void * src0_d = (const void *) src0->data;
164
+ const int32_t * src1_d = (const int32_t *) src1->data;
165
+ float * dst_d = (float *) dst->data;
166
+
167
+ cudaStream_t stream = ctx.stream();
168
 
169
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
170
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
171
 
172
  GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
173
  GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
174
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
 
 
175
 
176
  switch (src0->type) {
177
  case GGML_TYPE_F16:
178
+ get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
179
  break;
180
  case GGML_TYPE_F32:
181
+ get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
182
  break;
183
  case GGML_TYPE_Q4_0:
184
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
185
  break;
186
  case GGML_TYPE_Q4_1:
187
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
188
  break;
189
  case GGML_TYPE_Q5_0:
190
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
191
  break;
192
  case GGML_TYPE_Q5_1:
193
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
194
  break;
195
  case GGML_TYPE_Q8_0:
196
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
197
  break;
198
  default:
199
  // TODO: k-quants
 
201
  break;
202
  }
203
  }
204
+
205
+ void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
206
+ const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
207
+ const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
208
+
209
+ GGML_TENSOR_BINARY_OP_LOCALS
210
+
211
+ const float * src0_d = (const float *) src0->data;
212
+ const int32_t * src1_d = (const int32_t *) src1->data;
213
+ float * dst_d = (float *) dst->data;
214
+
215
+ cudaStream_t stream = ctx.stream();
216
+
217
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
218
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
219
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
220
+
221
+ GGML_ASSERT(ggml_is_contiguous(src0));
222
+ GGML_ASSERT(ggml_is_contiguous(src1));
223
+ GGML_ASSERT(ggml_is_contiguous(dst));
224
+
225
+ GGML_ASSERT(ne02*ne03 == 1);
226
+ GGML_ASSERT(ne12*ne13 == 1);
227
+ GGML_ASSERT(ne2*ne3 == 1);
228
+
229
+ const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
230
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
231
+ const dim3 block_nums(block_num_x, ne1, 1);
232
+
233
+ k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
234
+ }
ggml/src/ggml-cuda/getrows.cuh CHANGED
@@ -1,5 +1,8 @@
1
  #include "common.cuh"
2
 
3
  #define CUDA_GET_ROWS_BLOCK_SIZE 256
 
4
 
5
  void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
1
  #include "common.cuh"
2
 
3
  #define CUDA_GET_ROWS_BLOCK_SIZE 256
4
+ #define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
5
 
6
  void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
7
+
8
+ void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -2003,6 +2003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2003
  case GGML_OP_GET_ROWS:
2004
  ggml_cuda_op_get_rows(ctx, dst);
2005
  break;
 
 
 
2006
  case GGML_OP_DUP:
2007
  ggml_cuda_dup(ctx, dst);
2008
  break;
@@ -2091,9 +2094,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2091
  case GGML_OP_LEAKY_RELU:
2092
  ggml_cuda_op_leaky_relu(ctx, dst);
2093
  break;
 
 
 
2094
  case GGML_OP_RMS_NORM:
2095
  ggml_cuda_op_rms_norm(ctx, dst);
2096
  break;
 
 
 
2097
  case GGML_OP_MUL_MAT:
2098
  if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
2099
  GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
@@ -2138,6 +2147,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2138
  case GGML_OP_SOFT_MAX:
2139
  ggml_cuda_op_soft_max(ctx, dst);
2140
  break;
 
 
 
2141
  case GGML_OP_ROPE:
2142
  ggml_cuda_op_rope(ctx, dst);
2143
  break;
@@ -2912,7 +2924,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
2912
  }
2913
  } break;
2914
  case GGML_OP_OUT_PROD:
2915
- return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
2916
  case GGML_OP_GET_ROWS:
2917
  {
2918
  switch (op->src[0]->type) {
@@ -2928,6 +2940,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
2928
  return false;
2929
  }
2930
  } break;
 
 
 
 
2931
  case GGML_OP_CPY:
2932
  {
2933
  ggml_type src0_type = op->src[0]->type;
@@ -3001,8 +3017,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3001
  }
3002
  return false;
3003
  } break;
 
 
 
3004
  case GGML_OP_NORM:
3005
  case GGML_OP_RMS_NORM:
 
3006
  return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
3007
  break;
3008
  case GGML_OP_NONE:
@@ -3027,6 +3047,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3027
  case GGML_OP_DIAG_MASK_INF:
3028
  case GGML_OP_SOFT_MAX:
3029
  return true;
 
 
 
 
 
3030
  case GGML_OP_ROPE:
3031
  case GGML_OP_ROPE_BACK: {
3032
  const size_t ts = ggml_type_size(op->src[0]->type);
 
2003
  case GGML_OP_GET_ROWS:
2004
  ggml_cuda_op_get_rows(ctx, dst);
2005
  break;
2006
+ case GGML_OP_GET_ROWS_BACK:
2007
+ ggml_cuda_op_get_rows_back(ctx, dst);
2008
+ break;
2009
  case GGML_OP_DUP:
2010
  ggml_cuda_dup(ctx, dst);
2011
  break;
 
2094
  case GGML_OP_LEAKY_RELU:
2095
  ggml_cuda_op_leaky_relu(ctx, dst);
2096
  break;
2097
+ case GGML_OP_SILU_BACK:
2098
+ ggml_cuda_op_silu_back(ctx, dst);
2099
+ break;
2100
  case GGML_OP_RMS_NORM:
2101
  ggml_cuda_op_rms_norm(ctx, dst);
2102
  break;
2103
+ case GGML_OP_RMS_NORM_BACK:
2104
+ ggml_cuda_op_rms_norm_back(ctx, dst);
2105
+ break;
2106
  case GGML_OP_MUL_MAT:
2107
  if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
2108
  GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
 
2147
  case GGML_OP_SOFT_MAX:
2148
  ggml_cuda_op_soft_max(ctx, dst);
2149
  break;
2150
+ case GGML_OP_SOFT_MAX_BACK:
2151
+ ggml_cuda_op_soft_max_back(ctx, dst);
2152
+ break;
2153
  case GGML_OP_ROPE:
2154
  ggml_cuda_op_rope(ctx, dst);
2155
  break;
 
2924
  }
2925
  } break;
2926
  case GGML_OP_OUT_PROD:
2927
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
2928
  case GGML_OP_GET_ROWS:
2929
  {
2930
  switch (op->src[0]->type) {
 
2940
  return false;
2941
  }
2942
  } break;
2943
+ case GGML_OP_GET_ROWS_BACK:
2944
+ {
2945
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
2946
+ } break;
2947
  case GGML_OP_CPY:
2948
  {
2949
  ggml_type src0_type = op->src[0]->type;
 
3017
  }
3018
  return false;
3019
  } break;
3020
+ case GGML_OP_SILU_BACK:
3021
+ return ggml_is_contiguous(op->src[0]);
3022
+ break;
3023
  case GGML_OP_NORM:
3024
  case GGML_OP_RMS_NORM:
3025
+ case GGML_OP_RMS_NORM_BACK:
3026
  return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
3027
  break;
3028
  case GGML_OP_NONE:
 
3047
  case GGML_OP_DIAG_MASK_INF:
3048
  case GGML_OP_SOFT_MAX:
3049
  return true;
3050
+ case GGML_OP_SOFT_MAX_BACK: {
3051
+ float max_bias = 0.0f;
3052
+ memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
3053
+ return max_bias == 0.0f;
3054
+ }
3055
  case GGML_OP_ROPE:
3056
  case GGML_OP_ROPE_BACK: {
3057
  const size_t ts = ggml_type_size(op->src[0]->type);
ggml/src/ggml-cuda/norm.cu CHANGED
@@ -5,20 +5,24 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
5
  const int row = blockIdx.x*blockDim.y + threadIdx.y;
6
  const int tid = threadIdx.x;
7
 
8
- float2 mean_var = make_float2(0.f, 0.f);
 
 
 
9
 
10
  for (int col = tid; col < ncols; col += block_size) {
11
- const float xi = x[row*ncols + col];
12
  mean_var.x += xi;
13
  mean_var.y += xi * xi;
14
  }
15
 
16
  // sum up partial sums
17
  mean_var = warp_reduce_sum(mean_var);
18
- if (block_size > WARP_SIZE) {
 
19
  __shared__ float2 s_sum[32];
20
- int warp_id = threadIdx.x / WARP_SIZE;
21
- int lane_id = threadIdx.x % WARP_SIZE;
22
  if (lane_id == 0) {
23
  s_sum[warp_id] = mean_var;
24
  }
@@ -32,7 +36,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
32
  const float inv_std = rsqrtf(var + eps);
33
 
34
  for (int col = tid; col < ncols; col += block_size) {
35
- dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
36
  }
37
  }
38
 
@@ -40,14 +44,8 @@ template <int block_size>
40
  static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
41
  // blockIdx.x: num_groups idx
42
  // threadIdx.x: block_size idx
43
- int start = blockIdx.x * group_size;
44
- int end = start + group_size;
45
-
46
- start += threadIdx.x;
47
-
48
- if (end >= ne_elements) {
49
- end = ne_elements;
50
- }
51
 
52
  float tmp = 0.0f; // partial sum for thread in warp
53
 
@@ -56,10 +54,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
56
  }
57
 
58
  tmp = warp_reduce_sum(tmp);
59
- if (block_size > WARP_SIZE) {
 
60
  __shared__ float s_sum[32];
61
- int warp_id = threadIdx.x / WARP_SIZE;
62
- int lane_id = threadIdx.x % WARP_SIZE;
63
  if (lane_id == 0) {
64
  s_sum[warp_id] = tmp;
65
  }
@@ -68,11 +67,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
68
  tmp = warp_reduce_sum(tmp);
69
  }
70
 
71
- float mean = tmp / group_size;
72
  tmp = 0.0f;
73
 
74
  for (int j = start; j < end; j += block_size) {
75
- float xi = x[j] - mean;
76
  dst[j] = xi;
77
  tmp += xi * xi;
78
  }
@@ -80,8 +79,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
80
  tmp = warp_reduce_sum(tmp);
81
  if (block_size > WARP_SIZE) {
82
  __shared__ float s_sum[32];
83
- int warp_id = threadIdx.x / WARP_SIZE;
84
- int lane_id = threadIdx.x % WARP_SIZE;
85
  if (lane_id == 0) {
86
  s_sum[warp_id] = tmp;
87
  }
@@ -90,8 +89,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
90
  tmp = warp_reduce_sum(tmp);
91
  }
92
 
93
- float variance = tmp / group_size;
94
- float scale = rsqrtf(variance + eps);
95
  for (int j = start; j < end; j += block_size) {
96
  dst[j] *= scale;
97
  }
@@ -102,19 +101,23 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
102
  const int row = blockIdx.x*blockDim.y + threadIdx.y;
103
  const int tid = threadIdx.x;
104
 
 
 
 
105
  float tmp = 0.0f; // partial sum for thread in warp
106
 
107
  for (int col = tid; col < ncols; col += block_size) {
108
- const float xi = x[row*ncols + col];
109
  tmp += xi * xi;
110
  }
111
 
112
  // sum up partial sums
113
  tmp = warp_reduce_sum(tmp);
114
- if (block_size > WARP_SIZE) {
 
115
  __shared__ float s_sum[32];
116
- int warp_id = threadIdx.x / WARP_SIZE;
117
- int lane_id = threadIdx.x % WARP_SIZE;
118
  if (lane_id == 0) {
119
  s_sum[warp_id] = tmp;
120
  }
@@ -127,12 +130,63 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
127
  const float scale = rsqrtf(mean + eps);
128
 
129
  for (int col = tid; col < ncols; col += block_size) {
130
- dst[row*ncols + col] = scale * x[row*ncols + col];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  }
132
  }
133
 
134
  static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
135
- GGML_ASSERT(ncols % WARP_SIZE == 0);
136
  if (ncols < 1024) {
137
  const dim3 block_dims(WARP_SIZE, 1, 1);
138
  norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
@@ -142,7 +196,8 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
142
  }
143
  }
144
 
145
- static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
 
146
  if (group_size < 1024) {
147
  const dim3 block_dims(WARP_SIZE, 1, 1);
148
  group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
@@ -153,7 +208,6 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
153
  }
154
 
155
  static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
156
- GGML_ASSERT(ncols % WARP_SIZE == 0);
157
  if (ncols < 1024) {
158
  const dim3 block_dims(WARP_SIZE, 1, 1);
159
  rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
@@ -163,6 +217,16 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
163
  }
164
  }
165
 
 
 
 
 
 
 
 
 
 
 
166
  void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
167
  const ggml_tensor * src0 = dst->src[0];
168
  const float * src0_d = (const float *)src0->data;
@@ -179,6 +243,7 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
179
 
180
  float eps;
181
  memcpy(&eps, dst->op_params, sizeof(float));
 
182
 
183
  norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
184
  }
@@ -198,6 +263,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
198
 
199
  float eps;
200
  memcpy(&eps, dst->op_params + 1, sizeof(float));
 
201
 
202
  int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
203
  group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
@@ -219,6 +285,33 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
219
 
220
  float eps;
221
  memcpy(&eps, dst->op_params, sizeof(float));
 
222
 
223
  rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
224
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  const int row = blockIdx.x*blockDim.y + threadIdx.y;
6
  const int tid = threadIdx.x;
7
 
8
+ x += int64_t(row)*ncols;
9
+ dst += int64_t(row)*ncols;
10
+
11
+ float2 mean_var = make_float2(0.0f, 0.0f);
12
 
13
  for (int col = tid; col < ncols; col += block_size) {
14
+ const float xi = x[col];
15
  mean_var.x += xi;
16
  mean_var.y += xi * xi;
17
  }
18
 
19
  // sum up partial sums
20
  mean_var = warp_reduce_sum(mean_var);
21
+ if constexpr (block_size > WARP_SIZE) {
22
+ static_assert(block_size == 1024, "unexpected block_size");
23
  __shared__ float2 s_sum[32];
24
+ const int warp_id = threadIdx.x / WARP_SIZE;
25
+ const int lane_id = threadIdx.x % WARP_SIZE;
26
  if (lane_id == 0) {
27
  s_sum[warp_id] = mean_var;
28
  }
 
36
  const float inv_std = rsqrtf(var + eps);
37
 
38
  for (int col = tid; col < ncols; col += block_size) {
39
+ dst[col] = (x[col] - mean) * inv_std;
40
  }
41
  }
42
 
 
44
  static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
45
  // blockIdx.x: num_groups idx
46
  // threadIdx.x: block_size idx
47
+ const int start = blockIdx.x*group_size + threadIdx.x;
48
+ const int end = min(blockIdx.x*group_size + group_size, ne_elements);
 
 
 
 
 
 
49
 
50
  float tmp = 0.0f; // partial sum for thread in warp
51
 
 
54
  }
55
 
56
  tmp = warp_reduce_sum(tmp);
57
+ if constexpr (block_size > WARP_SIZE) {
58
+ static_assert(block_size == 1024, "unexpected block_size");
59
  __shared__ float s_sum[32];
60
+ const int warp_id = threadIdx.x / WARP_SIZE;
61
+ const int lane_id = threadIdx.x % WARP_SIZE;
62
  if (lane_id == 0) {
63
  s_sum[warp_id] = tmp;
64
  }
 
67
  tmp = warp_reduce_sum(tmp);
68
  }
69
 
70
+ const float mean = tmp / group_size;
71
  tmp = 0.0f;
72
 
73
  for (int j = start; j < end; j += block_size) {
74
+ const float xi = x[j] - mean;
75
  dst[j] = xi;
76
  tmp += xi * xi;
77
  }
 
79
  tmp = warp_reduce_sum(tmp);
80
  if (block_size > WARP_SIZE) {
81
  __shared__ float s_sum[32];
82
+ const int warp_id = threadIdx.x / WARP_SIZE;
83
+ const int lane_id = threadIdx.x % WARP_SIZE;
84
  if (lane_id == 0) {
85
  s_sum[warp_id] = tmp;
86
  }
 
89
  tmp = warp_reduce_sum(tmp);
90
  }
91
 
92
+ const float variance = tmp / group_size;
93
+ const float scale = rsqrtf(variance + eps);
94
  for (int j = start; j < end; j += block_size) {
95
  dst[j] *= scale;
96
  }
 
101
  const int row = blockIdx.x*blockDim.y + threadIdx.y;
102
  const int tid = threadIdx.x;
103
 
104
+ x += int64_t(row)*ncols;
105
+ dst += int64_t(row)*ncols;
106
+
107
  float tmp = 0.0f; // partial sum for thread in warp
108
 
109
  for (int col = tid; col < ncols; col += block_size) {
110
+ const float xi = x[col];
111
  tmp += xi * xi;
112
  }
113
 
114
  // sum up partial sums
115
  tmp = warp_reduce_sum(tmp);
116
+ if constexpr (block_size > WARP_SIZE) {
117
+ static_assert(block_size == 1024, "unexpected block_size");
118
  __shared__ float s_sum[32];
119
+ const int warp_id = threadIdx.x / WARP_SIZE;
120
+ const int lane_id = threadIdx.x % WARP_SIZE;
121
  if (lane_id == 0) {
122
  s_sum[warp_id] = tmp;
123
  }
 
130
  const float scale = rsqrtf(mean + eps);
131
 
132
  for (int col = tid; col < ncols; col += block_size) {
133
+ dst[col] = scale * x[col];
134
+ }
135
+ }
136
+
137
+ template <int block_size>
138
+ static __global__ void rms_norm_back_f32(
139
+ const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
140
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
141
+ const int tid = threadIdx.x;
142
+
143
+ grad += int64_t(row)*ncols;
144
+ xf += int64_t(row)*ncols;
145
+ dst += int64_t(row)*ncols;
146
+
147
+ float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
148
+ float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
149
+
150
+ for (int col = tid; col < ncols; col += block_size) {
151
+ const float xfi = xf[col];
152
+ sum_xx += xfi * xfi;
153
+ sum_xg += xfi * grad[col];
154
+ }
155
+
156
+ // sum up partial sums
157
+ sum_xx = warp_reduce_sum(sum_xx);
158
+ sum_xg = warp_reduce_sum(sum_xg);
159
+ if constexpr (block_size > WARP_SIZE) {
160
+ static_assert(block_size == 1024, "unexpected block_size");
161
+ __shared__ float s_sum_xx[32];
162
+ __shared__ float s_sum_xg[32];
163
+ const int warp_id = threadIdx.x / WARP_SIZE;
164
+ const int lane_id = threadIdx.x % WARP_SIZE;
165
+ if (lane_id == 0) {
166
+ s_sum_xx[warp_id] = sum_xx;
167
+ s_sum_xg[warp_id] = sum_xg;
168
+ }
169
+ __syncthreads();
170
+
171
+ sum_xx = s_sum_xx[lane_id];
172
+ sum_xx = warp_reduce_sum(sum_xx);
173
+
174
+ sum_xg = s_sum_xg[lane_id];
175
+ sum_xg = warp_reduce_sum(sum_xg);
176
+ }
177
+
178
+ const float mean_eps = sum_xx / ncols + eps;
179
+ const float sum_eps = sum_xx + ncols*eps;
180
+
181
+ const float scale_grad = rsqrtf(mean_eps);
182
+ const float scale_x = -scale_grad * sum_xg/sum_eps;
183
+
184
+ for (int col = tid; col < ncols; col += block_size) {
185
+ dst[col] = scale_grad*grad[col] + scale_x*xf[col];
186
  }
187
  }
188
 
189
  static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
 
190
  if (ncols < 1024) {
191
  const dim3 block_dims(WARP_SIZE, 1, 1);
192
  norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
 
196
  }
197
  }
198
 
199
+ static void group_norm_f32_cuda(
200
+ const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
201
  if (group_size < 1024) {
202
  const dim3 block_dims(WARP_SIZE, 1, 1);
203
  group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
 
208
  }
209
 
210
  static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
 
211
  if (ncols < 1024) {
212
  const dim3 block_dims(WARP_SIZE, 1, 1);
213
  rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
 
217
  }
218
  }
219
 
220
+ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
221
+ if (ncols < 1024) {
222
+ const dim3 block_dims(WARP_SIZE, 1, 1);
223
+ rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
224
+ } else {
225
+ const dim3 block_dims(1024, 1, 1);
226
+ rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
227
+ }
228
+ }
229
+
230
  void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
231
  const ggml_tensor * src0 = dst->src[0];
232
  const float * src0_d = (const float *)src0->data;
 
243
 
244
  float eps;
245
  memcpy(&eps, dst->op_params, sizeof(float));
246
+ GGML_ASSERT(eps >= 0.0f);
247
 
248
  norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
249
  }
 
263
 
264
  float eps;
265
  memcpy(&eps, dst->op_params + 1, sizeof(float));
266
+ GGML_ASSERT(eps >= 0.0f);
267
 
268
  int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
269
  group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
 
285
 
286
  float eps;
287
  memcpy(&eps, dst->op_params, sizeof(float));
288
+ GGML_ASSERT(eps >= 0.0f);
289
 
290
  rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
291
  }
292
+
293
+ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
294
+ const ggml_tensor * grad = dst->src[0]; // gradients
295
+ const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
296
+
297
+ const float * grad_d = (const float *) grad->data;
298
+ const float * src0f_d = (const float *) src0f->data;
299
+ float * dst_d = (float *) dst->data;
300
+
301
+ cudaStream_t stream = ctx.stream();
302
+
303
+ GGML_ASSERT(ggml_is_contiguous(grad));
304
+
305
+ GGML_ASSERT( grad->type == GGML_TYPE_F32);
306
+ GGML_ASSERT(src0f->type == GGML_TYPE_F32);
307
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
308
+
309
+ const int64_t ne00 = src0f->ne[0];
310
+ const int64_t nrows = ggml_nrows(src0f);
311
+
312
+ float eps;
313
+ memcpy(&eps, dst->op_params, sizeof(float));
314
+ GGML_ASSERT(eps >= 0.0f);
315
+
316
+ rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
317
+ }
ggml/src/ggml-cuda/norm.cuh CHANGED
@@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5
  void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6
 
7
  void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
5
  void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6
 
7
  void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8
+
9
+ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/out-prod.cu CHANGED
@@ -11,16 +11,15 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
11
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
12
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
13
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
14
- GGML_ASSERT(ggml_is_contiguous(src0));
15
- GGML_ASSERT(ggml_is_contiguous(dst));
16
 
17
  GGML_ASSERT(ne01 == ne11);
18
  GGML_ASSERT(ne0 == ne00);
19
  GGML_ASSERT(ne1 == ne10);
20
 
21
- GGML_ASSERT(ne2 == src0->ne[2]);
 
 
22
  GGML_ASSERT(ne2 == src1->ne[2]);
23
- GGML_ASSERT(ne3 == src0->ne[3]);
24
  GGML_ASSERT(ne3 == src1->ne[3]);
25
 
26
  const float * src0_d = (const float *) src0->data;
@@ -33,8 +32,6 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
33
  const float alpha = 1.0f;
34
  const float beta = 0.0f;
35
 
36
- GGML_ASSERT(ne2 == 1);
37
- GGML_ASSERT(ne3 == 1);
38
  CUBLAS_CHECK(cublasSetStream(handle, stream));
39
 
40
  const bool src1_T = ggml_is_transposed(src1);
@@ -42,10 +39,27 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
42
  const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
43
  GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
44
 
45
- CUBLAS_CHECK(
46
- cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
47
- ne0, ne1, ne01,
48
- &alpha, src0_d, ne00,
49
- src1_d, ldb,
50
- &beta, dst_d, ne0));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  }
 
11
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
12
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
13
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
 
14
 
15
  GGML_ASSERT(ne01 == ne11);
16
  GGML_ASSERT(ne0 == ne00);
17
  GGML_ASSERT(ne1 == ne10);
18
 
19
+ GGML_ASSERT(ne2 % src0->ne[2] == 0);
20
+ GGML_ASSERT(ne3 % src0->ne[3] == 0);
21
+
22
  GGML_ASSERT(ne2 == src1->ne[2]);
 
23
  GGML_ASSERT(ne3 == src1->ne[3]);
24
 
25
  const float * src0_d = (const float *) src0->data;
 
32
  const float alpha = 1.0f;
33
  const float beta = 0.0f;
34
 
 
 
35
  CUBLAS_CHECK(cublasSetStream(handle, stream));
36
 
37
  const bool src1_T = ggml_is_transposed(src1);
 
39
  const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
40
  GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
41
 
42
+ // data strides in dimensions 2/3
43
+ const size_t s02 = nb02 / sizeof(float);
44
+ const size_t s03 = nb03 / sizeof(float);
45
+ const size_t s12 = nb12 / sizeof(float);
46
+ const size_t s13 = nb13 / sizeof(float);
47
+ const size_t s2 = nb2 / sizeof(float);
48
+ const size_t s3 = nb3 / sizeof(float);
49
+
50
+ // dps == dst per src0, used for group query attention
51
+ const int64_t dps2 = ne2 / ne02;
52
+ const int64_t dps3 = ne3 / ne03;
53
+
54
+ // TODO batched matrix multiplication
55
+ for (int64_t i3 = 0; i3 < ne3; ++i3) {
56
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
57
+ CUBLAS_CHECK(
58
+ cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
59
+ ne0, ne1, ne01,
60
+ &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
61
+ src1_d + i3 *s13 + i2 *s12, ldb,
62
+ &beta, dst_d + i3 *s3 + i2 *s2, ne0));
63
+ }
64
+ }
65
  }
ggml/src/ggml-cuda/rope.cu CHANGED
@@ -39,9 +39,9 @@ static __device__ void rope_yarn(
39
 
40
  template<bool forward, bool has_ff, typename T>
41
  static __global__ void rope_norm(
42
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
43
- const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
44
- const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
45
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
46
 
47
  if (i0 >= ne0) {
@@ -83,9 +83,9 @@ static __global__ void rope_norm(
83
 
84
  template<bool forward, bool has_ff, typename T>
85
  static __global__ void rope_neox(
86
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
87
- const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
88
- const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
89
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
90
 
91
  if (i0 >= ne0) {
@@ -127,9 +127,9 @@ static __global__ void rope_neox(
127
 
128
  template<bool forward, bool has_ff, typename T>
129
  static __global__ void rope_multi(
130
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
131
- const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
132
- const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
133
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
134
 
135
  if (i0 >= ne0) {
@@ -187,9 +187,9 @@ static __global__ void rope_multi(
187
 
188
  template<bool forward, bool has_ff, typename T>
189
  static __global__ void rope_vision(
190
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
191
- const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
192
- const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
193
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
194
 
195
  if (i0 >= ne0) {
@@ -234,9 +234,9 @@ static __global__ void rope_vision(
234
 
235
  template<bool forward, typename T>
236
  static void rope_norm_cuda(
237
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
238
- const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
239
- const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
240
  GGML_ASSERT(ne0 % 2 == 0);
241
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
242
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -257,9 +257,9 @@ static void rope_norm_cuda(
257
 
258
  template<bool forward, typename T>
259
  static void rope_neox_cuda(
260
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
261
- const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
262
- const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
263
  GGML_ASSERT(ne0 % 2 == 0);
264
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
265
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -280,9 +280,9 @@ static void rope_neox_cuda(
280
 
281
  template<bool forward, typename T>
282
  static void rope_multi_cuda(
283
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
284
- const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
285
- const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
286
  GGML_ASSERT(ne0 % 2 == 0);
287
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
288
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -303,9 +303,9 @@ static void rope_multi_cuda(
303
 
304
  template<bool forward, typename T>
305
  static void rope_vision_cuda(
306
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
307
- const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
308
- const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
309
  GGML_ASSERT(ne0 % 2 == 0);
310
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
311
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
 
39
 
40
  template<bool forward, bool has_ff, typename T>
41
  static __global__ void rope_norm(
42
+ const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
43
+ const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
44
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
45
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
46
 
47
  if (i0 >= ne0) {
 
83
 
84
  template<bool forward, bool has_ff, typename T>
85
  static __global__ void rope_neox(
86
+ const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
87
+ const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
88
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
89
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
90
 
91
  if (i0 >= ne0) {
 
127
 
128
  template<bool forward, bool has_ff, typename T>
129
  static __global__ void rope_multi(
130
+ const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
131
+ const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
132
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
133
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
134
 
135
  if (i0 >= ne0) {
 
187
 
188
  template<bool forward, bool has_ff, typename T>
189
  static __global__ void rope_vision(
190
+ const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
191
+ const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
192
+ const float theta_scale, const float * freq_factors, const mrope_sections sections) {
193
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
194
 
195
  if (i0 >= ne0) {
 
234
 
235
  template<bool forward, typename T>
236
  static void rope_norm_cuda(
237
+ const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
238
+ const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
239
+ const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
240
  GGML_ASSERT(ne0 % 2 == 0);
241
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
242
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
 
257
 
258
  template<bool forward, typename T>
259
  static void rope_neox_cuda(
260
+ const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
261
+ const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
262
+ const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
263
  GGML_ASSERT(ne0 % 2 == 0);
264
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
265
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
 
280
 
281
  template<bool forward, typename T>
282
  static void rope_multi_cuda(
283
+ const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
284
+ const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
285
+ const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
286
  GGML_ASSERT(ne0 % 2 == 0);
287
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
288
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
 
303
 
304
  template<bool forward, typename T>
305
  static void rope_vision_cuda(
306
+ const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
307
+ const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
308
+ const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
309
  GGML_ASSERT(ne0 % 2 == 0);
310
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
311
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
ggml/src/ggml-cuda/softmax.cu CHANGED
@@ -1,5 +1,7 @@
1
  #include "common.cuh"
 
2
  #include "softmax.cuh"
 
3
 
4
  template <typename T>
5
  static __device__ __forceinline__ float t2f32(T val) {
@@ -11,14 +13,20 @@ __device__ float __forceinline__ t2f32<half>(half val) {
11
  return __half2float(val);
12
  }
13
 
14
- template <bool vals_smem, int ncols_template, int block_size_template, typename T>
15
- static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
 
 
16
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
17
 
18
  const int tid = threadIdx.x;
19
  const int rowx = blockIdx.x;
20
  const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
21
 
 
 
 
 
22
  const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
23
 
24
  const int warp_id = threadIdx.x / WARP_SIZE;
@@ -29,7 +37,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
29
  extern __shared__ float data_soft_max_f32[];
30
  float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
31
  // shared memory buffer to cache values between iterations:
32
- float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
33
 
34
  float max_val = -INFINITY;
35
 
@@ -41,10 +49,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
41
  break;
42
  }
43
 
44
- const int64_t ix = (int64_t)rowx*ncols + col;
45
- const int64_t iy = (int64_t)rowy*ncols + col;
46
-
47
- const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
48
 
49
  vals[col] = val;
50
  max_val = max(max_val, val);
@@ -110,8 +115,29 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
110
  return;
111
  }
112
 
113
- const int64_t idst = (int64_t)rowx*ncols + col;
114
- dst[idst] = vals[col] * inv_sum;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  }
116
  }
117
 
@@ -121,7 +147,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
121
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
122
  const dim3 block_dims(nth, 1, 1);
123
  const dim3 block_nums(nrows_x, 1, 1);
124
- const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
125
  static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
126
 
127
  const uint32_t n_head = nrows_x/nrows_y;
@@ -131,50 +157,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
131
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
132
 
133
  // FIXME: this limit could be raised by ~2-4x on Ampere or newer
134
- if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
135
  switch (ncols_x) {
136
  case 32:
137
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
 
138
  break;
139
  case 64:
140
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
 
141
  break;
142
  case 128:
143
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
 
144
  break;
145
  case 256:
146
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
 
147
  break;
148
  case 512:
149
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
 
150
  break;
151
  case 1024:
152
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
 
153
  break;
154
  case 2048:
155
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
 
156
  break;
157
  case 4096:
158
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
 
159
  break;
160
  default:
161
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
 
162
  break;
163
  }
164
  } else {
165
- const size_t shmem_low = WARP_SIZE*sizeof(float);
166
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
167
  }
168
  }
169
 
 
 
 
 
 
 
 
 
 
170
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
171
  const ggml_tensor * src0 = dst->src[0];
172
  const ggml_tensor * src1 = dst->src[1];
173
 
174
- const float * src0_d = (const float *)src0->data;
175
- const void * src1_d = src1 ? (const void *)src1->data : nullptr;
 
176
 
177
- float * dst_d = (float *)dst->data;
178
  cudaStream_t stream = ctx.stream();
179
 
180
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -189,18 +233,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
189
  float scale = 1.0f;
190
  float max_bias = 0.0f;
191
 
192
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
193
- memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
194
 
195
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
196
 
197
  if (use_f16) {
198
- const half * src1_dd = (const half *)src1_d;
199
-
200
- soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
201
  } else {
202
- const float * src1_dd = (const float *)src1_d;
203
-
204
- soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
205
  }
206
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #include "common.cuh"
2
+ #include "ggml.h"
3
  #include "softmax.cuh"
4
+ #include <cstdint>
5
 
6
  template <typename T>
7
  static __device__ __forceinline__ float t2f32(T val) {
 
13
  return __half2float(val);
14
  }
15
 
16
+ template <bool use_shared, int ncols_template, int block_size_template, typename T>
17
+ static __global__ void soft_max_f32(
18
+ const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
19
+ const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
20
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
21
 
22
  const int tid = threadIdx.x;
23
  const int rowx = blockIdx.x;
24
  const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
25
 
26
+ x += int64_t(rowx)*ncols;
27
+ mask += int64_t(rowy)*ncols * (mask != nullptr);
28
+ dst += int64_t(rowx)*ncols;
29
+
30
  const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
31
 
32
  const int warp_id = threadIdx.x / WARP_SIZE;
 
37
  extern __shared__ float data_soft_max_f32[];
38
  float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
39
  // shared memory buffer to cache values between iterations:
40
+ float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
41
 
42
  float max_val = -INFINITY;
43
 
 
49
  break;
50
  }
51
 
52
+ const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
 
 
 
53
 
54
  vals[col] = val;
55
  max_val = max(max_val, val);
 
115
  return;
116
  }
117
 
118
+ dst[col] = vals[col] * inv_sum;
119
+ }
120
+ }
121
+
122
+ static __global__ void soft_max_back_f32(
123
+ const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
124
+ const int tid = threadIdx.x;
125
+ const int rowx = blockIdx.x;
126
+
127
+ grad += int64_t(rowx)*ncols;
128
+ dstf += int64_t(rowx)*ncols;
129
+ dst += int64_t(rowx)*ncols;
130
+
131
+ float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
132
+
133
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
134
+ dgf_dot += dstf[col]*grad[col];
135
+ }
136
+
137
+ dgf_dot = warp_reduce_sum(dgf_dot);
138
+
139
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
140
+ dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
141
  }
142
  }
143
 
 
147
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
148
  const dim3 block_dims(nth, 1, 1);
149
  const dim3 block_nums(nrows_x, 1, 1);
150
+ const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
151
  static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
152
 
153
  const uint32_t n_head = nrows_x/nrows_y;
 
157
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
158
 
159
  // FIXME: this limit could be raised by ~2-4x on Ampere or newer
160
+ if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
161
  switch (ncols_x) {
162
  case 32:
163
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
164
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
165
  break;
166
  case 64:
167
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
168
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
169
  break;
170
  case 128:
171
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
172
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
173
  break;
174
  case 256:
175
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
176
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
177
  break;
178
  case 512:
179
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
180
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
181
  break;
182
  case 1024:
183
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
184
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
185
  break;
186
  case 2048:
187
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
188
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
189
  break;
190
  case 4096:
191
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
192
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
193
  break;
194
  default:
195
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
196
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
197
  break;
198
  }
199
  } else {
200
+ const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
201
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
202
  }
203
  }
204
 
205
+ static void soft_max_back_f32_cuda(
206
+ const float * grad, const float * dstf, float * dst,
207
+ const int ncols, const int nrows, const float scale, cudaStream_t stream) {
208
+ const dim3 block_dims(WARP_SIZE, 1, 1);
209
+ const dim3 block_nums(nrows, 1, 1);
210
+
211
+ soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale);
212
+ }
213
+
214
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
215
  const ggml_tensor * src0 = dst->src[0];
216
  const ggml_tensor * src1 = dst->src[1];
217
 
218
+ const float * src0_d = (const float *) src0->data;
219
+ const void * src1_d = src1 ? (const void *) src1->data : nullptr;
220
+ float * dst_d = (float *) dst->data;
221
 
 
222
  cudaStream_t stream = ctx.stream();
223
 
224
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
233
  float scale = 1.0f;
234
  float max_bias = 0.0f;
235
 
236
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
237
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
238
 
239
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
240
 
241
  if (use_f16) {
242
+ soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
 
 
243
  } else {
244
+ soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
 
 
245
  }
246
  }
247
+
248
+ void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
249
+ const ggml_tensor * src0 = dst->src[0]; // grad
250
+ const ggml_tensor * src1 = dst->src[1]; // forward pass output
251
+
252
+ const float * src0_d = (const float *) src0->data;
253
+ const float * src1_d = (const float *) src1->data;
254
+ float * dst_d = (float *) dst->data;
255
+
256
+ cudaStream_t stream = ctx.stream();
257
+
258
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
259
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
260
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
261
+
262
+ const int64_t ncols = src0->ne[0];
263
+ const int64_t nrows = ggml_nrows(src0);
264
+
265
+ float scale = 1.0f;
266
+ float max_bias = 0.0f;
267
+
268
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
269
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
270
+
271
+ GGML_ASSERT(max_bias == 0.0f);
272
+
273
+ soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
274
+ }
ggml/src/ggml-cuda/softmax.cuh CHANGED
@@ -3,3 +3,5 @@
3
  #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
4
 
5
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
3
  #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
4
 
5
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6
+
7
+ void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/unary.cu CHANGED
@@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
51
  dst[i] = x[i] / (1.0f + expf(-x[i]));
52
  }
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  static __global__ void tanh_f32(const float * x, float * dst, int k) {
55
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
56
  if (i >= k) {
@@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
173
  silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
174
  }
175
 
 
 
 
 
 
176
  static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
177
  const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
178
  tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
284
  silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
285
  }
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
288
  const ggml_tensor * src0 = dst->src[0];
289
  const float * src0_d = (const float *)src0->data;
 
51
  dst[i] = x[i] / (1.0f + expf(-x[i]));
52
  }
53
 
54
+ static __global__ void silu_back_f32(
55
+ const float * grad, const float * xf, float * dst, const int k) {
56
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
57
+
58
+ if (i >= k) {
59
+ return;
60
+ }
61
+
62
+ const float xfi = xf[i];
63
+ const float s = 1.0f / (1.0f + expf(-xfi));
64
+ dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
65
+ }
66
+
67
  static __global__ void tanh_f32(const float * x, float * dst, int k) {
68
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
69
  if (i >= k) {
 
186
  silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
187
  }
188
 
189
+ static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
190
+ const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
191
+ silu_back_f32<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
192
+ }
193
+
194
  static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
195
  const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
196
  tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 
302
  silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
303
  }
304
 
305
+ void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
306
+ const ggml_tensor * src0 = dst->src[0]; // input from forward pass
307
+ const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
308
+
309
+ const float * src0_d = (const float *) src0->data;
310
+ const float * src1_d = (const float *) src1->data;
311
+ float * dst_d = (float *) dst->data;
312
+
313
+ cudaStream_t stream = ctx.stream();
314
+
315
+ GGML_ASSERT(ggml_is_contiguous(src0));
316
+
317
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
318
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
319
+
320
+ silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
321
+ }
322
+
323
  void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
324
  const ggml_tensor * src0 = dst->src[0];
325
  const float * src0_d = (const float *)src0->data;
ggml/src/ggml-cuda/unary.cuh CHANGED
@@ -4,6 +4,7 @@
4
  #define CUDA_STEP_BLOCK_SIZE 256
5
  #define CUDA_GELU_BLOCK_SIZE 256
6
  #define CUDA_SILU_BLOCK_SIZE 256
 
7
  #define CUDA_TANH_BLOCK_SIZE 256
8
  #define CUDA_RELU_BLOCK_SIZE 256
9
  #define CUDA_SIGMOID_BLOCK_SIZE 256
@@ -23,6 +24,8 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
23
 
24
  void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
25
 
 
 
26
  void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
27
 
28
  void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
4
  #define CUDA_STEP_BLOCK_SIZE 256
5
  #define CUDA_GELU_BLOCK_SIZE 256
6
  #define CUDA_SILU_BLOCK_SIZE 256
7
+ #define CUDA_SILU_BACK_BLOCK_SIZE 256
8
  #define CUDA_TANH_BLOCK_SIZE 256
9
  #define CUDA_RELU_BLOCK_SIZE 256
10
  #define CUDA_SIGMOID_BLOCK_SIZE 256
 
24
 
25
  void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
26
 
27
+ void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
28
+
29
  void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
30
 
31
  void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml.c CHANGED
@@ -3454,12 +3454,14 @@ struct ggml_tensor * ggml_soft_max_ext(
3454
  return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
3455
  }
3456
 
3457
- // ggml_soft_max_back
3458
 
3459
- static struct ggml_tensor * ggml_soft_max_back_impl(
3460
  struct ggml_context * ctx,
3461
  struct ggml_tensor * a,
3462
  struct ggml_tensor * b,
 
 
3463
  bool inplace) {
3464
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3465
 
@@ -3467,21 +3469,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
3467
  result->src[0] = a;
3468
  result->src[1] = b;
3469
 
 
 
 
3470
  return result;
3471
  }
3472
 
3473
- struct ggml_tensor * ggml_soft_max_back(
3474
  struct ggml_context * ctx,
3475
  struct ggml_tensor * a,
3476
- struct ggml_tensor * b) {
3477
- return ggml_soft_max_back_impl(ctx, a, b, false);
 
 
3478
  }
3479
 
3480
- struct ggml_tensor * ggml_soft_max_back_inplace(
3481
  struct ggml_context * ctx,
3482
  struct ggml_tensor * a,
3483
- struct ggml_tensor * b) {
3484
- return ggml_soft_max_back_impl(ctx, a, b, true);
 
 
3485
  }
3486
 
3487
  // ggml_rope
@@ -5080,10 +5089,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
5080
  struct ggml_tensor * a,
5081
  struct ggml_tensor * b,
5082
  struct ggml_tensor * c) {
5083
- GGML_ASSERT(ggml_are_same_shape(a, b));
5084
- GGML_ASSERT(ggml_is_scalar(c));
5085
 
5086
- struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
5087
 
5088
  result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
5089
  result->src[0] = a;
@@ -5262,7 +5271,7 @@ static void ggml_sub_or_set(
5262
  }
5263
 
5264
  static void ggml_compute_backward(
5265
- struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
5266
  struct ggml_tensor * tensor = cgraph->nodes[i];
5267
  struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
5268
 
@@ -5406,7 +5415,7 @@ static void ggml_compute_backward(
5406
  if (src0_needs_grads) {
5407
  float eps;
5408
  memcpy(&eps, tensor->op_params, sizeof(float));
5409
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
5410
  }
5411
  } break;
5412
  case GGML_OP_MUL_MAT: {
@@ -5589,7 +5598,13 @@ static void ggml_compute_backward(
5589
  } break;
5590
  case GGML_OP_SOFT_MAX: {
5591
  if (src0_needs_grads) {
5592
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
 
 
 
 
 
 
5593
  }
5594
  GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
5595
  } break;
@@ -5630,7 +5645,7 @@ static void ggml_compute_backward(
5630
  const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
5631
  const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
5632
 
5633
- ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
5634
  }
5635
  } break;
5636
  case GGML_OP_POOL_2D: {
@@ -5673,7 +5688,7 @@ static void ggml_compute_backward(
5673
  } break;
5674
  case GGML_UNARY_OP_SILU: {
5675
  if (src0_needs_grads) {
5676
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
5677
  }
5678
  } break;
5679
  case GGML_UNARY_OP_EXP: {
@@ -5690,7 +5705,7 @@ static void ggml_compute_backward(
5690
  } break;
5691
  case GGML_OP_CROSS_ENTROPY_LOSS: {
5692
  if (src0_needs_grads) {
5693
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
5694
  }
5695
  GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
5696
  } break;
 
3454
  return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
3455
  }
3456
 
3457
+ // ggml_soft_max_ext_back
3458
 
3459
+ static struct ggml_tensor * ggml_soft_max_ext_back_impl(
3460
  struct ggml_context * ctx,
3461
  struct ggml_tensor * a,
3462
  struct ggml_tensor * b,
3463
+ float scale,
3464
+ float max_bias,
3465
  bool inplace) {
3466
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3467
 
 
3469
  result->src[0] = a;
3470
  result->src[1] = b;
3471
 
3472
+ memcpy((float *) result->op_params + 0, &scale, sizeof(float));
3473
+ memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
3474
+
3475
  return result;
3476
  }
3477
 
3478
+ struct ggml_tensor * ggml_soft_max_ext_back(
3479
  struct ggml_context * ctx,
3480
  struct ggml_tensor * a,
3481
+ struct ggml_tensor * b,
3482
+ float scale,
3483
+ float max_bias) {
3484
+ return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
3485
  }
3486
 
3487
+ struct ggml_tensor * ggml_soft_max_ext_back_inplace(
3488
  struct ggml_context * ctx,
3489
  struct ggml_tensor * a,
3490
+ struct ggml_tensor * b,
3491
+ float scale,
3492
+ float max_bias) {
3493
+ return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
3494
  }
3495
 
3496
  // ggml_rope
 
5089
  struct ggml_tensor * a,
5090
  struct ggml_tensor * b,
5091
  struct ggml_tensor * c) {
5092
+ GGML_ASSERT(ggml_is_scalar(a));
5093
+ GGML_ASSERT(ggml_are_same_shape(b, c));
5094
 
5095
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
5096
 
5097
  result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
5098
  result->src[0] = a;
 
5271
  }
5272
 
5273
  static void ggml_compute_backward(
5274
+ struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
5275
  struct ggml_tensor * tensor = cgraph->nodes[i];
5276
  struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
5277
 
 
5415
  if (src0_needs_grads) {
5416
  float eps;
5417
  memcpy(&eps, tensor->op_params, sizeof(float));
5418
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
5419
  }
5420
  } break;
5421
  case GGML_OP_MUL_MAT: {
 
5598
  } break;
5599
  case GGML_OP_SOFT_MAX: {
5600
  if (src0_needs_grads) {
5601
+ float scale = 1.0f;
5602
+ float max_bias = 0.0f;
5603
+
5604
+ memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
5605
+ memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
5606
+
5607
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
5608
  }
5609
  GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
5610
  } break;
 
5645
  const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
5646
  const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
5647
 
5648
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
5649
  }
5650
  } break;
5651
  case GGML_OP_POOL_2D: {
 
5688
  } break;
5689
  case GGML_UNARY_OP_SILU: {
5690
  if (src0_needs_grads) {
5691
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
5692
  }
5693
  } break;
5694
  case GGML_UNARY_OP_EXP: {
 
5705
  } break;
5706
  case GGML_OP_CROSS_ENTROPY_LOSS: {
5707
  if (src0_needs_grads) {
5708
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
5709
  }
5710
  GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
5711
  } break;