JohannesGaessler ggerganov slaren commited on
Commit
5c178b0
·
1 Parent(s): 0bb7364

ggml/examples: add backend support for numerical optimization (ggml/949)

Browse files

* CUDA eval works

* stochastic gradient descent op

* Adam except decay

* CUDA CROSS_ENTROPY_LOSS_BACK

* CUDA mnist-fc training works

* backend CLI arg

* refactor gguf load

* remove sched from opt_step_adam

* implement l1 regularization (weight decay)

* extra call to add optimizer

* initialize gradients with ggml_graph_reset

* gradient accumulation

* increment iter per eval instead of epoch

* adjust backend interfaces

* fix ggml_graph_reset without backend

* fix ggml graph export/import

* fixup

* rename

* revert ggml_opt changes

* more general CUDA repeat_back

* update documentation, fix CNN

* validation split

* add clarifying comment

* optimize PyTorch training

* adjust buffer size, thread count

* fix 0.0f validation split

* Update examples/mnist/mnist-common.cpp

Co-authored-by: Georgi Gerganov <[email protected]>

* fix gradient accumulation

* tensor flag for accumulators -> tensor hash set

* Update include/ggml.h

Co-authored-by: slaren <[email protected]>

* Update tests/test-backend-ops.cpp

Co-authored-by: slaren <[email protected]>

* Update tests/test-backend-ops.cpp

Co-authored-by: slaren <[email protected]>

* fix test prints

* Update src/ggml-backend.c

Co-authored-by: Georgi Gerganov <[email protected]>

* better CUDA support for noncontiguous out_prod

* add comment

---------

Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: slaren <[email protected]>

ggml/include/ggml-backend.h CHANGED
@@ -66,6 +66,7 @@ extern "C" {
66
  // "offset" refers to the offset of the tensor data for setting/getting data
67
  GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
68
  GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
 
69
 
70
  GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
71
 
@@ -122,7 +123,7 @@ extern "C" {
122
  // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
123
 
124
  GGML_API size_t ggml_backend_reg_get_count(void);
125
- GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
126
  GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
127
  GGML_API const char * ggml_backend_reg_get_name(size_t i);
128
  GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
 
66
  // "offset" refers to the offset of the tensor data for setting/getting data
67
  GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
68
  GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
69
+ GGML_API GGML_CALL void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
70
 
71
  GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
72
 
 
123
  // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
124
 
125
  GGML_API size_t ggml_backend_reg_get_count(void);
126
+ GGML_API size_t ggml_backend_reg_find_by_name(const char * name); // returns index of backend with name, or SIZE_MAX if not found
127
  GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
128
  GGML_API const char * ggml_backend_reg_get_name(size_t i);
129
  GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
ggml/include/ggml.h CHANGED
@@ -533,6 +533,7 @@ extern "C" {
533
 
534
  GGML_OP_CROSS_ENTROPY_LOSS,
535
  GGML_OP_CROSS_ENTROPY_LOSS_BACK,
 
536
 
537
  GGML_OP_COUNT,
538
  };
@@ -569,10 +570,12 @@ extern "C" {
569
  GGML_LOG_LEVEL_DEBUG = 5
570
  };
571
 
 
572
  enum ggml_tensor_flag {
573
- GGML_TENSOR_FLAG_INPUT = 1,
574
- GGML_TENSOR_FLAG_OUTPUT = 2,
575
- GGML_TENSOR_FLAG_PARAM = 4,
 
576
  };
577
 
578
  // ggml object
@@ -2080,17 +2083,38 @@ extern "C" {
2080
  struct ggml_tensor * b,
2081
  struct ggml_tensor * c);
2082
 
 
 
 
 
 
 
 
 
 
 
 
 
2083
  //
2084
  // automatic differentiation
2085
  //
2086
 
2087
- GGML_API void ggml_set_param(
2088
- struct ggml_context * ctx,
2089
- struct ggml_tensor * tensor);
2090
 
2091
 
2092
  GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
2093
- GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
 
 
 
 
 
 
 
 
 
 
2094
 
2095
  // graph allocation in a context
2096
  GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
@@ -2098,7 +2122,7 @@ extern "C" {
2098
  GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
2099
  GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
2100
  GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
2101
- GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
2102
  GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
2103
 
2104
  GGML_API size_t ggml_graph_overhead(void);
 
533
 
534
  GGML_OP_CROSS_ENTROPY_LOSS,
535
  GGML_OP_CROSS_ENTROPY_LOSS_BACK,
536
+ GGML_OP_OPT_STEP_ADAMW,
537
 
538
  GGML_OP_COUNT,
539
  };
 
570
  GGML_LOG_LEVEL_DEBUG = 5
571
  };
572
 
573
+ // this tensor...
574
  enum ggml_tensor_flag {
575
+ GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
576
+ GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
577
+ GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
578
+ GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
579
  };
580
 
581
  // ggml object
 
2083
  struct ggml_tensor * b,
2084
  struct ggml_tensor * c);
2085
 
2086
+ // AdamW optimizer step
2087
+ // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
2088
+ // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
2089
+ GGML_API struct ggml_tensor * ggml_opt_step_adamw(
2090
+ struct ggml_context * ctx,
2091
+ struct ggml_tensor * a,
2092
+ float alpha,
2093
+ float beta1,
2094
+ float beta2,
2095
+ float eps,
2096
+ float wd); // weight decay
2097
+
2098
  //
2099
  // automatic differentiation
2100
  //
2101
 
2102
+ GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
2103
+ GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
 
2104
 
2105
 
2106
  GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
2107
+ GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
2108
+
2109
+ GGML_API void ggml_build_opt_adamw(
2110
+ struct ggml_context * ctx,
2111
+ struct ggml_cgraph * gf,
2112
+ struct ggml_cgraph * gb,
2113
+ float alpha,
2114
+ float beta1,
2115
+ float beta2,
2116
+ float eps,
2117
+ float wd); // weight decay
2118
 
2119
  // graph allocation in a context
2120
  GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
 
2122
  GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
2123
  GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
2124
  GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
2125
+ GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2126
  GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
2127
 
2128
  GGML_API size_t ggml_graph_overhead(void);
ggml/src/ggml-backend-impl.h CHANGED
@@ -38,15 +38,16 @@ extern "C" {
38
  typedef void * ggml_backend_buffer_context_t;
39
 
40
  struct ggml_backend_buffer_i {
41
- const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer);
42
- void (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer);
43
- void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer);
44
- void (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
45
- void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
46
- void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
47
- bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
48
- void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value);
49
- void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
 
50
  };
51
 
52
  struct ggml_backend_buffer {
 
38
  typedef void * ggml_backend_buffer_context_t;
39
 
40
  struct ggml_backend_buffer_i {
41
+ const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer);
42
+ void (*GGML_CALL free_buffer) (ggml_backend_buffer_t buffer);
43
+ void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer);
44
+ void (*GGML_CALL init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
45
+ void (*GGML_CALL memset_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
46
+ void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
47
+ void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
48
+ bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
49
+ void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value);
50
+ void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
51
  };
52
 
53
  struct ggml_backend_buffer {
ggml/src/ggml-backend.c CHANGED
@@ -246,6 +246,22 @@ GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void *
246
  buf->iface.get_tensor(buf, tensor, data, offset, size);
247
  }
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  void ggml_backend_synchronize(ggml_backend_t backend) {
250
  if (backend->iface.synchronize == NULL) {
251
  return;
@@ -569,6 +585,12 @@ GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t
569
  free(buffer->context);
570
  }
571
 
 
 
 
 
 
 
572
  GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
573
  memcpy((char *)tensor->data + offset, data, size);
574
 
@@ -600,6 +622,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
600
  /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
601
  /* .get_base = */ ggml_backend_cpu_buffer_get_base,
602
  /* .init_tensor = */ NULL, // no initialization required
 
603
  /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
604
  /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
605
  /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
@@ -613,6 +636,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
613
  /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
614
  /* .get_base = */ ggml_backend_cpu_buffer_get_base,
615
  /* .init_tensor = */ NULL, // no initialization required
 
616
  /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
617
  /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
618
  /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
@@ -980,6 +1004,7 @@ static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(
980
  /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
981
  /* .get_base = */ NULL,
982
  /* .init_tensor = */ NULL,
 
983
  /* .set_tensor = */ NULL,
984
  /* .get_tensor = */ NULL,
985
  /* .cpy_tensor = */ NULL,
 
246
  buf->iface.get_tensor(buf, tensor, data, offset, size);
247
  }
248
 
249
+ GGML_API GGML_CALL void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
250
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
251
+
252
+ GGML_ASSERT(buf != NULL && "tensor buffer not set");
253
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
254
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
255
+
256
+ if (!size) {
257
+ return;
258
+ }
259
+
260
+ GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not supported by backend buffer");
261
+
262
+ buf->iface.memset_tensor(buf, tensor, value, offset, size);
263
+ }
264
+
265
  void ggml_backend_synchronize(ggml_backend_t backend) {
266
  if (backend->iface.synchronize == NULL) {
267
  return;
 
585
  free(buffer->context);
586
  }
587
 
588
+ GGML_CALL static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
589
+ memset((char *)tensor->data + offset, value, size);
590
+
591
+ GGML_UNUSED(buffer);
592
+ }
593
+
594
  GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
595
  memcpy((char *)tensor->data + offset, data, size);
596
 
 
622
  /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
623
  /* .get_base = */ ggml_backend_cpu_buffer_get_base,
624
  /* .init_tensor = */ NULL, // no initialization required
625
+ /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
626
  /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
627
  /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
628
  /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
 
636
  /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
637
  /* .get_base = */ ggml_backend_cpu_buffer_get_base,
638
  /* .init_tensor = */ NULL, // no initialization required
639
+ /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
640
  /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
641
  /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
642
  /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
 
1004
  /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
1005
  /* .get_base = */ NULL,
1006
  /* .init_tensor = */ NULL,
1007
+ /* .memset_tensor = */ NULL,
1008
  /* .set_tensor = */ NULL,
1009
  /* .get_tensor = */ NULL,
1010
  /* .cpy_tensor = */ NULL,
ggml/src/ggml-cann.cpp CHANGED
@@ -1036,6 +1036,7 @@ static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
1036
  /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
1037
  /* .get_base = */ ggml_backend_cann_buffer_get_base,
1038
  /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
 
1039
  /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
1040
  /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
1041
  /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
 
1036
  /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
1037
  /* .get_base = */ ggml_backend_cann_buffer_get_base,
1038
  /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
1039
+ /* .memset_tensor = */ NULL,
1040
  /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
1041
  /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
1042
  /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
ggml/src/ggml-cuda.cu CHANGED
@@ -21,6 +21,8 @@
21
  #include "ggml-cuda/mmq.cuh"
22
  #include "ggml-cuda/mmvq.cuh"
23
  #include "ggml-cuda/norm.cuh"
 
 
24
  #include "ggml-cuda/pad.cuh"
25
  #include "ggml-cuda/pool2d.cuh"
26
  #include "ggml-cuda/quantize.cuh"
@@ -493,6 +495,14 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
493
  }
494
  }
495
 
 
 
 
 
 
 
 
 
496
  GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
497
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
498
 
@@ -544,6 +554,7 @@ static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
544
  /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
545
  /* .get_base = */ ggml_backend_cuda_buffer_get_base,
546
  /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
 
547
  /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
548
  /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
549
  /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
@@ -860,6 +871,7 @@ static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
860
  /* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
861
  /* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
862
  /* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
 
863
  /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
864
  /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
865
  /* .cpy_tensor = */ NULL,
@@ -2168,6 +2180,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2168
  case GGML_OP_REPEAT:
2169
  ggml_cuda_op_repeat(ctx, dst);
2170
  break;
 
 
 
2171
  case GGML_OP_GET_ROWS:
2172
  ggml_cuda_op_get_rows(ctx, dst);
2173
  break;
@@ -2201,6 +2216,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2201
  case GGML_UNARY_OP_NEG:
2202
  ggml_cuda_op_neg(ctx, dst);
2203
  break;
 
 
 
2204
  case GGML_UNARY_OP_GELU:
2205
  ggml_cuda_op_gelu(ctx, dst);
2206
  break;
@@ -2267,6 +2285,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2267
  case GGML_OP_MUL_MAT_ID:
2268
  ggml_cuda_mul_mat_id(ctx, dst);
2269
  break;
 
 
 
2270
  case GGML_OP_SCALE:
2271
  ggml_cuda_op_scale(ctx, dst);
2272
  break;
@@ -2324,6 +2345,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2324
  case GGML_OP_CROSS_ENTROPY_LOSS:
2325
  ggml_cuda_cross_entropy_loss(ctx, dst);
2326
  break;
 
 
 
 
 
 
2327
  default:
2328
  return false;
2329
  }
@@ -2757,6 +2784,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2757
  case GGML_OP_UNARY:
2758
  switch (ggml_get_unary_op(op)) {
2759
  case GGML_UNARY_OP_NEG:
 
2760
  case GGML_UNARY_OP_GELU:
2761
  case GGML_UNARY_OP_SILU:
2762
  case GGML_UNARY_OP_RELU:
@@ -2809,6 +2837,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2809
  return false;
2810
  }
2811
  } break;
 
 
2812
  case GGML_OP_GET_ROWS:
2813
  {
2814
  switch (op->src[0]->type) {
@@ -2865,6 +2895,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2865
  } break;
2866
  case GGML_OP_DUP:
2867
  case GGML_OP_REPEAT:
 
 
 
 
 
 
2868
  case GGML_OP_CONCAT:
2869
  {
2870
  ggml_type src0_type = op->src[0]->type;
@@ -2931,9 +2967,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2931
  }
2932
  return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
2933
  op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
 
2934
  case GGML_OP_CROSS_ENTROPY_LOSS:
 
 
2935
  return true;
2936
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2937
  default:
2938
  return false;
2939
  }
 
21
  #include "ggml-cuda/mmq.cuh"
22
  #include "ggml-cuda/mmvq.cuh"
23
  #include "ggml-cuda/norm.cuh"
24
+ #include "ggml-cuda/opt-step-adamw.cuh"
25
+ #include "ggml-cuda/out-prod.cuh"
26
  #include "ggml-cuda/pad.cuh"
27
  #include "ggml-cuda/pool2d.cuh"
28
  #include "ggml-cuda/quantize.cuh"
 
495
  }
496
  }
497
 
498
+ GGML_CALL static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
499
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
500
+
501
+ ggml_cuda_set_device(ctx->device);
502
+ CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
503
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
504
+ }
505
+
506
  GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
507
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
508
 
 
554
  /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
555
  /* .get_base = */ ggml_backend_cuda_buffer_get_base,
556
  /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
557
+ /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor,
558
  /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
559
  /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
560
  /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
 
871
  /* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
872
  /* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
873
  /* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
874
+ /* .memset_tensor = */ NULL,
875
  /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
876
  /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
877
  /* .cpy_tensor = */ NULL,
 
2180
  case GGML_OP_REPEAT:
2181
  ggml_cuda_op_repeat(ctx, dst);
2182
  break;
2183
+ case GGML_OP_REPEAT_BACK:
2184
+ ggml_cuda_op_repeat_back(ctx, dst);
2185
+ break;
2186
  case GGML_OP_GET_ROWS:
2187
  ggml_cuda_op_get_rows(ctx, dst);
2188
  break;
 
2216
  case GGML_UNARY_OP_NEG:
2217
  ggml_cuda_op_neg(ctx, dst);
2218
  break;
2219
+ case GGML_UNARY_OP_STEP:
2220
+ ggml_cuda_op_step(ctx, dst);
2221
+ break;
2222
  case GGML_UNARY_OP_GELU:
2223
  ggml_cuda_op_gelu(ctx, dst);
2224
  break;
 
2285
  case GGML_OP_MUL_MAT_ID:
2286
  ggml_cuda_mul_mat_id(ctx, dst);
2287
  break;
2288
+ case GGML_OP_OUT_PROD:
2289
+ ggml_cuda_out_prod(ctx, dst);
2290
+ break;
2291
  case GGML_OP_SCALE:
2292
  ggml_cuda_op_scale(ctx, dst);
2293
  break;
 
2345
  case GGML_OP_CROSS_ENTROPY_LOSS:
2346
  ggml_cuda_cross_entropy_loss(ctx, dst);
2347
  break;
2348
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2349
+ ggml_cuda_cross_entropy_loss_back(ctx, dst);
2350
+ break;
2351
+ case GGML_OP_OPT_STEP_ADAMW:
2352
+ ggml_cuda_opt_step_adamw(ctx, dst);
2353
+ break;
2354
  default:
2355
  return false;
2356
  }
 
2784
  case GGML_OP_UNARY:
2785
  switch (ggml_get_unary_op(op)) {
2786
  case GGML_UNARY_OP_NEG:
2787
+ case GGML_UNARY_OP_STEP:
2788
  case GGML_UNARY_OP_GELU:
2789
  case GGML_UNARY_OP_SILU:
2790
  case GGML_UNARY_OP_RELU:
 
2837
  return false;
2838
  }
2839
  } break;
2840
+ case GGML_OP_OUT_PROD:
2841
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
2842
  case GGML_OP_GET_ROWS:
2843
  {
2844
  switch (op->src[0]->type) {
 
2895
  } break;
2896
  case GGML_OP_DUP:
2897
  case GGML_OP_REPEAT:
2898
+ {
2899
+ ggml_type src0_type = op->src[0]->type;
2900
+ return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
2901
+ } break;
2902
+ case GGML_OP_REPEAT_BACK:
2903
+ return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
2904
  case GGML_OP_CONCAT:
2905
  {
2906
  ggml_type src0_type = op->src[0]->type;
 
2967
  }
2968
  return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
2969
  op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2970
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2971
  case GGML_OP_CROSS_ENTROPY_LOSS:
2972
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2973
+ case GGML_OP_OPT_STEP_ADAMW:
2974
  return true;
 
2975
  default:
2976
  return false;
2977
  }
ggml/src/ggml-cuda/binbcast.cu CHANGED
@@ -1,4 +1,5 @@
1
  #include "binbcast.cuh"
 
2
 
3
  static __device__ __forceinline__ float op_repeat(const float a, const float b) {
4
  return b;
@@ -90,6 +91,30 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
90
  dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
91
  }
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  template<float (*bin_op)(const float, const float)>
94
  struct bin_bcast_cuda {
95
  template<typename src0_t, typename src1_t, typename dst_t>
@@ -247,6 +272,16 @@ struct bin_bcast_cuda {
247
  }
248
  };
249
 
 
 
 
 
 
 
 
 
 
 
250
  template<class op>
251
  static void ggml_cuda_op_bin_bcast(
252
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
@@ -286,3 +321,35 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
286
  void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
287
  ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
288
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #include "binbcast.cuh"
2
+ #include <cstdint>
3
 
4
  static __device__ __forceinline__ float op_repeat(const float a, const float b) {
5
  return b;
 
91
  dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
92
  }
93
 
94
+ template <typename T>
95
+ static __global__ void k_repeat_back(
96
+ const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
97
+ const int64_t ne0, const int64_t ne1, const int64_t ne2) {
98
+
99
+ const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
100
+ const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
101
+ const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
102
+
103
+ if (tid0 >= ne0) {
104
+ return;
105
+ }
106
+
107
+ T sum = 0;
108
+ for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
109
+ for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
110
+ for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
111
+ sum += src[i2*ne01*ne00 + i1*ne00 + i0];
112
+ }
113
+ }
114
+ }
115
+ dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
116
+ }
117
+
118
  template<float (*bin_op)(const float, const float)>
119
  struct bin_bcast_cuda {
120
  template<typename src0_t, typename src1_t, typename dst_t>
 
272
  }
273
  };
274
 
275
+ template <typename T>
276
+ static void repeat_back_cuda(
277
+ const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
278
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
279
+
280
+ const dim3 block_dims(WARP_SIZE, 1, 1);
281
+ const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
282
+ k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
283
+ }
284
+
285
  template<class op>
286
  static void ggml_cuda_op_bin_bcast(
287
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
 
321
  void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
322
  ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
323
  }
324
+
325
+ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
326
+ const ggml_tensor * src0 = dst->src[0];
327
+
328
+ GGML_ASSERT(src0->type == dst->type);
329
+ GGML_ASSERT(ggml_is_contiguous(src0));
330
+ GGML_ASSERT(ggml_is_contiguous(dst));
331
+ GGML_ASSERT(ggml_can_repeat(dst, src0));
332
+
333
+ cudaStream_t stream = ctx.stream();
334
+
335
+ const int64_t ne00 = src0->ne[0];
336
+ const int64_t ne01 = src0->ne[1];
337
+ const int64_t ne02 = src0->ne[2];
338
+ GGML_ASSERT(src0->ne[3] == 1);
339
+
340
+ const int64_t ne0 = dst->ne[0];
341
+ const int64_t ne1 = dst->ne[1];
342
+ const int64_t ne2 = dst->ne[2];
343
+ GGML_ASSERT(dst->ne[3] == 1);
344
+
345
+ switch (dst->type) {
346
+ case GGML_TYPE_F32: {
347
+ const float * src0_d = (const float *) src0->data;
348
+ float * dst_d = (float *) dst->data;
349
+ repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
350
+ } break;
351
+ default: {
352
+ GGML_ASSERT(false);
353
+ } break;
354
+ }
355
+ }
ggml/src/ggml-cuda/binbcast.cuh CHANGED
@@ -5,3 +5,5 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5
  void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6
  void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
7
  void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
5
  void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6
  void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
7
  void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8
+
9
+ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/cross-entropy-loss.cu CHANGED
@@ -71,6 +71,32 @@ static __global__ void cross_entropy_loss_f32(const float * logits, const float
71
  dst[blockIdx.x] = loss;
72
  }
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
75
  const ggml_tensor * src0 = dst->src[0];
76
  const ggml_tensor * src1 = dst->src[1];
@@ -104,3 +130,37 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
104
  // Combine results from individual blocks:
105
  sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
106
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  dst[blockIdx.x] = loss;
72
  }
73
 
74
+ static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
75
+ extern __shared__ float tmp[];
76
+
77
+ float maxval = -INFINITY;
78
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
79
+ const float val = logits[blockIdx.x*nclasses + i];
80
+ maxval = fmaxf(maxval, val);
81
+ tmp[i] = val;
82
+ }
83
+ maxval = warp_reduce_max(maxval);
84
+
85
+ float sum = 0.0f;
86
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
87
+ const float val = expf(tmp[i] - maxval);
88
+ sum += val;
89
+ tmp[i] = val;
90
+ }
91
+ sum = warp_reduce_sum(sum);
92
+ const float sm_scale = 1.0f/sum;
93
+
94
+ const float d_by_nrows = *loss/gridDim.x;
95
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
96
+ dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
97
+ }
98
+ }
99
+
100
  void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
101
  const ggml_tensor * src0 = dst->src[0];
102
  const ggml_tensor * src1 = dst->src[1];
 
130
  // Combine results from individual blocks:
131
  sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
132
  }
133
+
134
+ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
135
+ const ggml_tensor * src0 = dst->src[0];
136
+ const ggml_tensor * src1 = dst->src[1];
137
+ const ggml_tensor * opt0 = dst->src[2];
138
+
139
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
140
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
141
+ GGML_ASSERT(opt0->type == GGML_TYPE_F32);
142
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
143
+
144
+ GGML_ASSERT(ggml_is_contiguous(src0));
145
+ GGML_ASSERT(ggml_is_contiguous(src1));
146
+ GGML_ASSERT(ggml_is_contiguous(opt0));
147
+ GGML_ASSERT(ggml_is_contiguous(dst));
148
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
149
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
150
+
151
+ const int64_t ne00 = src0->ne[0];
152
+ const int64_t nrows = ggml_nrows(src0);
153
+
154
+ const float * src0_d = (const float *) src0->data;
155
+ const float * src1_d = (const float *) src1->data;
156
+ const float * opt0_d = (const float *) opt0->data;
157
+ float * dst_d = (float *) dst->data;
158
+
159
+ cudaStream_t stream = ctx.stream();
160
+
161
+ const dim3 blocks_dim(WARP_SIZE, 1, 1);
162
+ const dim3 blocks_num(nrows, 1, 1);
163
+ const int shmem = ne00*sizeof(float);
164
+
165
+ cross_entropy_loss_back_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, opt0_d, dst_d, ne00);
166
+ }
ggml/src/ggml-cuda/cross-entropy-loss.cuh CHANGED
@@ -3,3 +3,5 @@
3
  #define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
4
 
5
  void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
3
  #define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
4
 
5
  void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6
+
7
+ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/opt-step-adamw.cu ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "opt-step-adamw.cuh"
2
+
3
+ #include <cstdint>
4
+
5
+ static __global__ void opt_step_adamw_f32(
6
+ float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k,
7
+ const float alpha, const float beta1, const float beta2, const float eps, const float wd,
8
+ const float beta1h, const float beta2h) {
9
+
10
+ const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
11
+
12
+ if (i >= k) {
13
+ return;
14
+ }
15
+
16
+ const float gi = g[i];
17
+ const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
18
+ const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
19
+
20
+ g_m[i] = gmi;
21
+ g_v[i] = gvi;
22
+
23
+ const float mh = gmi*beta1h;
24
+ const float vh = sqrtf(gvi*beta2h) + eps;
25
+
26
+ x[i] = x[i]*(1.0f - alpha*wd) - mh/vh;
27
+ }
28
+
29
+ static void opt_step_adamw_f32_cuda(
30
+ float * x, const float * g, float * g_m, float * g_v, const int64_t k,
31
+ const float alpha, const float beta1, const float beta2, const float eps, const float wd,
32
+ const float beta1h, const float beta2h, cudaStream_t stream) {
33
+
34
+ const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
35
+ const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
36
+ opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h);
37
+ }
38
+
39
+ void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
40
+ const ggml_tensor * src0 = dst->src[0];
41
+ const ggml_tensor * src0_grad = dst->src[1];
42
+ const ggml_tensor * src0_grad_m = dst->src[2];
43
+ const ggml_tensor * src0_grad_v = dst->src[3];
44
+
45
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
46
+ GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
47
+ GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
48
+ GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
49
+ GGML_ASSERT(ggml_is_contiguous(src0));
50
+ GGML_ASSERT(ggml_is_contiguous(src0_grad));
51
+ GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
52
+ GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
53
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
54
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
55
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
56
+
57
+ float * src0_d = (float *) src0->data;
58
+ const float * src0_grad_d = (const float *) src0_grad->data;
59
+ float * src0_grad_m_d = (float *) src0_grad_m->data;
60
+ float * src0_grad_v_d = (float *) src0_grad_v->data;
61
+
62
+ cudaStream_t stream = ctx.stream();
63
+
64
+ const int64_t ne = ggml_nelements(src0);
65
+
66
+ int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
67
+ float alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float));
68
+ float beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float));
69
+ float beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float));
70
+ float eps; memcpy(&eps, &dst->op_params[5], sizeof(float));
71
+ float wd; memcpy(&wd, &dst->op_params[6], sizeof(float));
72
+
73
+ const float beta1h = alpha/(1.0f - powf(beta1, iter));
74
+ const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
75
+
76
+ opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream);
77
+
78
+ iter++;
79
+ memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
80
+ }
ggml/src/ggml-cuda/opt-step-adamw.cuh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256
4
+
5
+ void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/out-prod.cu ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "out-prod.cuh"
2
+ #include "vendors/cuda.h"
3
+
4
+ #include <cstdint>
5
+
6
+ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
7
+ const ggml_tensor * src0 = dst->src[0];
8
+ const ggml_tensor * src1 = dst->src[1];
9
+
10
+ GGML_TENSOR_BINARY_OP_LOCALS
11
+
12
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
13
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
14
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
15
+ GGML_ASSERT(ggml_is_contiguous(src0));
16
+ GGML_ASSERT(ggml_is_contiguous(dst));
17
+
18
+ GGML_ASSERT(ne01 == ne11);
19
+ GGML_ASSERT(ne0 == ne00);
20
+ GGML_ASSERT(ne1 == ne10);
21
+
22
+ GGML_ASSERT(ne2 == src0->ne[2]);
23
+ GGML_ASSERT(ne2 == src1->ne[2]);
24
+ GGML_ASSERT(ne3 == src0->ne[3]);
25
+ GGML_ASSERT(ne3 == src1->ne[3]);
26
+
27
+ const float * src0_d = (const float *) src0->data;
28
+ const float * src1_d = (const float *) src1->data;
29
+ float * dst_d = (float *) dst->data;
30
+
31
+ cudaStream_t stream = ctx.stream();
32
+ cublasHandle_t handle = ctx.cublas_handle();
33
+
34
+ const float alpha = 1.0f;
35
+ const float beta = 0.0f;
36
+
37
+ GGML_ASSERT(ne2 == 1);
38
+ GGML_ASSERT(ne3 == 1);
39
+ CUBLAS_CHECK(cublasSetStream(handle, stream));
40
+
41
+ const bool src1_T = ggml_is_transposed(src1);
42
+ const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
43
+ const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
44
+ GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
45
+
46
+ CUBLAS_CHECK(
47
+ cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
48
+ ne0, ne1, ne01,
49
+ &alpha, src0_d, ne00,
50
+ src1_d, ldb,
51
+ &beta, dst_d, ne0));
52
+ }
ggml/src/ggml-cuda/out-prod.cuh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/unary.cu CHANGED
@@ -10,6 +10,16 @@ static __global__ void neg_f32(const float * x, float * dst, const int k) {
10
  dst[i] = -x[i];
11
  }
12
 
 
 
 
 
 
 
 
 
 
 
13
  static __global__ void gelu_f32(const float * x, float * dst, const int k) {
14
  const float GELU_COEF_A = 0.044715f;
15
  const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -134,6 +144,11 @@ static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
134
  neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
135
  }
136
 
 
 
 
 
 
137
  static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
138
  const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
139
  gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -213,6 +228,20 @@ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
213
  neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
214
  }
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
217
  const ggml_tensor * src0 = dst->src[0];
218
  const float * src0_d = (const float *)src0->data;
 
10
  dst[i] = -x[i];
11
  }
12
 
13
+ static __global__ void step_f32(const float * x, float * dst, const int k) {
14
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
15
+
16
+ if (i >= k) {
17
+ return;
18
+ }
19
+
20
+ dst[i] = x[i] > 0.0f;
21
+ }
22
+
23
  static __global__ void gelu_f32(const float * x, float * dst, const int k) {
24
  const float GELU_COEF_A = 0.044715f;
25
  const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
144
  neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
145
  }
146
 
147
+ static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
148
+ const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
149
+ step_f32<<<num_blocks, CUDA_STEP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
150
+ }
151
+
152
  static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
153
  const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
154
  gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 
228
  neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
229
  }
230
 
231
+ void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
232
+ const ggml_tensor * src0 = dst->src[0];
233
+ const float * src0_d = (const float *)src0->data;
234
+ float * dst_d = (float *)dst->data;
235
+ cudaStream_t stream = ctx.stream();
236
+
237
+ GGML_ASSERT(ggml_is_contiguous(src0));
238
+
239
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
240
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
241
+
242
+ step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
243
+ }
244
+
245
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
246
  const ggml_tensor * src0 = dst->src[0];
247
  const float * src0_d = (const float *)src0->data;
ggml/src/ggml-cuda/unary.cuh CHANGED
@@ -1,6 +1,7 @@
1
  #include "common.cuh"
2
 
3
  #define CUDA_NEG_BLOCK_SIZE 256
 
4
  #define CUDA_GELU_BLOCK_SIZE 256
5
  #define CUDA_SILU_BLOCK_SIZE 256
6
  #define CUDA_TANH_BLOCK_SIZE 256
@@ -15,6 +16,8 @@
15
 
16
  void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
17
 
 
 
18
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
19
 
20
  void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
1
  #include "common.cuh"
2
 
3
  #define CUDA_NEG_BLOCK_SIZE 256
4
+ #define CUDA_STEP_BLOCK_SIZE 256
5
  #define CUDA_GELU_BLOCK_SIZE 256
6
  #define CUDA_SILU_BLOCK_SIZE 256
7
  #define CUDA_TANH_BLOCK_SIZE 256
 
16
 
17
  void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
18
 
19
+ void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
20
+
21
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
22
 
23
  void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-kompute.cpp CHANGED
@@ -1872,6 +1872,7 @@ static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
1872
  /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
1873
  /* .get_base = */ ggml_backend_kompute_buffer_get_base,
1874
  /* .init_tensor = */ NULL,
 
1875
  /* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor,
1876
  /* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor,
1877
  /* .cpy_tensor = */ NULL,
 
1872
  /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
1873
  /* .get_base = */ ggml_backend_kompute_buffer_get_base,
1874
  /* .init_tensor = */ NULL,
1875
+ /* .memset_tensor = */ NULL,
1876
  /* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor,
1877
  /* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor,
1878
  /* .cpy_tensor = */ NULL,
ggml/src/ggml-metal.m CHANGED
@@ -3165,6 +3165,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
3165
  /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
3166
  /* .get_base = */ ggml_backend_metal_buffer_get_base,
3167
  /* .init_tensor = */ NULL,
 
3168
  /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
3169
  /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
3170
  /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
 
3165
  /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
3166
  /* .get_base = */ ggml_backend_metal_buffer_get_base,
3167
  /* .init_tensor = */ NULL,
3168
+ /* .memset_tensor = */ NULL,
3169
  /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
3170
  /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
3171
  /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
ggml/src/ggml-rpc.cpp CHANGED
@@ -469,6 +469,7 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
469
  /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
470
  /* .get_base = */ ggml_backend_rpc_buffer_get_base,
471
  /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
 
472
  /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
473
  /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
474
  /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
 
469
  /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
470
  /* .get_base = */ ggml_backend_rpc_buffer_get_base,
471
  /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
472
+ /* .memset_tensor = */ NULL,
473
  /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
474
  /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
475
  /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
ggml/src/ggml-sycl.cpp CHANGED
@@ -4318,6 +4318,7 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
4318
  /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
4319
  /* .get_base = */ ggml_backend_sycl_buffer_get_base,
4320
  /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
 
4321
  /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
4322
  /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
4323
  /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
 
4318
  /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
4319
  /* .get_base = */ ggml_backend_sycl_buffer_get_base,
4320
  /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
4321
+ /* .memset_tensor = */ NULL,
4322
  /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
4323
  /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
4324
  /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
ggml/src/ggml-vulkan.cpp CHANGED
@@ -6221,6 +6221,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
6221
  /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
6222
  /* .get_base = */ ggml_backend_vk_buffer_get_base,
6223
  /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
 
6224
  /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
6225
  /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
6226
  /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
 
6221
  /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
6222
  /* .get_base = */ ggml_backend_vk_buffer_get_base,
6223
  /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
6224
+ /* .memset_tensor = */ NULL,
6225
  /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
6226
  /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
6227
  /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
ggml/src/ggml.c CHANGED
@@ -1,6 +1,7 @@
1
  #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
2
  #define _USE_MATH_DEFINES // For M_PI on MSVC
3
 
 
4
  #include "ggml-impl.h"
5
  #include "ggml-quants.h"
6
  #include "ggml.h"
@@ -2977,9 +2978,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2977
 
2978
  "CROSS_ENTROPY_LOSS",
2979
  "CROSS_ENTROPY_LOSS_BACK",
 
2980
  };
2981
 
2982
- static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
2983
 
2984
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2985
  "none",
@@ -3070,9 +3072,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3070
 
3071
  "cross_entropy_loss(x,y)",
3072
  "cross_entropy_loss_back(x,y)",
 
3073
  };
3074
 
3075
- static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
3076
 
3077
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
3078
 
@@ -4079,7 +4082,11 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa
4079
  }
4080
 
4081
  struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
4082
- memset(tensor->data, 0, ggml_nbytes(tensor));
 
 
 
 
4083
  return tensor;
4084
  }
4085
 
@@ -8305,11 +8312,46 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
8305
  return result;
8306
  }
8307
 
8308
- ////////////////////////////////////////////////////////////////////////////////
8309
 
8310
- void ggml_set_param(
8311
  struct ggml_context * ctx,
8312
- struct ggml_tensor * tensor) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8313
  tensor->flags |= GGML_TENSOR_FLAG_PARAM;
8314
 
8315
  GGML_ASSERT(tensor->grad == NULL);
@@ -8317,6 +8359,13 @@ void ggml_set_param(
8317
  ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
8318
  }
8319
 
 
 
 
 
 
 
 
8320
  // ggml_compute_forward_dup
8321
 
8322
  static void ggml_compute_forward_dup_same_cont(
@@ -17391,7 +17440,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
17391
  const int64_t ir0 = dr*ith;
17392
  const int64_t ir1 = MIN(ir0 + dr, nr);
17393
 
17394
- float * d = (float *) opt0->data;
17395
 
17396
  for (int64_t i1 = ir0; i1 < ir1; i1++) {
17397
  float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
@@ -17415,7 +17464,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
17415
 
17416
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
17417
  ggml_vec_sub_f32(nc, ds0, ds0, s1);
17418
- ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
17419
 
17420
  #ifndef NDEBUG
17421
  for (int i = 0; i < nc; ++i) {
@@ -17444,6 +17493,94 @@ static void ggml_compute_forward_cross_entropy_loss_back(
17444
  }
17445
  }
17446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17447
  /////////////////////////////////
17448
 
17449
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -17789,6 +17926,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17789
  ggml_compute_forward_cross_entropy_loss_back(params, tensor);
17790
  }
17791
  break;
 
 
 
 
 
17792
  case GGML_OP_NONE:
17793
  {
17794
  // nop
@@ -17943,7 +18085,7 @@ void ggml_build_backward_gradient_checkpointing(
17943
  struct ggml_tensor * * checkpoints,
17944
  int n_checkpoints) {
17945
  ggml_graph_cpy(gf, gb_tmp);
17946
- ggml_build_backward_expand(ctx, gf, gb_tmp, true);
17947
 
17948
  if (n_checkpoints <= 0) {
17949
  ggml_graph_cpy(gb_tmp, gb);
@@ -17981,42 +18123,93 @@ void ggml_build_backward_gradient_checkpointing(
17981
  ggml_hash_map_free(replacements);
17982
  }
17983
 
17984
- // functions to change gradients considering the case that input a might be initial gradient with zero value
17985
-
17986
- static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17987
  if (ggml_hash_contains(zero_table, a)) {
17988
  return b;
17989
- } else {
17990
- return ggml_add_impl(ctx, a, b, false);
17991
  }
 
17992
  }
17993
 
17994
- static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17995
  if (ggml_hash_contains(zero_table, a)) {
17996
- struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
17997
  return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
17998
- } else {
17999
- return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
18000
  }
 
18001
  }
18002
 
18003
- static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
 
 
 
 
 
 
 
 
 
 
 
 
18004
  if (ggml_hash_contains(zero_table, a)) {
18005
  return ggml_repeat(ctx, b, a);
18006
- } else {
18007
- return ggml_add1_impl(ctx, a, b, false);
18008
  }
 
18009
  }
18010
 
18011
- static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
 
 
 
 
 
 
 
 
 
 
 
 
18012
  if (ggml_hash_contains(zero_table, a)) {
18013
  return ggml_neg(ctx, b);
18014
- } else {
18015
- return ggml_sub_impl(ctx, a, b, false);
18016
  }
 
18017
  }
18018
 
18019
- static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table) {
18020
  struct ggml_tensor * src0 = tensor->src[0];
18021
  struct ggml_tensor * src1 = tensor->src[1];
18022
  struct ggml_tensor * src2 = tensor->src[2];
@@ -18025,38 +18218,38 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18025
  case GGML_OP_DUP:
18026
  {
18027
  if (src0->grad) {
18028
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18029
  }
18030
  } break;
18031
  case GGML_OP_ADD:
18032
  {
18033
  if (src0->grad) {
18034
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18035
  }
18036
  if (src1->grad) {
18037
  if (ggml_are_same_shape(src0, src1)) {
18038
- src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
18039
  } else {
18040
- src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
18041
  }
18042
  }
18043
  } break;
18044
  case GGML_OP_ADD1:
18045
  {
18046
  if (src0->grad) {
18047
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18048
  }
18049
  if (src1->grad) {
18050
  src1->grad = ggml_add_or_set(ctx,
18051
  src1->grad,
18052
  ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
18053
- zero_table);
18054
  }
18055
  } break;
18056
  case GGML_OP_ACC:
18057
  {
18058
  if (src0->grad) {
18059
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18060
  }
18061
  if (src1->grad) {
18062
  const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -18078,16 +18271,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18078
  ggml_reshape(ctx,
18079
  ggml_cont(ctx, tensor_grad_view),
18080
  src1->grad),
18081
- zero_table);
18082
  }
18083
  } break;
18084
  case GGML_OP_SUB:
18085
  {
18086
  if (src0->grad) {
18087
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18088
  }
18089
  if (src1->grad) {
18090
- src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
18091
  }
18092
  } break;
18093
  case GGML_OP_MUL:
@@ -18097,14 +18290,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18097
  ggml_add_or_set(ctx,
18098
  src0->grad,
18099
  ggml_mul(ctx, src1, tensor->grad),
18100
- zero_table);
18101
  }
18102
  if (src1->grad) {
18103
  src1->grad =
18104
  ggml_add_or_set(ctx,
18105
  src1->grad,
18106
  ggml_mul(ctx, src0, tensor->grad),
18107
- zero_table);
18108
  }
18109
  } break;
18110
  case GGML_OP_DIV:
@@ -18114,7 +18307,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18114
  ggml_add_or_set(ctx,
18115
  src0->grad,
18116
  ggml_div(ctx, tensor->grad, src1),
18117
- zero_table);
18118
  }
18119
  if (src1->grad) {
18120
  src1->grad =
@@ -18123,7 +18316,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18123
  ggml_mul(ctx,
18124
  tensor->grad,
18125
  ggml_div(ctx, tensor, src1)),
18126
- zero_table);
18127
  }
18128
  } break;
18129
  case GGML_OP_SQR:
@@ -18135,7 +18328,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18135
  ggml_scale(ctx,
18136
  ggml_mul(ctx, src0, tensor->grad),
18137
  2.0f),
18138
- zero_table);
18139
  }
18140
  } break;
18141
  case GGML_OP_SQRT:
@@ -18149,7 +18342,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18149
  tensor->grad,
18150
  tensor),
18151
  0.5f),
18152
- zero_table);
18153
  }
18154
  } break;
18155
  case GGML_OP_LOG:
@@ -18161,7 +18354,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18161
  ggml_div(ctx,
18162
  tensor->grad,
18163
  src0),
18164
- zero_table);
18165
  }
18166
  } break;
18167
  case GGML_OP_SIN:
@@ -18173,7 +18366,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18173
  ggml_mul(ctx,
18174
  tensor->grad,
18175
  ggml_cos(ctx, src0)),
18176
- zero_table);
18177
  }
18178
  } break;
18179
  case GGML_OP_COS:
@@ -18185,7 +18378,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18185
  ggml_mul(ctx,
18186
  tensor->grad,
18187
  ggml_sin(ctx, src0)),
18188
- zero_table);
18189
  }
18190
  } break;
18191
  case GGML_OP_SUM:
@@ -18195,7 +18388,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18195
  ggml_add1_or_set(ctx,
18196
  src0->grad,
18197
  tensor->grad,
18198
- zero_table);
18199
  }
18200
  } break;
18201
  case GGML_OP_SUM_ROWS:
@@ -18207,7 +18400,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18207
  ggml_repeat(ctx,
18208
  tensor->grad,
18209
  src0->grad),
18210
- zero_table);
18211
  }
18212
  } break;
18213
  case GGML_OP_MEAN:
@@ -18222,7 +18415,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18222
  src0->grad = ggml_add_or_set(ctx,
18223
  src0->grad,
18224
  ggml_repeat_back(ctx, tensor->grad, src0->grad),
18225
- zero_table);
18226
  }
18227
  } break;
18228
  case GGML_OP_REPEAT_BACK:
@@ -18232,7 +18425,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18232
  src0->grad = ggml_add_or_set(ctx,
18233
  src0->grad,
18234
  ggml_repeat(ctx, tensor->grad, src0->grad),
18235
- zero_table);
18236
  }
18237
  } break;
18238
  case GGML_OP_CONCAT:
@@ -18257,7 +18450,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18257
  src0->grad = ggml_add_or_set(ctx,
18258
  src0->grad,
18259
  ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
18260
- zero_table);
18261
  }
18262
  } break;
18263
  case GGML_OP_RMS_NORM_BACK:
@@ -18305,7 +18498,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18305
  ggml_add_or_set(ctx,
18306
  src0->grad, // [n,m,q1,r1]
18307
  s1_tg, // [n,m,q1,r1]
18308
- zero_table);
18309
  }
18310
  if (src1->grad) {
18311
  src1->grad =
@@ -18323,7 +18516,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18323
  src0, // [n,m,q1,r1]
18324
  ggml_transpose(ctx, // [p,m,qq,rr]
18325
  tensor->grad)), // [m,p,qq,rr]
18326
- zero_table);
18327
  }
18328
  } break;
18329
  case GGML_OP_MUL_MAT_ID:
@@ -18345,7 +18538,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18345
  ggml_add_or_set(ctx,
18346
  src0->grad,
18347
  ggml_scale_impl(ctx, tensor->grad, s, false),
18348
- zero_table);
18349
  }
18350
  } break;
18351
  case GGML_OP_SET:
@@ -18374,7 +18567,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18374
  tensor->grad,
18375
  ggml_neg(ctx, tensor_grad_view),
18376
  nb1, nb2, nb3, offset, false),
18377
- zero_table);
18378
  }
18379
 
18380
  if (src1->grad) {
@@ -18384,7 +18577,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18384
  ggml_reshape(ctx,
18385
  ggml_cont(ctx, tensor_grad_view),
18386
  src1->grad),
18387
- zero_table);
18388
  }
18389
  } break;
18390
  case GGML_OP_CPY:
@@ -18395,7 +18588,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18395
  // tensor = src0 * 1 + src1 * 0
18396
  if (src0->grad) {
18397
  // dsrc0 = dtensor * 1
18398
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18399
  }
18400
  if (src1->grad) {
18401
  // dsrc1 = dtensor * 0 -> noop
@@ -18407,7 +18600,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18407
  if (src0->grad) {
18408
  GGML_ASSERT(ggml_is_contiguous(src0->grad));
18409
  GGML_ASSERT(ggml_is_contiguous(tensor->grad));
18410
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18411
  }
18412
  } break;
18413
  case GGML_OP_RESHAPE:
@@ -18421,7 +18614,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18421
  ? tensor->grad
18422
  : ggml_cont(ctx, tensor->grad),
18423
  src0->grad),
18424
- zero_table);
18425
  }
18426
  } break;
18427
  case GGML_OP_VIEW:
@@ -18450,7 +18643,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18450
  nb3 = (nb3 / n0) * ng;
18451
  }
18452
 
18453
- src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
18454
  }
18455
  } break;
18456
  case GGML_OP_PERMUTE:
@@ -18475,7 +18668,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18475
  axes_backward[1],
18476
  axes_backward[2],
18477
  axes_backward[3]),
18478
- zero_table);
18479
  }
18480
  } break;
18481
  case GGML_OP_TRANSPOSE:
@@ -18485,7 +18678,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18485
  src0->grad =
18486
  ggml_add_or_set(ctx, src0->grad,
18487
  ggml_transpose(ctx, tensor->grad),
18488
- zero_table);
18489
  }
18490
  } break;
18491
  case GGML_OP_GET_ROWS:
@@ -18497,7 +18690,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18497
  // last ggml_get_rows_back argument src0->grad is only
18498
  // necessary to setup correct output shape
18499
  ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
18500
- zero_table);
18501
  }
18502
  if (src1->grad) {
18503
  // noop
@@ -18521,7 +18714,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18521
  /* ggml_diag_mask_inf_impl() shouldn't be here */
18522
  /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
18523
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
18524
- zero_table);
18525
  }
18526
  } break;
18527
  case GGML_OP_DIAG_MASK_ZERO:
@@ -18532,7 +18725,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18532
  src0->grad =
18533
  ggml_add_or_set(ctx, src0->grad,
18534
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
18535
- zero_table);
18536
  }
18537
  } break;
18538
  case GGML_OP_SOFT_MAX:
@@ -18542,7 +18735,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18542
  src0->grad =
18543
  ggml_add_or_set(ctx, src0->grad,
18544
  ggml_soft_max_back(ctx, tensor->grad, tensor),
18545
- zero_table);
18546
  }
18547
 
18548
  } break;
@@ -18583,7 +18776,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18583
  attn_factor,
18584
  beta_fast,
18585
  beta_slow),
18586
- zero_table);
18587
  }
18588
  } break;
18589
  case GGML_OP_ROPE_BACK:
@@ -18619,7 +18812,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18619
  beta_fast,
18620
  beta_slow,
18621
  false),
18622
- zero_table);
18623
  }
18624
  } break;
18625
  case GGML_OP_CLAMP:
@@ -18644,7 +18837,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18644
  src1->grad = ggml_add_or_set(ctx,
18645
  src1->grad,
18646
  ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
18647
- zero_table);
18648
  }
18649
  } break;
18650
  case GGML_OP_IM2COL_BACK:
@@ -18673,7 +18866,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18673
  src0->grad = ggml_add_or_set(ctx,
18674
  src0->grad,
18675
  ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
18676
- zero_table);
18677
  }
18678
  } break;
18679
  case GGML_OP_POOL_2D_BACK:
@@ -18738,7 +18931,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18738
  src0->grad = ggml_add_or_set(ctx,
18739
  src0->grad,
18740
  grad_q,
18741
- zero_table);
18742
  }
18743
  if (src1->grad) {
18744
  struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
@@ -18746,7 +18939,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18746
  src1->grad = ggml_add_or_set(ctx,
18747
  src1->grad,
18748
  grad_k,
18749
- zero_table);
18750
  }
18751
  if (src2->grad) {
18752
  struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
@@ -18754,7 +18947,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18754
  src2->grad = ggml_add_or_set(ctx,
18755
  src2->grad,
18756
  grad_v,
18757
- zero_table);
18758
  }
18759
  } break;
18760
  case GGML_OP_FLASH_ATTN_BACK:
@@ -18780,7 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18780
  ggml_mul(ctx,
18781
  ggml_sgn(ctx, src0),
18782
  tensor->grad),
18783
- zero_table);
18784
  }
18785
  } break;
18786
  case GGML_UNARY_OP_SGN:
@@ -18792,7 +18985,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18792
  case GGML_UNARY_OP_NEG:
18793
  {
18794
  if (src0->grad) {
18795
- src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
18796
  }
18797
  } break;
18798
  case GGML_UNARY_OP_STEP:
@@ -18817,7 +19010,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18817
  ggml_mul(ctx,
18818
  ggml_step(ctx, src0),
18819
  tensor->grad),
18820
- zero_table);
18821
  }
18822
  } break;
18823
  case GGML_UNARY_OP_SIGMOID:
@@ -18839,7 +19032,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18839
  src0->grad = ggml_add_or_set(ctx,
18840
  src0->grad,
18841
  ggml_silu_back(ctx, src0, tensor->grad),
18842
- zero_table);
18843
  }
18844
  } break;
18845
  case GGML_UNARY_OP_EXP:
@@ -18848,7 +19041,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18848
  src0->grad = ggml_add_or_set(ctx,
18849
  src0->grad,
18850
  ggml_mul(ctx, tensor, tensor->grad),
18851
- zero_table);
18852
  }
18853
  } break;
18854
  default:
@@ -18878,13 +19071,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18878
  src0,
18879
  src1,
18880
  tensor->grad),
18881
- zero_table);
18882
  }
18883
  } break;
18884
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
18885
  {
18886
  GGML_ABORT("fatal error"); // not supported
18887
  }
 
 
 
 
18888
  case GGML_OP_NONE:
18889
  {
18890
  // nop
@@ -18974,7 +19171,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
18974
  ggml_build_forward_impl(cgraph, tensor, true);
18975
  }
18976
 
18977
- void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
18978
  GGML_ASSERT(gf->n_nodes > 0);
18979
  GGML_ASSERT(gf->grads);
18980
 
@@ -18990,21 +19187,35 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
18990
  }
18991
  }
18992
 
18993
- // remember original gradients which start with zero values
18994
  struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
 
18995
  for (int i = 0; i < gf->n_nodes; i++) {
18996
- if (gf->grads[i]) {
18997
- ggml_hash_insert(&zero_table, gf->grads[i]);
 
 
 
 
 
 
 
 
 
 
 
 
 
18998
  }
18999
  }
19000
 
19001
  for (int i = gf->n_nodes - 1; i >= 0; i--) {
19002
  struct ggml_tensor * node = gf->nodes[i];
19003
 
19004
- // inplace operations to add gradients are not created by ggml_compute_backward
19005
  // use allocator to automatically make inplace operations
19006
  if (node->grad) {
19007
- ggml_compute_backward(ctx, node, &zero_table);
19008
  }
19009
  }
19010
 
@@ -19018,8 +19229,30 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
19018
  }
19019
 
19020
  ggml_hash_set_free(&zero_table);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19021
  }
19022
 
 
19023
  static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
19024
  void * ptr = *p;
19025
  ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
@@ -19147,10 +19380,28 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
19147
  GGML_ASSERT(cgraph->grads != NULL);
19148
 
19149
  for (int i = 0; i < cgraph->n_nodes; i++) {
19150
- struct ggml_tensor * grad = cgraph->grads[i];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19151
 
19152
- if (grad) {
19153
- ggml_set_zero(grad);
 
 
 
 
19154
  }
19155
  }
19156
  }
@@ -19415,6 +19666,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
19415
  } break;
19416
  case GGML_OP_CROSS_ENTROPY_LOSS:
19417
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
 
19418
  {
19419
  n_tasks = n_threads;
19420
  } break;
@@ -21777,7 +22029,7 @@ enum ggml_opt_result ggml_opt_resume(
21777
  ggml_build_forward_expand(gf, f);
21778
 
21779
  struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
21780
- ggml_build_backward_expand(ctx, gf, gb, true);
21781
 
21782
  return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
21783
  }
 
1
  #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
2
  #define _USE_MATH_DEFINES // For M_PI on MSVC
3
 
4
+ #include "ggml-backend.h"
5
  #include "ggml-impl.h"
6
  #include "ggml-quants.h"
7
  #include "ggml.h"
 
2978
 
2979
  "CROSS_ENTROPY_LOSS",
2980
  "CROSS_ENTROPY_LOSS_BACK",
2981
+ "OPT_STEP_ADAMW",
2982
  };
2983
 
2984
+ static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
2985
 
2986
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2987
  "none",
 
3072
 
3073
  "cross_entropy_loss(x,y)",
3074
  "cross_entropy_loss_back(x,y)",
3075
+ "adamw(x)",
3076
  };
3077
 
3078
+ static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
3079
 
3080
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
3081
 
 
4082
  }
4083
 
4084
  struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
4085
+ if (tensor->buffer) {
4086
+ ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
4087
+ } else {
4088
+ memset(tensor->data, 0, ggml_nbytes(tensor));
4089
+ }
4090
  return tensor;
4091
  }
4092
 
 
8312
  return result;
8313
  }
8314
 
8315
+ // opt_step_adamw
8316
 
8317
+ struct ggml_tensor * ggml_opt_step_adamw(
8318
  struct ggml_context * ctx,
8319
+ struct ggml_tensor * a,
8320
+ float alpha,
8321
+ float beta1,
8322
+ float beta2,
8323
+ float eps,
8324
+ float wd) {
8325
+ GGML_ASSERT(a->grad);
8326
+ GGML_ASSERT(alpha > 0.0f);
8327
+ GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
8328
+ GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
8329
+ GGML_ASSERT(eps >= 0.0f);
8330
+ GGML_ASSERT(wd >= 0.0f && wd <= 1.0f);
8331
+
8332
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
8333
+
8334
+ result->op = GGML_OP_OPT_STEP_ADAMW;
8335
+ result->grad = NULL;
8336
+ result->src[0] = a;
8337
+ result->src[1] = a->grad;
8338
+ result->src[2] = ggml_dup_tensor(ctx, a->grad);
8339
+ result->src[3] = ggml_dup_tensor(ctx, a->grad);
8340
+
8341
+ const int64_t iter = 1;
8342
+ memcpy(&result->op_params[0], &iter, sizeof(int64_t));
8343
+ ggml_set_op_params_f32(result, 2, alpha);
8344
+ ggml_set_op_params_f32(result, 3, beta1);
8345
+ ggml_set_op_params_f32(result, 4, beta2);
8346
+ ggml_set_op_params_f32(result, 5, eps);
8347
+ ggml_set_op_params_f32(result, 6, wd);
8348
+
8349
+ return result;
8350
+ }
8351
+
8352
+ ////////////////////////////////////////////////////////////////////////////////
8353
+
8354
+ void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
8355
  tensor->flags |= GGML_TENSOR_FLAG_PARAM;
8356
 
8357
  GGML_ASSERT(tensor->grad == NULL);
 
8359
  ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
8360
  }
8361
 
8362
+ void ggml_set_loss(struct ggml_tensor * tensor) {
8363
+ GGML_ASSERT(ggml_is_scalar(tensor));
8364
+ GGML_ASSERT(tensor->type == GGML_TYPE_F32);
8365
+ GGML_ASSERT(tensor->grad);
8366
+ tensor->flags |= GGML_TENSOR_FLAG_LOSS;
8367
+ }
8368
+
8369
  // ggml_compute_forward_dup
8370
 
8371
  static void ggml_compute_forward_dup_same_cont(
 
17440
  const int64_t ir0 = dr*ith;
17441
  const int64_t ir1 = MIN(ir0 + dr, nr);
17442
 
17443
+ const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
17444
 
17445
  for (int64_t i1 = ir0; i1 < ir1; i1++) {
17446
  float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
 
17464
 
17465
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
17466
  ggml_vec_sub_f32(nc, ds0, ds0, s1);
17467
+ ggml_vec_scale_f32(nc, ds0, d_by_nr);
17468
 
17469
  #ifndef NDEBUG
17470
  for (int i = 0; i < nc; ++i) {
 
17493
  }
17494
  }
17495
 
17496
+ static void ggml_compute_forward_opt_step_adamw_f32(
17497
+ const struct ggml_compute_params * params,
17498
+ struct ggml_tensor * dst) {
17499
+
17500
+ const struct ggml_tensor * src0 = dst->src[0];
17501
+ const struct ggml_tensor * src0_grad = dst->src[1];
17502
+ const struct ggml_tensor * src0_grad_m = dst->src[2];
17503
+ const struct ggml_tensor * src0_grad_v = dst->src[3];
17504
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
17505
+
17506
+ const int ith = params->ith;
17507
+ const int nth = params->nth;
17508
+
17509
+ const int nr = ggml_nrows(src0);
17510
+
17511
+ GGML_TENSOR_UNARY_OP_LOCALS
17512
+ GGML_ASSERT(nb00 == sizeof(float));
17513
+
17514
+ // rows per thread
17515
+ const int dr = (nr + nth - 1)/nth;
17516
+
17517
+ // row range for this thread
17518
+ const int ir0 = dr*ith;
17519
+ const int ir1 = MIN(ir0 + dr, nr);
17520
+
17521
+ /* const float gnorm = 1.0f; */
17522
+ int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
17523
+ const float alpha = ggml_get_op_params_f32(dst, 2);
17524
+ const float beta1 = ggml_get_op_params_f32(dst, 3);
17525
+ const float beta2 = ggml_get_op_params_f32(dst, 4);
17526
+ const float eps = ggml_get_op_params_f32(dst, 5);
17527
+ const float wd = ggml_get_op_params_f32(dst, 6);
17528
+
17529
+ const float beta1h = alpha/(1.0f - powf(beta1, iter));
17530
+ const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
17531
+
17532
+ for (int ir = ir0; ir < ir1; ++ir) {
17533
+ const int64_t i03 = ir/(ne02*ne01);
17534
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
17535
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
17536
+
17537
+ const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
17538
+
17539
+ float * w = (float *) ((char *) src0->data + offset); // weight
17540
+ const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
17541
+ float * m = (float *) ((char *) src0_grad_m->data + offset);
17542
+ float * v = (float *) ((char *) src0_grad_v->data + offset);
17543
+
17544
+ for (int i00 = 0; i00 < ne00; ++i00) {
17545
+ m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
17546
+ v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
17547
+
17548
+ const float mh = m[i00]*beta1h;
17549
+ const float vh = sqrtf(v[i00]*beta2h) + eps;
17550
+
17551
+ // The weight decay is applied independently of the Adam momenta m and v.
17552
+ // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
17553
+ // See: https://arxiv.org/pdf/1711.05101v3.pdf
17554
+ w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
17555
+ }
17556
+ }
17557
+
17558
+ ggml_barrier(params->threadpool);
17559
+ if (ith != 0) {
17560
+ return;
17561
+ }
17562
+
17563
+ iter++;
17564
+ memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
17565
+ }
17566
+
17567
+ static void ggml_compute_forward_opt_step_adamw(
17568
+ const struct ggml_compute_params * params,
17569
+ struct ggml_tensor * dst) {
17570
+
17571
+ const struct ggml_tensor * src0 = dst->src[0];
17572
+
17573
+ switch (src0->type) {
17574
+ case GGML_TYPE_F32:
17575
+ {
17576
+ ggml_compute_forward_opt_step_adamw_f32(params, dst);
17577
+ } break;
17578
+ default:
17579
+ {
17580
+ GGML_ABORT("fatal error");
17581
+ }
17582
+ }
17583
+ }
17584
  /////////////////////////////////
17585
 
17586
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
 
17926
  ggml_compute_forward_cross_entropy_loss_back(params, tensor);
17927
  }
17928
  break;
17929
+ case GGML_OP_OPT_STEP_ADAMW:
17930
+ {
17931
+ ggml_compute_forward_opt_step_adamw(params, tensor);
17932
+ }
17933
+ break;
17934
  case GGML_OP_NONE:
17935
  {
17936
  // nop
 
18085
  struct ggml_tensor * * checkpoints,
18086
  int n_checkpoints) {
18087
  ggml_graph_cpy(gf, gb_tmp);
18088
+ ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);
18089
 
18090
  if (n_checkpoints <= 0) {
18091
  ggml_graph_cpy(gb_tmp, gb);
 
18123
  ggml_hash_map_free(replacements);
18124
  }
18125
 
18126
+ // utility functions to change gradients
18127
+ // if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
18128
+ // else if a is in zero_table, replace a
18129
+ // else, just add/subtract/etc. the gradients
18130
+
18131
+ static struct ggml_tensor * ggml_add_or_set(
18132
+ struct ggml_context * ctx,
18133
+ struct ggml_tensor * a,
18134
+ struct ggml_tensor * b,
18135
+ struct ggml_hash_set * zero_table,
18136
+ struct ggml_hash_set * acc_table) {
18137
+ if (ggml_hash_contains(acc_table, a)) {
18138
+ struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
18139
+ const size_t insert_result = ggml_hash_insert(acc_table, ret);
18140
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18141
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18142
+ return ret;
18143
+ }
18144
  if (ggml_hash_contains(zero_table, a)) {
18145
  return b;
 
 
18146
  }
18147
+ return ggml_add_impl(ctx, a, b, false);
18148
  }
18149
 
18150
+ static struct ggml_tensor * ggml_acc_or_set(
18151
+ struct ggml_context * ctx,
18152
+ struct ggml_tensor * a,
18153
+ struct ggml_tensor * b,
18154
+ const size_t nb1,
18155
+ const size_t nb2,
18156
+ const size_t nb3,
18157
+ const size_t offset,
18158
+ struct ggml_hash_set * zero_table,
18159
+ struct ggml_hash_set * acc_table) {
18160
+ if (ggml_hash_contains(acc_table, a)) {
18161
+ struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
18162
+ const size_t insert_result = ggml_hash_insert(acc_table, ret);
18163
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18164
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18165
+ return ret;
18166
+ }
18167
  if (ggml_hash_contains(zero_table, a)) {
18168
+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
18169
  return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
 
 
18170
  }
18171
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
18172
  }
18173
 
18174
+ static struct ggml_tensor * ggml_add1_or_set(
18175
+ struct ggml_context * ctx,
18176
+ struct ggml_tensor * a,
18177
+ struct ggml_tensor * b,
18178
+ struct ggml_hash_set * zero_table,
18179
+ struct ggml_hash_set * acc_table) {
18180
+ if (ggml_hash_contains(acc_table, a)) {
18181
+ struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
18182
+ const size_t insert_result = ggml_hash_insert(acc_table, ret);
18183
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18184
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18185
+ return ret;
18186
+ }
18187
  if (ggml_hash_contains(zero_table, a)) {
18188
  return ggml_repeat(ctx, b, a);
 
 
18189
  }
18190
+ return ggml_add1_impl(ctx, a, b, false);
18191
  }
18192
 
18193
+ static struct ggml_tensor * ggml_sub_or_set(
18194
+ struct ggml_context * ctx,
18195
+ struct ggml_tensor * a,
18196
+ struct ggml_tensor * b,
18197
+ struct ggml_hash_set * zero_table,
18198
+ struct ggml_hash_set * acc_table) {
18199
+ if (ggml_hash_contains(acc_table, a)) {
18200
+ struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
18201
+ const size_t insert_result = ggml_hash_insert(acc_table, ret);
18202
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18203
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18204
+ return ret;
18205
+ }
18206
  if (ggml_hash_contains(zero_table, a)) {
18207
  return ggml_neg(ctx, b);
 
 
18208
  }
18209
+ return ggml_sub_impl(ctx, a, b, false);
18210
  }
18211
 
18212
+ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) {
18213
  struct ggml_tensor * src0 = tensor->src[0];
18214
  struct ggml_tensor * src1 = tensor->src[1];
18215
  struct ggml_tensor * src2 = tensor->src[2];
 
18218
  case GGML_OP_DUP:
18219
  {
18220
  if (src0->grad) {
18221
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18222
  }
18223
  } break;
18224
  case GGML_OP_ADD:
18225
  {
18226
  if (src0->grad) {
18227
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18228
  }
18229
  if (src1->grad) {
18230
  if (ggml_are_same_shape(src0, src1)) {
18231
+ src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
18232
  } else {
18233
+ src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
18234
  }
18235
  }
18236
  } break;
18237
  case GGML_OP_ADD1:
18238
  {
18239
  if (src0->grad) {
18240
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18241
  }
18242
  if (src1->grad) {
18243
  src1->grad = ggml_add_or_set(ctx,
18244
  src1->grad,
18245
  ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
18246
+ zero_table, acc_table);
18247
  }
18248
  } break;
18249
  case GGML_OP_ACC:
18250
  {
18251
  if (src0->grad) {
18252
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18253
  }
18254
  if (src1->grad) {
18255
  const size_t nb1 = ((int32_t *) tensor->op_params)[0];
 
18271
  ggml_reshape(ctx,
18272
  ggml_cont(ctx, tensor_grad_view),
18273
  src1->grad),
18274
+ zero_table, acc_table);
18275
  }
18276
  } break;
18277
  case GGML_OP_SUB:
18278
  {
18279
  if (src0->grad) {
18280
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18281
  }
18282
  if (src1->grad) {
18283
+ src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
18284
  }
18285
  } break;
18286
  case GGML_OP_MUL:
 
18290
  ggml_add_or_set(ctx,
18291
  src0->grad,
18292
  ggml_mul(ctx, src1, tensor->grad),
18293
+ zero_table, acc_table);
18294
  }
18295
  if (src1->grad) {
18296
  src1->grad =
18297
  ggml_add_or_set(ctx,
18298
  src1->grad,
18299
  ggml_mul(ctx, src0, tensor->grad),
18300
+ zero_table, acc_table);
18301
  }
18302
  } break;
18303
  case GGML_OP_DIV:
 
18307
  ggml_add_or_set(ctx,
18308
  src0->grad,
18309
  ggml_div(ctx, tensor->grad, src1),
18310
+ zero_table, acc_table);
18311
  }
18312
  if (src1->grad) {
18313
  src1->grad =
 
18316
  ggml_mul(ctx,
18317
  tensor->grad,
18318
  ggml_div(ctx, tensor, src1)),
18319
+ zero_table, acc_table);
18320
  }
18321
  } break;
18322
  case GGML_OP_SQR:
 
18328
  ggml_scale(ctx,
18329
  ggml_mul(ctx, src0, tensor->grad),
18330
  2.0f),
18331
+ zero_table, acc_table);
18332
  }
18333
  } break;
18334
  case GGML_OP_SQRT:
 
18342
  tensor->grad,
18343
  tensor),
18344
  0.5f),
18345
+ zero_table, acc_table);
18346
  }
18347
  } break;
18348
  case GGML_OP_LOG:
 
18354
  ggml_div(ctx,
18355
  tensor->grad,
18356
  src0),
18357
+ zero_table, acc_table);
18358
  }
18359
  } break;
18360
  case GGML_OP_SIN:
 
18366
  ggml_mul(ctx,
18367
  tensor->grad,
18368
  ggml_cos(ctx, src0)),
18369
+ zero_table, acc_table);
18370
  }
18371
  } break;
18372
  case GGML_OP_COS:
 
18378
  ggml_mul(ctx,
18379
  tensor->grad,
18380
  ggml_sin(ctx, src0)),
18381
+ zero_table, acc_table);
18382
  }
18383
  } break;
18384
  case GGML_OP_SUM:
 
18388
  ggml_add1_or_set(ctx,
18389
  src0->grad,
18390
  tensor->grad,
18391
+ zero_table, acc_table);
18392
  }
18393
  } break;
18394
  case GGML_OP_SUM_ROWS:
 
18400
  ggml_repeat(ctx,
18401
  tensor->grad,
18402
  src0->grad),
18403
+ zero_table, acc_table);
18404
  }
18405
  } break;
18406
  case GGML_OP_MEAN:
 
18415
  src0->grad = ggml_add_or_set(ctx,
18416
  src0->grad,
18417
  ggml_repeat_back(ctx, tensor->grad, src0->grad),
18418
+ zero_table, acc_table);
18419
  }
18420
  } break;
18421
  case GGML_OP_REPEAT_BACK:
 
18425
  src0->grad = ggml_add_or_set(ctx,
18426
  src0->grad,
18427
  ggml_repeat(ctx, tensor->grad, src0->grad),
18428
+ zero_table, acc_table);
18429
  }
18430
  } break;
18431
  case GGML_OP_CONCAT:
 
18450
  src0->grad = ggml_add_or_set(ctx,
18451
  src0->grad,
18452
  ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
18453
+ zero_table, acc_table);
18454
  }
18455
  } break;
18456
  case GGML_OP_RMS_NORM_BACK:
 
18498
  ggml_add_or_set(ctx,
18499
  src0->grad, // [n,m,q1,r1]
18500
  s1_tg, // [n,m,q1,r1]
18501
+ zero_table, acc_table);
18502
  }
18503
  if (src1->grad) {
18504
  src1->grad =
 
18516
  src0, // [n,m,q1,r1]
18517
  ggml_transpose(ctx, // [p,m,qq,rr]
18518
  tensor->grad)), // [m,p,qq,rr]
18519
+ zero_table, acc_table);
18520
  }
18521
  } break;
18522
  case GGML_OP_MUL_MAT_ID:
 
18538
  ggml_add_or_set(ctx,
18539
  src0->grad,
18540
  ggml_scale_impl(ctx, tensor->grad, s, false),
18541
+ zero_table, acc_table);
18542
  }
18543
  } break;
18544
  case GGML_OP_SET:
 
18567
  tensor->grad,
18568
  ggml_neg(ctx, tensor_grad_view),
18569
  nb1, nb2, nb3, offset, false),
18570
+ zero_table, acc_table);
18571
  }
18572
 
18573
  if (src1->grad) {
 
18577
  ggml_reshape(ctx,
18578
  ggml_cont(ctx, tensor_grad_view),
18579
  src1->grad),
18580
+ zero_table, acc_table);
18581
  }
18582
  } break;
18583
  case GGML_OP_CPY:
 
18588
  // tensor = src0 * 1 + src1 * 0
18589
  if (src0->grad) {
18590
  // dsrc0 = dtensor * 1
18591
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18592
  }
18593
  if (src1->grad) {
18594
  // dsrc1 = dtensor * 0 -> noop
 
18600
  if (src0->grad) {
18601
  GGML_ASSERT(ggml_is_contiguous(src0->grad));
18602
  GGML_ASSERT(ggml_is_contiguous(tensor->grad));
18603
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18604
  }
18605
  } break;
18606
  case GGML_OP_RESHAPE:
 
18614
  ? tensor->grad
18615
  : ggml_cont(ctx, tensor->grad),
18616
  src0->grad),
18617
+ zero_table, acc_table);
18618
  }
18619
  } break;
18620
  case GGML_OP_VIEW:
 
18643
  nb3 = (nb3 / n0) * ng;
18644
  }
18645
 
18646
+ src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
18647
  }
18648
  } break;
18649
  case GGML_OP_PERMUTE:
 
18668
  axes_backward[1],
18669
  axes_backward[2],
18670
  axes_backward[3]),
18671
+ zero_table, acc_table);
18672
  }
18673
  } break;
18674
  case GGML_OP_TRANSPOSE:
 
18678
  src0->grad =
18679
  ggml_add_or_set(ctx, src0->grad,
18680
  ggml_transpose(ctx, tensor->grad),
18681
+ zero_table, acc_table);
18682
  }
18683
  } break;
18684
  case GGML_OP_GET_ROWS:
 
18690
  // last ggml_get_rows_back argument src0->grad is only
18691
  // necessary to setup correct output shape
18692
  ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
18693
+ zero_table, acc_table);
18694
  }
18695
  if (src1->grad) {
18696
  // noop
 
18714
  /* ggml_diag_mask_inf_impl() shouldn't be here */
18715
  /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
18716
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
18717
+ zero_table, acc_table);
18718
  }
18719
  } break;
18720
  case GGML_OP_DIAG_MASK_ZERO:
 
18725
  src0->grad =
18726
  ggml_add_or_set(ctx, src0->grad,
18727
  ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
18728
+ zero_table, acc_table);
18729
  }
18730
  } break;
18731
  case GGML_OP_SOFT_MAX:
 
18735
  src0->grad =
18736
  ggml_add_or_set(ctx, src0->grad,
18737
  ggml_soft_max_back(ctx, tensor->grad, tensor),
18738
+ zero_table, acc_table);
18739
  }
18740
 
18741
  } break;
 
18776
  attn_factor,
18777
  beta_fast,
18778
  beta_slow),
18779
+ zero_table, acc_table);
18780
  }
18781
  } break;
18782
  case GGML_OP_ROPE_BACK:
 
18812
  beta_fast,
18813
  beta_slow,
18814
  false),
18815
+ zero_table, acc_table);
18816
  }
18817
  } break;
18818
  case GGML_OP_CLAMP:
 
18837
  src1->grad = ggml_add_or_set(ctx,
18838
  src1->grad,
18839
  ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
18840
+ zero_table, acc_table);
18841
  }
18842
  } break;
18843
  case GGML_OP_IM2COL_BACK:
 
18866
  src0->grad = ggml_add_or_set(ctx,
18867
  src0->grad,
18868
  ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
18869
+ zero_table, acc_table);
18870
  }
18871
  } break;
18872
  case GGML_OP_POOL_2D_BACK:
 
18931
  src0->grad = ggml_add_or_set(ctx,
18932
  src0->grad,
18933
  grad_q,
18934
+ zero_table, acc_table);
18935
  }
18936
  if (src1->grad) {
18937
  struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
 
18939
  src1->grad = ggml_add_or_set(ctx,
18940
  src1->grad,
18941
  grad_k,
18942
+ zero_table, acc_table);
18943
  }
18944
  if (src2->grad) {
18945
  struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
 
18947
  src2->grad = ggml_add_or_set(ctx,
18948
  src2->grad,
18949
  grad_v,
18950
+ zero_table, acc_table);
18951
  }
18952
  } break;
18953
  case GGML_OP_FLASH_ATTN_BACK:
 
18973
  ggml_mul(ctx,
18974
  ggml_sgn(ctx, src0),
18975
  tensor->grad),
18976
+ zero_table, acc_table);
18977
  }
18978
  } break;
18979
  case GGML_UNARY_OP_SGN:
 
18985
  case GGML_UNARY_OP_NEG:
18986
  {
18987
  if (src0->grad) {
18988
+ src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18989
  }
18990
  } break;
18991
  case GGML_UNARY_OP_STEP:
 
19010
  ggml_mul(ctx,
19011
  ggml_step(ctx, src0),
19012
  tensor->grad),
19013
+ zero_table, acc_table);
19014
  }
19015
  } break;
19016
  case GGML_UNARY_OP_SIGMOID:
 
19032
  src0->grad = ggml_add_or_set(ctx,
19033
  src0->grad,
19034
  ggml_silu_back(ctx, src0, tensor->grad),
19035
+ zero_table, acc_table);
19036
  }
19037
  } break;
19038
  case GGML_UNARY_OP_EXP:
 
19041
  src0->grad = ggml_add_or_set(ctx,
19042
  src0->grad,
19043
  ggml_mul(ctx, tensor, tensor->grad),
19044
+ zero_table, acc_table);
19045
  }
19046
  } break;
19047
  default:
 
19071
  src0,
19072
  src1,
19073
  tensor->grad),
19074
+ zero_table, acc_table);
19075
  }
19076
  } break;
19077
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
19078
  {
19079
  GGML_ABORT("fatal error"); // not supported
19080
  }
19081
+ case GGML_OP_OPT_STEP_ADAMW:
19082
+ {
19083
+ GGML_ABORT("fatal error"); // not supported
19084
+ }
19085
  case GGML_OP_NONE:
19086
  {
19087
  // nop
 
19171
  ggml_build_forward_impl(cgraph, tensor, true);
19172
  }
19173
 
19174
+ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep) {
19175
  GGML_ASSERT(gf->n_nodes > 0);
19176
  GGML_ASSERT(gf->grads);
19177
 
 
19187
  }
19188
  }
19189
 
19190
+ // keep tables of original gradients for replacement/accumulation logic
19191
  struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
19192
+ struct ggml_hash_set acc_table = ggml_hash_set_new(gf->size);
19193
  for (int i = 0; i < gf->n_nodes; i++) {
19194
+ struct ggml_tensor * node = gf->nodes[i];
19195
+
19196
+ if (node->grad) {
19197
+ {
19198
+ const size_t insert_result = ggml_hash_insert(&zero_table, node->grad);
19199
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
19200
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
19201
+ }
19202
+
19203
+ // only gradients of trainable parameters should be accumulated
19204
+ if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
19205
+ const size_t insert_result = ggml_hash_insert(&acc_table, node->grad);
19206
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
19207
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
19208
+ }
19209
  }
19210
  }
19211
 
19212
  for (int i = gf->n_nodes - 1; i >= 0; i--) {
19213
  struct ggml_tensor * node = gf->nodes[i];
19214
 
19215
+ // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
19216
  // use allocator to automatically make inplace operations
19217
  if (node->grad) {
19218
+ ggml_compute_backward(ctx, node, &zero_table, &acc_table);
19219
  }
19220
  }
19221
 
 
19229
  }
19230
 
19231
  ggml_hash_set_free(&zero_table);
19232
+ ggml_hash_set_free(&acc_table);
19233
+ }
19234
+
19235
+ void ggml_build_opt_adamw(
19236
+ struct ggml_context * ctx,
19237
+ struct ggml_cgraph * gf,
19238
+ struct ggml_cgraph * gb,
19239
+ float alpha,
19240
+ float beta1,
19241
+ float beta2,
19242
+ float eps,
19243
+ float wd) {
19244
+ for (int i = 0; i < gf->n_nodes; i++) {
19245
+ struct ggml_tensor * node = gf->nodes[i];
19246
+
19247
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
19248
+ GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
19249
+ struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
19250
+ ggml_build_forward_expand(gb, opt_step);
19251
+ }
19252
+ }
19253
  }
19254
 
19255
+
19256
  static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
19257
  void * ptr = *p;
19258
  ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
 
19380
  GGML_ASSERT(cgraph->grads != NULL);
19381
 
19382
  for (int i = 0; i < cgraph->n_nodes; i++) {
19383
+ struct ggml_tensor * node = cgraph->nodes[i];
19384
+
19385
+ // initial gradients of loss should be 1, 0 otherwise
19386
+ if (node->grad) {
19387
+ if (node->flags & GGML_TENSOR_FLAG_LOSS) {
19388
+ GGML_ASSERT(node->grad->buffer);
19389
+ GGML_ASSERT(node->type == GGML_TYPE_F32);
19390
+ GGML_ASSERT(ggml_is_scalar(node));
19391
+
19392
+ const float onef = 1.0f;
19393
+ ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
19394
+ } else {
19395
+ ggml_set_zero(node->grad);
19396
+ }
19397
+ }
19398
 
19399
+ GGML_ASSERT(node);
19400
+ if (node->op == GGML_OP_OPT_STEP_ADAMW) {
19401
+ // set iteration to 1 and clear momenta
19402
+ ggml_set_op_params_i32(node, 0, 1);
19403
+ ggml_set_zero(node->src[2]);
19404
+ ggml_set_zero(node->src[3]);
19405
  }
19406
  }
19407
  }
 
19666
  } break;
19667
  case GGML_OP_CROSS_ENTROPY_LOSS:
19668
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
19669
+ case GGML_OP_OPT_STEP_ADAMW:
19670
  {
19671
  n_tasks = n_threads;
19672
  } break;
 
22029
  ggml_build_forward_expand(gf, f);
22030
 
22031
  struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
22032
+ ggml_build_backward_expand(ctx, gf, gb, false, true);
22033
 
22034
  return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
22035
  }