ggerganov commited on
Commit
895e87a
·
unverified ·
1 Parent(s): 5704da9

sync : ggml (new ops, new backend, etc) (#1602)

Browse files

* sync : ggml (new ops, new backend, etc)

* whisper : remove obsolete broadcasting code

* ggml : remove backend self-registers + fix ggml_concat + n_task logic

* metal : fix assert

* metal : print resource path

* whisper : fix bug if metal init fails

Files changed (16) hide show
  1. ggml-alloc.c +43 -8
  2. ggml-alloc.h +7 -0
  3. ggml-backend-impl.h +46 -21
  4. ggml-backend.c +563 -156
  5. ggml-backend.h +62 -17
  6. ggml-cuda.cu +1261 -428
  7. ggml-cuda.h +9 -1
  8. ggml-impl.h +1 -1
  9. ggml-metal.h +4 -5
  10. ggml-metal.m +537 -187
  11. ggml-metal.metal +888 -237
  12. ggml-opencl.cpp +5 -7
  13. ggml-quants.c +1 -1
  14. ggml.c +478 -134
  15. ggml.h +60 -7
  16. whisper.cpp +24 -54
ggml-alloc.c CHANGED
@@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
137
 
138
  #ifdef GGML_ALLOCATOR_DEBUG
139
  add_allocated_tensor(alloc, tensor);
140
- size_t cur_max = (char*)addr - (char*)alloc->data + size;
141
  if (cur_max > alloc->max_size) {
142
  printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
143
  for (int i = 0; i < 1024; i++) {
@@ -168,10 +168,6 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
168
  size = aligned_offset(NULL, size, alloc->alignment);
169
  AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
170
 
171
- if (!alloc->measure) {
172
- ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
173
- }
174
-
175
  #ifdef GGML_ALLOCATOR_DEBUG
176
  remove_allocated_tensor(alloc, tensor);
177
  #endif
@@ -237,7 +233,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
237
  }
238
 
239
  ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
240
- struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
241
 
242
  ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
243
 
@@ -449,7 +445,6 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
449
  static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
450
  ggml_tallocr_t alloc = node_tallocr(galloc, view);
451
 
452
- //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
453
  GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
454
  if (update_backend) {
455
  view->backend = view->view_src->backend;
@@ -459,7 +454,7 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
459
 
460
  // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
461
  // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
462
- assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
463
 
464
  if (!alloc->measure) {
465
  ggml_backend_buffer_init_tensor(alloc->buffer, view);
@@ -765,3 +760,43 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
765
  size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
766
  return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
767
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  #ifdef GGML_ALLOCATOR_DEBUG
139
  add_allocated_tensor(alloc, tensor);
140
+ size_t cur_max = (char*)addr - (char*)alloc->base + size;
141
  if (cur_max > alloc->max_size) {
142
  printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
143
  for (int i = 0; i < 1024; i++) {
 
168
  size = aligned_offset(NULL, size, alloc->alignment);
169
  AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
170
 
 
 
 
 
171
  #ifdef GGML_ALLOCATOR_DEBUG
172
  remove_allocated_tensor(alloc, tensor);
173
  #endif
 
233
  }
234
 
235
  ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
236
+ struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
237
 
238
  ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
239
 
 
445
  static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
446
  ggml_tallocr_t alloc = node_tallocr(galloc, view);
447
 
 
448
  GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
449
  if (update_backend) {
450
  view->backend = view->view_src->backend;
 
454
 
455
  // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
456
  // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
457
+ assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
458
 
459
  if (!alloc->measure) {
460
  ggml_backend_buffer_init_tensor(alloc->buffer, view);
 
760
  size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
761
  return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
762
  }
763
+
764
+ // utils
765
+ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
766
+ GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
767
+
768
+ size_t alignment = ggml_backend_buft_get_alignment(buft);
769
+
770
+ size_t nbytes = 0;
771
+ for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
772
+ if (t->data == NULL && t->view_src == NULL) {
773
+ nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
774
+ }
775
+ }
776
+
777
+ if (nbytes == 0) {
778
+ fprintf(stderr, "%s: no tensors to allocate\n", __func__);
779
+ return NULL;
780
+ }
781
+
782
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
783
+ ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
784
+
785
+ for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
786
+ if (t->data == NULL) {
787
+ if (t->view_src == NULL) {
788
+ ggml_tallocr_alloc(tallocr, t);
789
+ } else {
790
+ ggml_backend_view_init(buffer, t);
791
+ }
792
+ }
793
+ }
794
+
795
+ ggml_tallocr_free(tallocr);
796
+
797
+ return buffer;
798
+ }
799
+
800
+ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
801
+ return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
802
+ }
ggml-alloc.h CHANGED
@@ -8,6 +8,7 @@ extern "C" {
8
 
9
  struct ggml_backend;
10
  struct ggml_backend_buffer;
 
11
 
12
  //
13
  // Legacy API
@@ -80,6 +81,12 @@ GGML_API void ggml_gallocr_alloc_graph_n(
80
  struct ggml_hash_set hash_set,
81
  ggml_tallocr_t * hash_node_talloc);
82
 
 
 
 
 
 
 
83
  #ifdef __cplusplus
84
  }
85
  #endif
 
8
 
9
  struct ggml_backend;
10
  struct ggml_backend_buffer;
11
+ struct ggml_backend_buffer_type;
12
 
13
  //
14
  // Legacy API
 
81
  struct ggml_hash_set hash_set,
82
  ggml_tallocr_t * hash_node_talloc);
83
 
84
+
85
+ // Utils
86
+ // Create a buffer and allocate all the tensors in a ggml_context
87
+ GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
88
+ GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
89
+
90
  #ifdef __cplusplus
91
  }
92
  #endif
ggml-backend-impl.h CHANGED
@@ -12,31 +12,50 @@ extern "C" {
12
  // Backend buffer
13
  //
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  typedef void * ggml_backend_buffer_context_t;
16
 
17
  struct ggml_backend_buffer_i {
18
- void (*free_buffer) (ggml_backend_buffer_t buffer);
19
- void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
20
- size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
21
- void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
22
- void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
 
 
 
 
23
  };
24
 
25
  struct ggml_backend_buffer {
26
- struct ggml_backend_buffer_i iface;
27
-
28
- ggml_backend_t backend;
29
  ggml_backend_buffer_context_t context;
30
-
31
  size_t size;
32
  };
33
 
34
- GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
35
- struct ggml_backend * backend,
36
  struct ggml_backend_buffer_i iface,
37
  ggml_backend_buffer_context_t context,
38
  size_t size);
39
 
 
40
  //
41
  // Backend
42
  //
@@ -49,20 +68,17 @@ extern "C" {
49
  void (*free)(ggml_backend_t backend);
50
 
51
  // buffer allocation
52
- ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
53
 
54
- // get buffer alignment
55
- size_t (*get_alignment)(ggml_backend_t backend);
56
-
57
- // tensor data access
58
- // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
59
  void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
60
  void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
61
- void (*synchronize) (ggml_backend_t backend);
62
 
63
- // (optional) copy tensor between different backends, allow for single-copy tranfers
64
- void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
65
- void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
 
 
66
 
67
  // compute graph with a plan
68
  ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
@@ -82,6 +98,15 @@ extern "C" {
82
  ggml_backend_context_t context;
83
  };
84
 
 
 
 
 
 
 
 
 
 
85
  #ifdef __cplusplus
86
  }
87
  #endif
 
12
  // Backend buffer
13
  //
14
 
15
+ // buffer type
16
+ typedef void * ggml_backend_buffer_type_context_t;
17
+
18
+ struct ggml_backend_buffer_type_i {
19
+ ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
20
+ size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
21
+ size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
22
+ bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
23
+ };
24
+
25
+ struct ggml_backend_buffer_type {
26
+ struct ggml_backend_buffer_type_i iface;
27
+ ggml_backend_buffer_type_context_t context;
28
+ };
29
+
30
+ // buffer
31
  typedef void * ggml_backend_buffer_context_t;
32
 
33
  struct ggml_backend_buffer_i {
34
+ void (*free_buffer)(ggml_backend_buffer_t buffer);
35
+ //void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
36
+ void * (*get_base) (ggml_backend_buffer_t buffer);
37
+ void (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
38
+ void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
39
+ void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
40
+ // (optional) copy tensor between different buffer-type, allow for single-copy tranfers
41
+ void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
42
+ void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
43
  };
44
 
45
  struct ggml_backend_buffer {
46
+ struct ggml_backend_buffer_i iface;
47
+ ggml_backend_buffer_type_t buft;
 
48
  ggml_backend_buffer_context_t context;
 
49
  size_t size;
50
  };
51
 
52
+ ggml_backend_buffer_t ggml_backend_buffer_init(
53
+ ggml_backend_buffer_type_t buft,
54
  struct ggml_backend_buffer_i iface,
55
  ggml_backend_buffer_context_t context,
56
  size_t size);
57
 
58
+
59
  //
60
  // Backend
61
  //
 
68
  void (*free)(ggml_backend_t backend);
69
 
70
  // buffer allocation
71
+ ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
72
 
73
+ // (optional) asynchroneous tensor data access
 
 
 
 
74
  void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
75
  void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
 
76
 
77
+ // (optional) asynchroneous tensor copy
78
+ void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
79
+ void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
80
+
81
+ void (*synchronize) (ggml_backend_t backend);
82
 
83
  // compute graph with a plan
84
  ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
 
98
  ggml_backend_context_t context;
99
  };
100
 
101
+
102
+ //
103
+ // Backend registry
104
+ //
105
+
106
+ typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
107
+
108
+ void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
109
+
110
  #ifdef __cplusplus
111
  }
112
  #endif
ggml-backend.c CHANGED
@@ -9,14 +9,36 @@
9
  #include <stdlib.h>
10
  #include <string.h>
11
 
12
- #define UNUSED GGML_UNUSED
13
 
14
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  // backend buffer
17
 
18
  ggml_backend_buffer_t ggml_backend_buffer_init(
19
- struct ggml_backend * backend,
20
  struct ggml_backend_buffer_i iface,
21
  ggml_backend_buffer_context_t context,
22
  size_t size) {
@@ -26,7 +48,7 @@ ggml_backend_buffer_t ggml_backend_buffer_init(
26
 
27
  (*buffer) = (struct ggml_backend_buffer) {
28
  /* .interface = */ iface,
29
- /* .backend = */ backend,
30
  /* .context = */ context,
31
  /* .size = */ size,
32
  };
@@ -45,10 +67,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
45
  free(buffer);
46
  }
47
 
48
- size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
49
- return ggml_backend_get_alignment(buffer->backend);
50
- }
51
-
52
  size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
53
  return buffer->size;
54
  }
@@ -61,14 +79,6 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
61
  return base;
62
  }
63
 
64
- size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
65
- // get_alloc_size is optional, defaults to ggml_nbytes
66
- if (buffer->iface.get_alloc_size) {
67
- return buffer->iface.get_alloc_size(buffer, tensor);
68
- }
69
- return ggml_nbytes(tensor);
70
- }
71
-
72
  void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
73
  // init_tensor is optional
74
  if (buffer->iface.init_tensor) {
@@ -76,19 +86,20 @@ void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_t
76
  }
77
  }
78
 
79
- void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
80
- // free_tensor is optional
81
- if (buffer->iface.free_tensor) {
82
- buffer->iface.free_tensor(buffer, tensor);
83
- }
84
  }
85
 
86
- // backend
 
 
87
 
88
- ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) {
89
- return tensor->buffer ? tensor->buffer->backend : NULL;
90
  }
91
 
 
 
92
  const char * ggml_backend_name(ggml_backend_t backend) {
93
  if (backend == NULL) {
94
  return "NULL";
@@ -104,43 +115,53 @@ void ggml_backend_free(ggml_backend_t backend) {
104
  backend->iface.free(backend);
105
  }
106
 
 
 
 
 
107
  ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
108
- return backend->iface.alloc_buffer(backend, size);
109
  }
110
 
111
  size_t ggml_backend_get_alignment(ggml_backend_t backend) {
112
- return backend->iface.get_alignment(backend);
113
  }
114
 
115
- void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
116
- ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
 
 
 
117
  }
118
 
119
- void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
120
- ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
 
 
 
121
  }
122
 
123
  void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
124
- ggml_backend_t backend = ggml_get_backend(tensor);
125
-
126
  GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
127
- GGML_ASSERT(backend != NULL && "tensor backend not set");
 
128
 
129
- backend->iface.set_tensor_async(backend, tensor, data, offset, size);
130
- backend->iface.synchronize(backend);
131
  }
132
 
133
  void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
134
- ggml_backend_t backend = ggml_get_backend(tensor);
135
-
136
  GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
137
- GGML_ASSERT(backend != NULL && "tensor backend not set");
 
138
 
139
- backend->iface.get_tensor_async(backend, tensor, data, offset, size);
140
- backend->iface.synchronize(backend);
141
  }
142
 
143
  void ggml_backend_synchronize(ggml_backend_t backend) {
 
 
 
 
144
  backend->iface.synchronize(backend);
145
  }
146
 
@@ -154,10 +175,16 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
154
 
155
  void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
156
  backend->iface.graph_plan_compute(backend, plan);
 
 
 
157
  }
158
 
159
  void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
160
  backend->iface.graph_compute(backend, cgraph);
 
 
 
161
  }
162
 
163
  bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -194,14 +221,15 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
194
 
195
  // TODO: allow backends to support copy to/from same backend
196
 
197
- if (ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) {
198
- ggml_get_backend(dst)->iface.cpy_tensor_from(ggml_get_backend(dst)->context, src, dst);
199
- } else if (ggml_get_backend(src)->iface.cpy_tensor_to != NULL) {
200
- ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst);
201
  } else {
202
  // shouldn't be hit when copying from/to CPU
203
  #ifndef NDEBUG
204
- fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend));
 
205
  #endif
206
  size_t nbytes = ggml_nbytes(src);
207
  void * data = malloc(nbytes);
@@ -211,101 +239,259 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
211
  }
212
  }
213
 
214
- // backend CPU
215
 
216
- struct ggml_backend_cpu_context {
217
- int n_threads;
218
- void * work_data;
219
- size_t work_size;
 
 
 
220
  };
221
 
222
- static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
223
- return "CPU";
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- UNUSED(backend);
 
 
 
 
 
 
 
 
 
 
 
 
226
  }
227
 
228
- static void ggml_backend_cpu_free(ggml_backend_t backend) {
229
- struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
230
- free(cpu_ctx->work_data);
231
- free(cpu_ctx);
232
- free(backend);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  }
234
 
 
 
235
  static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
236
  return (void *)buffer->context;
237
  }
238
 
239
  static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
240
  free(buffer->context);
241
- UNUSED(buffer);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  }
243
 
244
  static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
245
- /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
246
- /* .get_base = */ ggml_backend_cpu_buffer_get_base,
247
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
248
- /* .init_tensor = */ NULL, // no initialization required
249
- /* .free_tensor = */ NULL, // no cleanup required
 
 
250
  };
251
 
252
  // for buffers from ptr, free is not called
253
  static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
254
- /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
255
- /* .get_base = */ ggml_backend_cpu_buffer_get_base,
256
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
257
- /* .init_tensor = */ NULL,
258
- /* .free_tensor = */ NULL,
 
 
259
  };
260
 
261
  static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
262
 
263
- static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) {
264
  size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
265
  void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
266
 
267
  GGML_ASSERT(data != NULL && "failed to allocate buffer");
268
 
269
- return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size);
270
  }
271
 
272
- static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) {
273
  return TENSOR_ALIGNMENT;
274
- UNUSED(backend);
275
- }
276
 
277
- static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
278
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
279
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
280
 
281
- memcpy((char *)tensor->data + offset, data, size);
 
282
 
283
- UNUSED(backend);
284
  }
285
 
286
- static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
287
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
288
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
289
-
290
- memcpy(data, (const char *)tensor->data + offset, size);
 
 
 
 
 
291
 
292
- UNUSED(backend);
293
  }
294
 
295
- static void ggml_backend_cpu_synchronize(ggml_backend_t backend) {
296
- UNUSED(backend);
297
- }
 
 
298
 
299
- static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
300
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
301
 
302
- UNUSED(backend);
303
  }
304
 
305
- static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
306
- ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
 
 
 
 
 
 
 
307
 
308
- UNUSED(backend);
309
  }
310
 
311
  struct ggml_backend_plan_cpu {
@@ -334,7 +520,7 @@ static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backen
334
  free(cpu_plan->cplan.work_data);
335
  free(cpu_plan);
336
 
337
- UNUSED(backend);
338
  }
339
 
340
  static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
@@ -342,7 +528,7 @@ static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_bac
342
 
343
  ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
344
 
345
- UNUSED(backend);
346
  }
347
 
348
  static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
@@ -363,25 +549,25 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c
363
 
364
  static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
365
  return true;
366
- UNUSED(backend);
367
- UNUSED(op);
 
368
  }
369
 
370
  static struct ggml_backend_i cpu_backend_i = {
371
- /* .get_name = */ ggml_backend_cpu_name,
372
- /* .free = */ ggml_backend_cpu_free,
373
- /* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer,
374
- /* .get_alignment = */ ggml_backend_cpu_get_alignment,
375
- /* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async,
376
- /* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async,
377
- /* .synchronize = */ ggml_backend_cpu_synchronize,
378
- /* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from,
379
- /* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to,
380
- /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
381
- /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
382
- /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
383
- /* .graph_compute = */ ggml_backend_cpu_graph_compute,
384
- /* .supports_op = */ ggml_backend_cpu_supports_op,
385
  };
386
 
387
  ggml_backend_t ggml_backend_cpu_init(void) {
@@ -411,10 +597,18 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
411
  ctx->n_threads = n_threads;
412
  }
413
 
414
- ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) {
415
- return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size);
 
 
 
 
 
 
 
416
  }
417
 
 
418
  // scheduler
419
 
420
  #define GGML_MAX_BACKENDS 4
@@ -427,7 +621,7 @@ struct ggml_backend_sched_split {
427
  int i_end;
428
  struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
429
  int n_inputs;
430
- struct ggml_cgraph * graph;
431
  };
432
 
433
  struct ggml_backend_sched {
@@ -453,7 +647,7 @@ struct ggml_backend_sched {
453
  #else
454
  __attribute__((aligned(GGML_MEM_ALIGN)))
455
  #endif
456
- char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)];
457
  };
458
 
459
  #define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
@@ -482,23 +676,57 @@ static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr)
482
  return INT_MAX;
483
  }
484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  // returns the backend that should be used for the node based on the current locations
486
- char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
487
  static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
488
  // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
489
  // ie. kv cache updates
490
  // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
491
  // dst
492
- ggml_backend_t cur_backend = ggml_get_backend(node);
493
  if (cur_backend != NULL) {
494
- sprintf(causes[hash_id(node)], "1.dst");
495
  return cur_backend;
496
  }
497
 
498
  // view_src
499
- if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) {
500
- sprintf(causes[hash_id(node)], "1.vsrc");
501
- return ggml_get_backend(node->view_src);
502
  }
503
 
504
  // src
@@ -510,7 +738,7 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
510
  if (src == NULL) {
511
  break;
512
  }
513
- ggml_backend_t src_backend = ggml_get_backend(src);
514
  if (src_backend != NULL) {
515
  int src_prio = sched_backend_prio(sched, src_backend);
516
  size_t src_size = ggml_nbytes(src);
@@ -518,7 +746,7 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
518
  cur_prio = src_prio;
519
  cur_size = src_size;
520
  cur_backend = src_backend;
521
- sprintf(causes[hash_id(node)], "1.src%d", i);
522
  }
523
  }
524
  }
@@ -539,10 +767,12 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra
539
  int cur_split = 0;
540
  for (int i = 0; i < graph->n_nodes; i++) {
541
  if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
542
- ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend;
543
- fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs);
 
544
  for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
545
- fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
 
546
  }
547
  fprintf(stderr, "\n");
548
  cur_split++;
@@ -552,16 +782,18 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra
552
  continue;
553
  }
554
  ggml_tallocr_t node_allocr = node_allocr(node);
555
- ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL;
556
- fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]);
 
557
  for (int j = 0; j < GGML_MAX_SRC; j++) {
558
  struct ggml_tensor * src = node->src[j];
559
  if (src == NULL) {
560
  break;
561
  }
562
  ggml_tallocr_t src_allocr = node_allocr(src);
563
- ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL;
564
- fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]);
 
565
  }
566
  fprintf(stderr, "\n");
567
  }
@@ -587,9 +819,9 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
587
  sched->n_splits = 0;
588
 
589
  struct ggml_init_params params = {
590
- /*.mem_size = */ sizeof(sched->context_buffer),
591
- /*.mem_buffer = */ sched->context_buffer,
592
- /*.no_alloc = */ true
593
  };
594
 
595
  if (sched->ctx != NULL) {
@@ -605,9 +837,9 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
605
  // do not overwrite user assignments
606
  continue;
607
  }
608
- ggml_backend_t leaf_backend = ggml_get_backend(leaf);
609
  if (leaf_backend == NULL && leaf->view_src != NULL) {
610
- leaf_backend = ggml_get_backend(leaf->view_src);
611
  }
612
  if (leaf_backend != NULL) {
613
  node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
@@ -649,7 +881,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
649
  cur_prio = src_prio;
650
  cur_size = src_size;
651
  node_allocr = src_allocr;
652
- sprintf(causes[hash_id(node)], "2.src%d", j);
653
  }
654
  }
655
  }
@@ -733,7 +965,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
733
  struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
734
  sched->node_copies[id][cur_backend_id] = tensor_copy;
735
  node_allocr(tensor_copy) = cur_allocr;
736
- ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend;
737
  ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
738
  }
739
  node->src[j] = sched->node_copies[id][cur_backend_id];
@@ -761,8 +993,8 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
761
  ggml_tallocr_t src_allocr = node_allocr(src);
762
  if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
763
  fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
764
- node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL",
765
- j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL");
766
  }
767
  }
768
  }
@@ -773,7 +1005,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
773
  struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
774
  for (int i = 0; i < sched->n_splits; i++) {
775
  struct ggml_backend_sched_split * split = &sched->splits[i];
776
- split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end);
777
 
778
  // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
779
  for (int j = 0; j < split->n_inputs; j++) {
@@ -806,31 +1038,29 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
806
 
807
  for (int i = 0; i < sched->n_splits; i++) {
808
  struct ggml_backend_sched_split * split = &splits[i];
809
- ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend;
810
  int split_backend_id = sched_backend_prio(sched, split_backend);
811
 
812
  // copy the input tensors to the split backend
813
  uint64_t copy_start_us = ggml_time_us();
814
  for (int j = 0; j < split->n_inputs; j++) {
815
- struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)];
816
- if (split->inputs[j]->buffer == NULL) {
817
- if (split->inputs[j]->view_src == NULL) {
818
- fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name);
 
819
  exit(1);
820
  }
821
- struct ggml_tensor * view = split->inputs[j];
822
- view->backend = view->view_src->backend;
823
- view->buffer = view->view_src->buffer;
824
- view->data = (char *)view->view_src->data + view->view_offs;
825
- ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view);
826
  }
827
  if (input_cpy->buffer == NULL) {
828
  fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
829
  exit(1);
830
  }
831
- GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend);
832
- GGML_ASSERT(input_cpy->buffer->backend == split_backend);
833
- ggml_backend_tensor_copy(split->inputs[j], input_cpy);
834
  }
835
  // ggml_backend_synchronize(split_backend);
836
  int64_t copy_end_us = ggml_time_us();
@@ -843,7 +1073,7 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
843
  #endif
844
 
845
  uint64_t compute_start_us = ggml_time_us();
846
- ggml_backend_graph_compute(split_backend, split->graph);
847
  // ggml_backend_synchronize(split_backend);
848
  uint64_t compute_end_us = ggml_time_us();
849
  compute_us[split_backend_id] += compute_end_us - compute_start_us;
@@ -872,8 +1102,6 @@ ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_bac
872
  struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
873
  memset(sched, 0, sizeof(struct ggml_backend_sched));
874
 
875
- fprintf(stderr, "ggml_backend_sched size: %zu KB\n", sizeof(struct ggml_backend_sched)/1024);
876
-
877
  sched->n_backends = n_backends;
878
  for (int i = 0; i < n_backends; i++) {
879
  sched->backends[i] = backends[i];
@@ -948,3 +1176,182 @@ void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml
948
  GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
949
  node_allocr(node) = sched->tallocs[backend_index];
950
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  #include <stdlib.h>
10
  #include <string.h>
11
 
 
12
 
13
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
14
 
15
+
16
+ // backend buffer type
17
+
18
+ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
19
+ return buft->iface.alloc_buffer(buft, size);
20
+ }
21
+
22
+ size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
23
+ return buft->iface.get_alignment(buft);
24
+ }
25
+
26
+ size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
27
+ // get_alloc_size is optional, defaults to ggml_nbytes
28
+ if (buft->iface.get_alloc_size) {
29
+ return buft->iface.get_alloc_size(buft, tensor);
30
+ }
31
+ return ggml_nbytes(tensor);
32
+ }
33
+
34
+ bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
35
+ return buft->iface.supports_backend(buft, backend);
36
+ }
37
+
38
  // backend buffer
39
 
40
  ggml_backend_buffer_t ggml_backend_buffer_init(
41
+ ggml_backend_buffer_type_t buft,
42
  struct ggml_backend_buffer_i iface,
43
  ggml_backend_buffer_context_t context,
44
  size_t size) {
 
48
 
49
  (*buffer) = (struct ggml_backend_buffer) {
50
  /* .interface = */ iface,
51
+ /* .buft = */ buft,
52
  /* .context = */ context,
53
  /* .size = */ size,
54
  };
 
67
  free(buffer);
68
  }
69
 
 
 
 
 
70
  size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
71
  return buffer->size;
72
  }
 
79
  return base;
80
  }
81
 
 
 
 
 
 
 
 
 
82
  void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
83
  // init_tensor is optional
84
  if (buffer->iface.init_tensor) {
 
86
  }
87
  }
88
 
89
+ size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) {
90
+ return ggml_backend_buft_get_alignment(ggml_backend_buffer_type(buffer));
 
 
 
91
  }
92
 
93
+ size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
94
+ return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type(buffer), tensor);
95
+ }
96
 
97
+ ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer) {
98
+ return buffer->buft;
99
  }
100
 
101
+ // backend
102
+
103
  const char * ggml_backend_name(ggml_backend_t backend) {
104
  if (backend == NULL) {
105
  return "NULL";
 
115
  backend->iface.free(backend);
116
  }
117
 
118
+ ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
119
+ return backend->iface.get_default_buffer_type(backend);
120
+ }
121
+
122
  ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
123
+ return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size);
124
  }
125
 
126
  size_t ggml_backend_get_alignment(ggml_backend_t backend) {
127
+ return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend));
128
  }
129
 
130
+ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
131
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
132
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
133
+
134
+ backend->iface.set_tensor_async(backend, tensor, data, offset, size);
135
  }
136
 
137
+ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
138
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
139
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
140
+
141
+ backend->iface.get_tensor_async(backend, tensor, data, offset, size);
142
  }
143
 
144
  void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
 
 
145
  GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
146
+ GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
147
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
148
 
149
+ tensor->buffer->iface.set_tensor(tensor->buffer, tensor, data, offset, size);
 
150
  }
151
 
152
  void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
 
 
153
  GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
154
+ GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
155
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
156
 
157
+ tensor->buffer->iface.get_tensor(tensor->buffer, tensor, data, offset, size);
 
158
  }
159
 
160
  void ggml_backend_synchronize(ggml_backend_t backend) {
161
+ if (backend->iface.synchronize == NULL) {
162
+ return;
163
+ }
164
+
165
  backend->iface.synchronize(backend);
166
  }
167
 
 
175
 
176
  void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
177
  backend->iface.graph_plan_compute(backend, plan);
178
+
179
+ // TODO: optional sync
180
+ ggml_backend_synchronize(backend);
181
  }
182
 
183
  void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
184
  backend->iface.graph_compute(backend, cgraph);
185
+
186
+ // TODO: optional sync
187
+ ggml_backend_synchronize(backend);
188
  }
189
 
190
  bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
 
221
 
222
  // TODO: allow backends to support copy to/from same backend
223
 
224
+ if (dst->buffer->iface.cpy_tensor_from != NULL) {
225
+ dst->buffer->iface.cpy_tensor_from(dst->buffer, src, dst);
226
+ } else if (src->buffer->iface.cpy_tensor_to != NULL) {
227
+ src->buffer->iface.cpy_tensor_to(src->buffer, src, dst);
228
  } else {
229
  // shouldn't be hit when copying from/to CPU
230
  #ifndef NDEBUG
231
+ fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to "
232
+ "are implemented for %s and %s, falling back to get/set\n", src->name, dst->name);
233
  #endif
234
  size_t nbytes = ggml_nbytes(src);
235
  void * data = malloc(nbytes);
 
239
  }
240
  }
241
 
242
+ // backend registry
243
 
244
+ #define GGML_MAX_BACKENDS_REG 16
245
+
246
+ struct ggml_backend_reg {
247
+ char name[128];
248
+ ggml_backend_init_fn init_fn;
249
+ ggml_backend_buffer_type_t default_buffer_type;
250
+ void * user_data;
251
  };
252
 
253
+ static struct ggml_backend_reg ggml_backend_registry[GGML_MAX_BACKENDS_REG];
254
+ static size_t ggml_backend_registry_count = 0;
255
+
256
+ static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data);
257
+
258
+ static void ggml_backend_registry_init(void) {
259
+ static bool initialized = false;
260
+
261
+ if (initialized) {
262
+ return;
263
+ }
264
+
265
+ initialized = true;
266
 
267
+ ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL);
268
+
269
+ // add forward decls here to avoid including the backend headers
270
+ #ifdef GGML_USE_CUBLAS
271
+ extern void ggml_backend_cuda_reg_devices(void);
272
+ ggml_backend_cuda_reg_devices();
273
+ #endif
274
+
275
+ #ifdef GGML_USE_METAL
276
+ extern ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data);
277
+ extern ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
278
+ ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL);
279
+ #endif
280
  }
281
 
282
+ void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
283
+ GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);
284
+
285
+ int id = ggml_backend_registry_count;
286
+
287
+ ggml_backend_registry[id] = (struct ggml_backend_reg) {
288
+ /* .name = */ {0},
289
+ /* .fn = */ init_fn,
290
+ /* .default_buffer_type = */ default_buffer_type,
291
+ /* .user_data = */ user_data,
292
+ };
293
+
294
+ snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name);
295
+
296
+ #ifndef NDEBUG
297
+ fprintf(stderr, "%s: registered backend %s\n", __func__, name);
298
+ #endif
299
+
300
+ ggml_backend_registry_count++;
301
+ }
302
+
303
+ size_t ggml_backend_reg_get_count(void) {
304
+ ggml_backend_registry_init();
305
+
306
+ return ggml_backend_registry_count;
307
+ }
308
+
309
+ size_t ggml_backend_reg_find_by_name(const char * name) {
310
+ ggml_backend_registry_init();
311
+
312
+ for (size_t i = 0; i < ggml_backend_registry_count; i++) {
313
+ // TODO: case insensitive in a portable way
314
+ if (strcmp(ggml_backend_registry[i].name, name) == 0) {
315
+ return i;
316
+ }
317
+ }
318
+ return SIZE_MAX;
319
+ }
320
+
321
+ // init from backend:params string
322
+ ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) {
323
+ ggml_backend_registry_init();
324
+
325
+ const char * params = strchr(backend_str, ':');
326
+ char backend_name[128];
327
+ if (params == NULL) {
328
+ strcpy(backend_name, backend_str);
329
+ params = "";
330
+ } else {
331
+ strncpy(backend_name, backend_str, params - backend_str);
332
+ backend_name[params - backend_str] = '\0';
333
+ params++;
334
+ }
335
+
336
+ size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
337
+ if (backend_i == SIZE_MAX) {
338
+ fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
339
+ return NULL;
340
+ }
341
+
342
+ return ggml_backend_reg_init_backend(backend_i, params);
343
+ }
344
+
345
+ const char * ggml_backend_reg_get_name(size_t i) {
346
+ ggml_backend_registry_init();
347
+
348
+ GGML_ASSERT(i < ggml_backend_registry_count);
349
+ return ggml_backend_registry[i].name;
350
+ }
351
+
352
+ ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) {
353
+ ggml_backend_registry_init();
354
+
355
+ GGML_ASSERT(i < ggml_backend_registry_count);
356
+ return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data);
357
+ }
358
+
359
+ ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) {
360
+ ggml_backend_registry_init();
361
+
362
+ GGML_ASSERT(i < ggml_backend_registry_count);
363
+ return ggml_backend_registry[i].default_buffer_type;
364
+ }
365
+
366
+ ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
367
+ ggml_backend_registry_init();
368
+
369
+ GGML_ASSERT(i < ggml_backend_registry_count);
370
+ return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size);
371
  }
372
 
373
+ // backend CPU
374
+
375
  static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
376
  return (void *)buffer->context;
377
  }
378
 
379
  static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
380
  free(buffer->context);
381
+ GGML_UNUSED(buffer);
382
+ }
383
+
384
+ static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
385
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
386
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
387
+
388
+ memcpy((char *)tensor->data + offset, data, size);
389
+
390
+ GGML_UNUSED(buffer);
391
+ }
392
+
393
+ static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
394
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
395
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
396
+
397
+ memcpy(data, (const char *)tensor->data + offset, size);
398
+
399
+ GGML_UNUSED(buffer);
400
+ }
401
+
402
+ static void ggml_backend_cpu_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
403
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
404
+
405
+ GGML_UNUSED(buffer);
406
+ }
407
+
408
+ static void ggml_backend_cpu_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
409
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
410
+
411
+ GGML_UNUSED(buffer);
412
  }
413
 
414
  static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
415
+ /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
416
+ /* .get_base = */ ggml_backend_cpu_buffer_get_base,
417
+ /* .init_tensor = */ NULL, // no initialization required
418
+ /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
419
+ /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
420
+ /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
421
+ /* .cpy_tensor_to = */ ggml_backend_cpu_buffer_cpy_tensor_to,
422
  };
423
 
424
  // for buffers from ptr, free is not called
425
  static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
426
+ /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
427
+ /* .get_base = */ ggml_backend_cpu_buffer_get_base,
428
+ /* .init_tensor = */ NULL, // no initialization required
429
+ /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
430
+ /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
431
+ /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
432
+ /* .cpy_tensor_to = */ ggml_backend_cpu_buffer_cpy_tensor_to,
433
  };
434
 
435
  static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
436
 
437
+ static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
438
  size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
439
  void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
440
 
441
  GGML_ASSERT(data != NULL && "failed to allocate buffer");
442
 
443
+ return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
444
  }
445
 
446
+ static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
447
  return TENSOR_ALIGNMENT;
 
 
448
 
449
+ GGML_UNUSED(buft);
450
+ }
 
451
 
452
+ static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
453
+ return ggml_backend_is_cpu(backend);
454
 
455
+ GGML_UNUSED(buft);
456
  }
457
 
458
+ ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
459
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_cpu = {
460
+ /* .iface = */ {
461
+ /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer,
462
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
463
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
464
+ /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend,
465
+ },
466
+ /* .context = */ NULL,
467
+ };
468
 
469
+ return &ggml_backend_buffer_type_cpu;
470
  }
471
 
472
+ struct ggml_backend_cpu_context {
473
+ int n_threads;
474
+ void * work_data;
475
+ size_t work_size;
476
+ };
477
 
478
+ static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
479
+ return "CPU";
480
 
481
+ GGML_UNUSED(backend);
482
  }
483
 
484
+ static void ggml_backend_cpu_free(ggml_backend_t backend) {
485
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
486
+ free(cpu_ctx->work_data);
487
+ free(cpu_ctx);
488
+ free(backend);
489
+ }
490
+
491
+ static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) {
492
+ return ggml_backend_cpu_buffer_type();
493
 
494
+ GGML_UNUSED(backend);
495
  }
496
 
497
  struct ggml_backend_plan_cpu {
 
520
  free(cpu_plan->cplan.work_data);
521
  free(cpu_plan);
522
 
523
+ GGML_UNUSED(backend);
524
  }
525
 
526
  static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
 
528
 
529
  ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
530
 
531
+ GGML_UNUSED(backend);
532
  }
533
 
534
  static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
 
549
 
550
  static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
551
  return true;
552
+
553
+ GGML_UNUSED(backend);
554
+ GGML_UNUSED(op);
555
  }
556
 
557
  static struct ggml_backend_i cpu_backend_i = {
558
+ /* .get_name = */ ggml_backend_cpu_name,
559
+ /* .free = */ ggml_backend_cpu_free,
560
+ /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type,
561
+ /* .set_tensor_async = */ NULL,
562
+ /* .get_tensor_async = */ NULL,
563
+ /* .cpy_tensor_from_async = */ NULL,
564
+ /* .cpy_tensor_to_async = */ NULL,
565
+ /* .synchronize = */ NULL,
566
+ /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
567
+ /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
568
+ /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
569
+ /* .graph_compute = */ ggml_backend_cpu_graph_compute,
570
+ /* .supports_op = */ ggml_backend_cpu_supports_op,
 
571
  };
572
 
573
  ggml_backend_t ggml_backend_cpu_init(void) {
 
597
  ctx->n_threads = n_threads;
598
  }
599
 
600
+ ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
601
+ return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
602
+ }
603
+
604
+ static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) {
605
+ return ggml_backend_cpu_init();
606
+
607
+ GGML_UNUSED(params);
608
+ GGML_UNUSED(user_data);
609
  }
610
 
611
+
612
  // scheduler
613
 
614
  #define GGML_MAX_BACKENDS 4
 
621
  int i_end;
622
  struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
623
  int n_inputs;
624
+ struct ggml_cgraph graph;
625
  };
626
 
627
  struct ggml_backend_sched {
 
647
  #else
648
  __attribute__((aligned(GGML_MEM_ALIGN)))
649
  #endif
650
+ char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
651
  };
652
 
653
  #define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
 
676
  return INT_MAX;
677
  }
678
 
679
+ static ggml_backend_t get_buffer_backend(ggml_backend_sched_t sched, ggml_backend_buffer_t buffer) {
680
+ if (buffer == NULL) {
681
+ return NULL;
682
+ }
683
+ // find highest prio backend that supports the buffer type
684
+ for (int i = 0; i < sched->n_backends; i++) {
685
+ if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
686
+ return sched->backends[i];
687
+ }
688
+ }
689
+ GGML_ASSERT(false && "tensor buffer type not supported by any backend");
690
+ }
691
+
692
+ static ggml_backend_t get_allocr_backend(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
693
+ if (allocr == NULL) {
694
+ return NULL;
695
+ }
696
+ // find highest prio backend that supports the buffer type
697
+ for (int i = 0; i < sched->n_backends; i++) {
698
+ if (sched->tallocs[i] == allocr) {
699
+ return sched->backends[i];
700
+ }
701
+ }
702
+ GGML_UNREACHABLE();
703
+ }
704
+
705
+ #if 0
706
+ static char causes[GGML_DEFAULT_GRAPH_SIZE*8 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
707
+ #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
708
+ #define GET_CAUSE(node) causes[hash_id(node)]
709
+ #else
710
+ #define SET_CAUSE(node, ...)
711
+ #define GET_CAUSE(node) ""
712
+ #endif
713
+
714
  // returns the backend that should be used for the node based on the current locations
 
715
  static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
716
  // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
717
  // ie. kv cache updates
718
  // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
719
  // dst
720
+ ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer);
721
  if (cur_backend != NULL) {
722
+ SET_CAUSE(node, "1.dst");
723
  return cur_backend;
724
  }
725
 
726
  // view_src
727
+ if (node->view_src != NULL && get_buffer_backend(sched, node->view_src->buffer) != NULL) {
728
+ SET_CAUSE(node, "1.vsrc");
729
+ return get_buffer_backend(sched, node->view_src->buffer);
730
  }
731
 
732
  // src
 
738
  if (src == NULL) {
739
  break;
740
  }
741
+ ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer);
742
  if (src_backend != NULL) {
743
  int src_prio = sched_backend_prio(sched, src_backend);
744
  size_t src_size = ggml_nbytes(src);
 
746
  cur_prio = src_prio;
747
  cur_size = src_size;
748
  cur_backend = src_backend;
749
+ SET_CAUSE(node, "1.src%d", i);
750
  }
751
  }
752
  }
 
767
  int cur_split = 0;
768
  for (int i = 0; i < graph->n_nodes; i++) {
769
  if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
770
+ ggml_backend_t split_backend = get_allocr_backend(sched, sched->splits[cur_split].tallocr);
771
+ fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
772
+ sched->splits[cur_split].n_inputs);
773
  for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
774
+ fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
775
+ fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
776
  }
777
  fprintf(stderr, "\n");
778
  cur_split++;
 
782
  continue;
783
  }
784
  ggml_tallocr_t node_allocr = node_allocr(node);
785
+ ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
786
+ fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name,
787
+ fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
788
  for (int j = 0; j < GGML_MAX_SRC; j++) {
789
  struct ggml_tensor * src = node->src[j];
790
  if (src == NULL) {
791
  break;
792
  }
793
  ggml_tallocr_t src_allocr = node_allocr(src);
794
+ ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
795
+ fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name,
796
+ fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
797
  }
798
  fprintf(stderr, "\n");
799
  }
 
819
  sched->n_splits = 0;
820
 
821
  struct ggml_init_params params = {
822
+ /* .mem_size = */ sizeof(sched->context_buffer),
823
+ /* .mem_buffer = */ sched->context_buffer,
824
+ /* .no_alloc = */ true
825
  };
826
 
827
  if (sched->ctx != NULL) {
 
837
  // do not overwrite user assignments
838
  continue;
839
  }
840
+ ggml_backend_t leaf_backend = get_buffer_backend(sched, leaf->buffer);
841
  if (leaf_backend == NULL && leaf->view_src != NULL) {
842
+ leaf_backend = get_buffer_backend(sched, leaf->view_src->buffer);
843
  }
844
  if (leaf_backend != NULL) {
845
  node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
 
881
  cur_prio = src_prio;
882
  cur_size = src_size;
883
  node_allocr = src_allocr;
884
+ SET_CAUSE(node, "2.src%d", j);
885
  }
886
  }
887
  }
 
965
  struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
966
  sched->node_copies[id][cur_backend_id] = tensor_copy;
967
  node_allocr(tensor_copy) = cur_allocr;
968
+ ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
969
  ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
970
  }
971
  node->src[j] = sched->node_copies[id][cur_backend_id];
 
993
  ggml_tallocr_t src_allocr = node_allocr(src);
994
  if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
995
  fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
996
+ node->name, node_allocr ? ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
997
+ j, src->name, src_allocr ? ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
998
  }
999
  }
1000
  }
 
1005
  struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
1006
  for (int i = 0; i < sched->n_splits; i++) {
1007
  struct ggml_backend_sched_split * split = &sched->splits[i];
1008
+ split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
1009
 
1010
  // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
1011
  for (int j = 0; j < split->n_inputs; j++) {
 
1038
 
1039
  for (int i = 0; i < sched->n_splits; i++) {
1040
  struct ggml_backend_sched_split * split = &splits[i];
1041
+ ggml_backend_t split_backend = get_allocr_backend(sched, split->tallocr);
1042
  int split_backend_id = sched_backend_prio(sched, split_backend);
1043
 
1044
  // copy the input tensors to the split backend
1045
  uint64_t copy_start_us = ggml_time_us();
1046
  for (int j = 0; j < split->n_inputs; j++) {
1047
+ struct ggml_tensor * input = split->inputs[j];
1048
+ struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)];
1049
+ if (input->buffer == NULL) {
1050
+ if (input->view_src == NULL) {
1051
+ fprintf(stderr, "input %s has no buffer and no view_src\n", input->name);
1052
  exit(1);
1053
  }
1054
+ // FIXME: may need to use the sched buffer instead
1055
+ ggml_backend_view_init(input->view_src->buffer, input);
 
 
 
1056
  }
1057
  if (input_cpy->buffer == NULL) {
1058
  fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
1059
  exit(1);
1060
  }
1061
+ //GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend);
1062
+ //GGML_ASSERT(input_cpy->buffer->backend == split_backend);
1063
+ ggml_backend_tensor_copy(input, input_cpy);
1064
  }
1065
  // ggml_backend_synchronize(split_backend);
1066
  int64_t copy_end_us = ggml_time_us();
 
1073
  #endif
1074
 
1075
  uint64_t compute_start_us = ggml_time_us();
1076
+ ggml_backend_graph_compute(split_backend, &split->graph);
1077
  // ggml_backend_synchronize(split_backend);
1078
  uint64_t compute_end_us = ggml_time_us();
1079
  compute_us[split_backend_id] += compute_end_us - compute_start_us;
 
1102
  struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
1103
  memset(sched, 0, sizeof(struct ggml_backend_sched));
1104
 
 
 
1105
  sched->n_backends = n_backends;
1106
  for (int i = 0; i < n_backends; i++) {
1107
  sched->backends[i] = backends[i];
 
1176
  GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
1177
  node_allocr(node) = sched->tallocs[backend_index];
1178
  }
1179
+
1180
+ // utils
1181
+ void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
1182
+ GGML_ASSERT(tensor->buffer == NULL);
1183
+ GGML_ASSERT(tensor->data == NULL);
1184
+ GGML_ASSERT(tensor->view_src != NULL);
1185
+ GGML_ASSERT(tensor->view_src->buffer != NULL);
1186
+ GGML_ASSERT(tensor->view_src->data != NULL);
1187
+
1188
+ tensor->buffer = buffer;
1189
+ tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
1190
+ tensor->backend = tensor->view_src->backend;
1191
+ ggml_backend_buffer_init_tensor(buffer, tensor);
1192
+ }
1193
+
1194
+ void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
1195
+ GGML_ASSERT(tensor->buffer == NULL);
1196
+ GGML_ASSERT(tensor->data == NULL);
1197
+ GGML_ASSERT(tensor->view_src == NULL);
1198
+ GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
1199
+ GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
1200
+ (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
1201
+
1202
+ tensor->buffer = buffer;
1203
+ tensor->data = addr;
1204
+ ggml_backend_buffer_init_tensor(buffer, tensor);
1205
+ }
1206
+
1207
+ static struct ggml_tensor * graph_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
1208
+ struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {
1209
+
1210
+ GGML_ASSERT(src != NULL);
1211
+ GGML_ASSERT(src->data && "graph must be allocated");
1212
+
1213
+ size_t id = ggml_hash_insert(hash_set, src);
1214
+ if (id == GGML_HASHTABLE_ALREADY_EXISTS) {
1215
+ return node_copies[ggml_hash_find(hash_set, src)];
1216
+ }
1217
+
1218
+ struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
1219
+ if (src->view_src != NULL) {
1220
+ dst->view_src = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
1221
+ dst->view_offs = src->view_offs;
1222
+ }
1223
+ dst->op = src->op;
1224
+ memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
1225
+ ggml_set_name(dst, src->name);
1226
+
1227
+ // copy src
1228
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
1229
+ struct ggml_tensor * s = src->src[i];
1230
+ if (s == NULL) {
1231
+ break;
1232
+ }
1233
+ dst->src[i] = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
1234
+ }
1235
+
1236
+ node_copies[id] = dst;
1237
+ return dst;
1238
+ }
1239
+
1240
+ static void graph_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
1241
+ size_t id = ggml_hash_find(hash_set, src);
1242
+ if (node_init[id]) {
1243
+ return;
1244
+ }
1245
+ node_init[id] = true;
1246
+
1247
+ struct ggml_tensor * dst = node_copies[id];
1248
+ if (dst->view_src != NULL) {
1249
+ ggml_backend_view_init(dst->view_src->buffer, dst);
1250
+ }
1251
+ else {
1252
+ ggml_backend_tensor_copy(src, dst);
1253
+ }
1254
+
1255
+ // init src
1256
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
1257
+ struct ggml_tensor * s = src->src[i];
1258
+ if (s == NULL) {
1259
+ break;
1260
+ }
1261
+ graph_init_tensor(hash_set, node_copies, node_init, s);
1262
+ }
1263
+ }
1264
+
1265
+ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
1266
+ struct ggml_hash_set hash_set = {
1267
+ /* .size = */ graph->visited_hash_table.size,
1268
+ /* .keys = */ calloc(sizeof(hash_set.keys[0]) * graph->visited_hash_table.size, 1)
1269
+ };
1270
+ struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]) * hash_set.size, 1);
1271
+ bool * node_init = calloc(sizeof(node_init[0]) * hash_set.size, 1);
1272
+
1273
+ struct ggml_init_params params = {
1274
+ /* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
1275
+ /* .mem_buffer = */ NULL,
1276
+ /* .no_alloc = */ true
1277
+ };
1278
+
1279
+ struct ggml_context * ctx_allocated = ggml_init(params);
1280
+ struct ggml_context * ctx_unallocated = ggml_init(params);
1281
+
1282
+ // dup nodes
1283
+ for (int i = 0; i < graph->n_nodes; i++) {
1284
+ struct ggml_tensor * node = graph->nodes[i];
1285
+ graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
1286
+ }
1287
+
1288
+ // allocate nodes
1289
+ ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
1290
+
1291
+ //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
1292
+
1293
+ // copy data and init views
1294
+ for (int i = 0; i < graph->n_nodes; i++) {
1295
+ struct ggml_tensor * node = graph->nodes[i];
1296
+ graph_init_tensor(hash_set, node_copies, node_init, node);
1297
+ }
1298
+
1299
+ // build graph copy
1300
+ struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);
1301
+ for (int i = 0; i < graph->n_nodes; i++) {
1302
+ struct ggml_tensor * node = graph->nodes[i];
1303
+ struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)];
1304
+ graph_copy->nodes[i] = node_copy;
1305
+ }
1306
+ graph_copy->n_nodes = graph->n_nodes;
1307
+
1308
+ free(hash_set.keys);
1309
+ free(node_copies);
1310
+ free(node_init);
1311
+
1312
+ return (struct ggml_backend_graph_copy) {
1313
+ /* .buffer = */ buffer,
1314
+ /* .ctx_allocated = */ ctx_allocated,
1315
+ /* .ctx_unallocated = */ ctx_unallocated,
1316
+ /* .graph = */ graph_copy,
1317
+ };
1318
+ }
1319
+
1320
+ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
1321
+ ggml_backend_buffer_free(copy.buffer);
1322
+ ggml_free(copy.ctx_allocated);
1323
+ ggml_free(copy.ctx_unallocated);
1324
+ }
1325
+
1326
+ void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
1327
+ struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
1328
+ struct ggml_cgraph * g1 = graph;
1329
+ struct ggml_cgraph * g2 = copy.graph;
1330
+
1331
+ assert(g1->n_nodes == g2->n_nodes);
1332
+
1333
+ for (int i = 0; i < g1->n_nodes; i++) {
1334
+ //printf("eval %d/%d\n", i, g1->n_nodes);
1335
+ struct ggml_tensor * t1 = g1->nodes[i];
1336
+ struct ggml_tensor * t2 = g2->nodes[i];
1337
+
1338
+ assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
1339
+
1340
+ struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
1341
+ struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
1342
+
1343
+ ggml_backend_graph_compute(backend1, &g1v);
1344
+ ggml_backend_graph_compute(backend2, &g2v);
1345
+
1346
+ if (ggml_is_view_op(t1->op)) {
1347
+ continue;
1348
+ }
1349
+
1350
+ // compare results, calculate rms etc
1351
+ if (!callback(i, t1, t2, user_data)) {
1352
+ break;
1353
+ }
1354
+ }
1355
+
1356
+ ggml_backend_graph_copy_free(copy);
1357
+ }
ggml-backend.h CHANGED
@@ -7,41 +7,44 @@
7
  extern "C" {
8
  #endif
9
 
 
 
 
 
 
10
  //
11
  // Backend buffer
12
  //
13
 
14
- struct ggml_backend_buffer;
15
- typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
 
 
 
16
 
17
- // backend buffer functions
18
  GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
19
- GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
20
  GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
21
  GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
22
- GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
23
  GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
24
- GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
 
 
25
 
26
  //
27
  // Backend
28
  //
29
 
30
- struct ggml_backend;
31
- typedef struct ggml_backend * ggml_backend_t;
32
- typedef void * ggml_backend_graph_plan_t;
33
-
34
- GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
35
 
36
  GGML_API const char * ggml_backend_name(ggml_backend_t backend);
37
  GGML_API void ggml_backend_free(ggml_backend_t backend);
38
 
39
- GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
40
-
41
- GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
42
 
43
- GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
44
- GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
45
 
46
  GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
47
  GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
@@ -57,6 +60,7 @@ extern "C" {
57
 
58
  // tensor copy between different backends
59
  GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
 
60
 
61
  //
62
  // CPU backend
@@ -68,8 +72,23 @@ extern "C" {
68
  GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
69
 
70
  // Create a backend buffer from an existing pointer
71
- GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
 
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  //
75
  // Backend scheduler
@@ -131,6 +150,32 @@ extern "C" {
131
  ggml_backend_sched_t sched,
132
  struct ggml_cgraph * graph);
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  #ifdef __cplusplus
135
  }
136
  #endif
 
7
  extern "C" {
8
  #endif
9
 
10
+ typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
11
+ typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
12
+ typedef struct ggml_backend * ggml_backend_t;
13
+ typedef void * ggml_backend_graph_plan_t;
14
+
15
  //
16
  // Backend buffer
17
  //
18
 
19
+ // buffer type
20
+ GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
21
+ GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
22
+ GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
23
+ GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
24
 
25
+ // buffer
26
  GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
 
27
  GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
28
  GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
 
29
  GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
30
+ GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
31
+ GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
32
+ GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
33
 
34
  //
35
  // Backend
36
  //
37
 
 
 
 
 
 
38
 
39
  GGML_API const char * ggml_backend_name(ggml_backend_t backend);
40
  GGML_API void ggml_backend_free(ggml_backend_t backend);
41
 
42
+ GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
43
+ GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
44
+ GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
45
 
46
+ GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
47
+ GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
48
 
49
  GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
50
  GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
 
60
 
61
  // tensor copy between different backends
62
  GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
63
+ GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
64
 
65
  //
66
  // CPU backend
 
72
  GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
73
 
74
  // Create a backend buffer from an existing pointer
75
+ GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
76
+
77
+ GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
78
 
79
+ //
80
+ // Backend registry
81
+ //
82
+
83
+ // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
84
+
85
+ GGML_API size_t ggml_backend_reg_get_count(void);
86
+ GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
87
+ GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
88
+ GGML_API const char * ggml_backend_reg_get_name(size_t i);
89
+ GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
90
+ GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
91
+ GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size);
92
 
93
  //
94
  // Backend scheduler
 
150
  ggml_backend_sched_t sched,
151
  struct ggml_cgraph * graph);
152
 
153
+
154
+ //
155
+ // Utils
156
+ //
157
+
158
+ struct ggml_backend_graph_copy {
159
+ ggml_backend_buffer_t buffer;
160
+ struct ggml_context * ctx_allocated;
161
+ struct ggml_context * ctx_unallocated;
162
+ struct ggml_cgraph * graph;
163
+ };
164
+
165
+ // Copy a graph to a different backend
166
+ GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
167
+ GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
168
+
169
+ typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
170
+
171
+ // Compare the output of two backends
172
+ GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
173
+
174
+ // Tensor initialization
175
+ GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
176
+ GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
177
+
178
+
179
  #ifdef __cplusplus
180
  }
181
  #endif
ggml-cuda.cu CHANGED
@@ -1,7 +1,8 @@
1
  #include <algorithm>
2
- #include <cinttypes>
3
  #include <cstddef>
4
  #include <cstdint>
 
 
5
  #include <limits>
6
  #include <stdint.h>
7
  #include <stdio.h>
@@ -69,6 +70,7 @@
69
  #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
70
  #define cudaSetDevice hipSetDevice
71
  #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
 
72
  #define cudaStreamNonBlocking hipStreamNonBlocking
73
  #define cudaStreamSynchronize hipStreamSynchronize
74
  #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
@@ -190,7 +192,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
190
  fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
191
  cudaGetErrorString(err_)); \
192
  fprintf(stderr, "current device: %d\n", id); \
193
- exit(1); \
194
  } \
195
  } while (0)
196
 
@@ -204,7 +206,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
204
  fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
205
  err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
206
  fprintf(stderr, "current device: %d\n", id); \
207
- exit(1); \
208
  } \
209
  } while (0)
210
  #else
@@ -216,7 +218,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
216
  cudaGetDevice(&id); \
217
  fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
218
  fprintf(stderr, "current device: %d\n", id); \
219
- exit(1); \
220
  } \
221
  } while (0)
222
  #endif // CUDART_VERSION >= 11
@@ -433,8 +435,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
433
  #define WARP_SIZE 32
434
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
435
 
436
- #define CUDA_ADD_BLOCK_SIZE 256
437
- #define CUDA_MUL_BLOCK_SIZE 256
438
  #define CUDA_GELU_BLOCK_SIZE 256
439
  #define CUDA_SILU_BLOCK_SIZE 256
440
  #define CUDA_RELU_BLOCK_SIZE 256
@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443
  #define CUDA_SCALE_BLOCK_SIZE 256
444
  #define CUDA_CLAMP_BLOCK_SIZE 256
445
  #define CUDA_ROPE_BLOCK_SIZE 256
 
446
  #define CUDA_ALIBI_BLOCK_SIZE 32
447
  #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448
  #define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -501,40 +502,112 @@ static size_t g_scratch_offset = 0;
501
 
502
  static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
503
 
504
- static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
505
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
506
-
507
- if (i >= kx) {
508
- return;
509
  }
510
- dst[i] = x[i] + y[i%ky];
511
  }
512
 
513
- static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
514
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
 
 
 
 
 
515
 
516
- if (i >= k) {
517
- return;
 
 
518
  }
519
- dst[i] = __hadd(x[i], __float2half(y[i]));
520
  }
521
 
522
- static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) {
523
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
524
 
525
- if (i >= k) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  return;
527
  }
528
- dst[i] = __half2float(x[i]) + y[i];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  }
530
 
531
- static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
 
 
 
 
 
 
532
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
533
 
534
- if (i >= kx) {
 
 
 
 
 
535
  return;
536
  }
537
- dst[i] = x[i] * y[i%ky];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  }
539
 
540
  static __global__ void gelu_f32(const float * x, float * dst, const int k) {
@@ -577,22 +650,11 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
577
  dst[i] = x[i] * x[i];
578
  }
579
 
580
- static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
581
- #pragma unroll
582
- for (int mask = 16; mask > 0; mask >>= 1) {
583
- a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
584
- a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
585
- }
586
- return a;
587
- }
588
-
589
  template <int block_size>
590
- static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
591
  const int row = blockIdx.x*blockDim.y + threadIdx.y;
592
  const int tid = threadIdx.x;
593
 
594
- const float eps = 1e-5f;
595
-
596
  float2 mean_var = make_float2(0.f, 0.f);
597
 
598
  for (int col = tid; col < ncols; col += block_size) {
@@ -624,14 +686,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
624
  }
625
  }
626
 
627
- static __device__ __forceinline__ float warp_reduce_sum(float x) {
628
- #pragma unroll
629
- for (int mask = 16; mask > 0; mask >>= 1) {
630
- x += __shfl_xor_sync(0xffffffff, x, mask, 32);
631
- }
632
- return x;
633
- }
634
-
635
  template <int block_size>
636
  static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
637
  const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -4550,6 +4604,116 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
4550
  cpy_1(cx + x_offset, cdst + dst_offset);
4551
  }
4552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4553
  static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
4554
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
4555
  return 1.0f - min(1.0f, max(0.0f, y));
@@ -4610,8 +4774,8 @@ static __global__ void rope(
4610
 
4611
  template<typename T, bool has_pos>
4612
  static __global__ void rope_neox(
4613
- const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
4614
- float ext_factor, float attn_factor, rope_corr_dims corr_dims
4615
  ) {
4616
  const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
4617
 
@@ -4620,23 +4784,25 @@ static __global__ void rope_neox(
4620
  }
4621
 
4622
  const int row = blockDim.x*blockIdx.x + threadIdx.x;
4623
- const int i = row*ncols + col/2;
 
 
 
4624
  const int i2 = row/p_delta_rows;
4625
 
4626
- // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
4627
- const float cur_rot = -float(col)/ncols;
4628
 
4629
  const int p = has_pos ? pos[i2] : 0;
4630
- const float theta_base = p*powf(freq_base, cur_rot);
4631
 
4632
  float cos_theta, sin_theta;
4633
  rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
4634
 
4635
  const float x0 = x[i + 0];
4636
- const float x1 = x[i + ncols/2];
4637
 
4638
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
4639
- dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
4640
  }
4641
 
4642
  static __global__ void rope_glm_f32(
@@ -4702,6 +4868,65 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
4702
  dst[i] = col * m_k + x[i];
4703
  }
4704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4705
  static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
4706
  const int col = blockDim.y*blockIdx.y + threadIdx.y;
4707
  const int row = blockDim.x*blockIdx.x + threadIdx.x;
@@ -4711,49 +4936,79 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
4711
  }
4712
 
4713
  const int i = row*ncols + col;
4714
- // dst[i] = col > n_past + row ? -INFINITY : x[i];
4715
- dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
 
4716
  }
4717
 
4718
- // the CUDA soft max implementation differs from the CPU implementation
4719
- // instead of doubles floats are used
4720
- static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
4721
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
4722
- const int block_size = blockDim.y;
4723
- const int tid = threadIdx.y;
 
 
 
 
 
4724
 
4725
  float max_val = -INFINITY;
4726
 
4727
  for (int col = tid; col < ncols; col += block_size) {
4728
- const int i = row*ncols + col;
4729
- max_val = max(max_val, x[i]);
 
4730
  }
4731
 
4732
  // find the max value in the block
4733
- #pragma unroll
4734
- for (int mask = 16; mask > 0; mask >>= 1) {
4735
- max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
 
 
 
 
 
 
 
 
 
 
 
4736
  }
4737
 
4738
  float tmp = 0.f;
4739
 
4740
  for (int col = tid; col < ncols; col += block_size) {
4741
- const int i = row*ncols + col;
4742
- const float val = expf(x[i] - max_val);
 
4743
  tmp += val;
4744
- dst[i] = val;
4745
  }
4746
 
4747
- // sum up partial sums
4748
- #pragma unroll
4749
- for (int mask = 16; mask > 0; mask >>= 1) {
4750
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
 
 
 
 
 
 
 
 
 
 
 
4751
  }
4752
 
4753
  const float inv_tmp = 1.f / tmp;
4754
 
4755
  for (int col = tid; col < ncols; col += block_size) {
4756
- const int i = row*ncols + col;
4757
  dst[i] *= inv_tmp;
4758
  }
4759
  }
@@ -4805,25 +5060,119 @@ static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const
4805
  k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
4806
  }
4807
 
4808
- static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
4809
- const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
4810
- add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
4811
- }
4812
-
4813
- static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
4814
- const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
4815
- add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
4816
- }
4817
-
4818
- static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) {
4819
- const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
4820
- add_f16_f32_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
4821
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4822
 
4823
- static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
4824
- const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
4825
- mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
4826
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4827
 
4828
  static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
4829
  const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
@@ -4845,14 +5194,14 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
4845
  sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
4846
  }
4847
 
4848
- static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4849
  GGML_ASSERT(ncols % WARP_SIZE == 0);
4850
  if (ncols < 1024) {
4851
  const dim3 block_dims(WARP_SIZE, 1, 1);
4852
- norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
4853
  } else {
4854
  const dim3 block_dims(1024, 1, 1);
4855
- norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
4856
  }
4857
  }
4858
 
@@ -4874,34 +5223,10 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
4874
  quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
4875
  }
4876
 
4877
- template<typename dst_t>
4878
- static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4879
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
4880
- dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
4881
- }
4882
-
4883
- template<typename dst_t>
4884
- static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4885
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
4886
- dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
4887
- }
4888
-
4889
- template<typename dst_t>
4890
- static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4891
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
4892
- dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
4893
- }
4894
-
4895
- template<typename dst_t>
4896
- static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4897
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
4898
- dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
4899
- }
4900
-
4901
- template<typename dst_t>
4902
- static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4903
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
4904
- dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
4905
  }
4906
 
4907
  template<typename dst_t>
@@ -4950,6 +5275,64 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
4950
  #endif
4951
  }
4952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4953
  static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4954
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4955
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
@@ -5038,13 +5421,22 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
5038
  dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
5039
  }
5040
 
5041
- static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5042
- GGML_ASSERT(ncols % QK4_0 == 0);
5043
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5044
  const dim3 block_nums(block_num_y, 1, 1);
5045
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5046
- mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
5047
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 
 
 
 
 
 
 
 
 
5048
  }
5049
 
5050
  static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -5128,83 +5520,6 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
5128
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
5129
  }
5130
 
5131
- static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
5132
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
5133
- dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
5134
- }
5135
-
5136
- static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
5137
- const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
5138
- dequantize_block<1, 1, convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
5139
- }
5140
-
5141
- static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5142
- GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
5143
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5144
- const dim3 block_nums(block_num_y, 1, 1);
5145
- const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5146
- dequantize_mul_mat_vec<1, 1, convert_f16>
5147
- <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
5148
- }
5149
-
5150
- static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
5151
- switch (type) {
5152
- case GGML_TYPE_Q4_0:
5153
- return dequantize_row_q4_0_cuda;
5154
- case GGML_TYPE_Q4_1:
5155
- return dequantize_row_q4_1_cuda;
5156
- case GGML_TYPE_Q5_0:
5157
- return dequantize_row_q5_0_cuda;
5158
- case GGML_TYPE_Q5_1:
5159
- return dequantize_row_q5_1_cuda;
5160
- case GGML_TYPE_Q8_0:
5161
- return dequantize_row_q8_0_cuda;
5162
- case GGML_TYPE_Q2_K:
5163
- return dequantize_row_q2_K_cuda;
5164
- case GGML_TYPE_Q3_K:
5165
- return dequantize_row_q3_K_cuda;
5166
- case GGML_TYPE_Q4_K:
5167
- return dequantize_row_q4_K_cuda;
5168
- case GGML_TYPE_Q5_K:
5169
- return dequantize_row_q5_K_cuda;
5170
- case GGML_TYPE_Q6_K:
5171
- return dequantize_row_q6_K_cuda;
5172
- case GGML_TYPE_F32:
5173
- return convert_fp32_to_fp16_cuda;
5174
- default:
5175
- return nullptr;
5176
- }
5177
- }
5178
-
5179
- static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
5180
- switch (type) {
5181
- case GGML_TYPE_Q4_0:
5182
- return dequantize_row_q4_0_cuda;
5183
- case GGML_TYPE_Q4_1:
5184
- return dequantize_row_q4_1_cuda;
5185
- case GGML_TYPE_Q5_0:
5186
- return dequantize_row_q5_0_cuda;
5187
- case GGML_TYPE_Q5_1:
5188
- return dequantize_row_q5_1_cuda;
5189
- case GGML_TYPE_Q8_0:
5190
- return dequantize_row_q8_0_cuda;
5191
- case GGML_TYPE_Q2_K:
5192
- return dequantize_row_q2_K_cuda;
5193
- case GGML_TYPE_Q3_K:
5194
- return dequantize_row_q3_K_cuda;
5195
- case GGML_TYPE_Q4_K:
5196
- return dequantize_row_q4_K_cuda;
5197
- case GGML_TYPE_Q5_K:
5198
- return dequantize_row_q5_K_cuda;
5199
- case GGML_TYPE_Q6_K:
5200
- return dequantize_row_q6_K_cuda;
5201
- case GGML_TYPE_F16:
5202
- return convert_fp16_to_fp32_cuda;
5203
- default:
5204
- return nullptr;
5205
- }
5206
- }
5207
-
5208
  static void ggml_mul_mat_q4_0_q8_1_cuda(
5209
  const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
5210
  const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -5697,6 +6012,39 @@ static void ggml_cpy_f32_f16_cuda(
5697
  (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5698
  }
5699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5700
  static void ggml_cpy_f16_f16_cuda(
5701
  const char * cx, char * cdst, const int ne,
5702
  const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -5739,20 +6087,26 @@ static void rope_cuda(
5739
 
5740
  template<typename T>
5741
  static void rope_neox_cuda(
5742
- const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
5743
  float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
5744
  ) {
5745
  GGML_ASSERT(ncols % 2 == 0);
5746
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
5747
  const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
5748
  const dim3 block_nums(nrows, num_blocks_x, 1);
 
 
 
 
5749
  if (pos == nullptr) {
5750
  rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
5751
- x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
 
5752
  );
5753
  } else {
5754
  rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
5755
- x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
 
5756
  );
5757
  }
5758
  }
@@ -5777,6 +6131,27 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
5777
  alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
5778
  }
5779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5780
  static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
5781
  const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
5782
  const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -5784,10 +6159,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
5784
  diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
5785
  }
5786
 
5787
- static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5788
- const dim3 block_dims(1, WARP_SIZE, 1);
 
 
5789
  const dim3 block_nums(nrows_x, 1, 1);
5790
- soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
5791
  }
5792
 
5793
  static void im2col_f32_f16_cuda(const float * x, half * dst,
@@ -5867,7 +6244,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
5867
  return ptr;
5868
  }
5869
  #ifdef DEBUG_CUDA_MALLOC
5870
- fprintf(stderr, "%s: %d buffers, max_size = %u MiB, tot_size = %u MiB, requested %u MiB\n", __func__, nnz,
5871
  (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
5872
  #endif
5873
  void * ptr;
@@ -6005,7 +6382,7 @@ void * ggml_cuda_host_malloc(size_t size) {
6005
  // The allocation error can be bypassed. A null ptr will assigned out of this function.
6006
  // This can fixed the OOM error in WSL.
6007
  cudaGetLastError();
6008
- fprintf(stderr, "WARNING: failed to allocate %.2f MiB of pinned memory: %s\n",
6009
  size/1024.0/1024.0, cudaGetErrorString(err));
6010
  return nullptr;
6011
  }
@@ -6064,63 +6441,6 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
6064
  }
6065
  }
6066
 
6067
- static void ggml_cuda_op_repeat(
6068
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6069
- const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
6070
- // guaranteed to be an integer due to the check in ggml_can_repeat
6071
- const int64_t ne0 = dst->ne[0];
6072
- const int64_t ne1 = dst->ne[1];
6073
- const int64_t ne2 = dst->ne[2];
6074
- const int64_t ne3 = dst->ne[3];
6075
-
6076
- const int64_t ne00 = src0->ne[0];
6077
- const int64_t ne01 = src0->ne[1];
6078
- const int64_t ne02 = src0->ne[2];
6079
- const int64_t ne03 = src0->ne[3];
6080
-
6081
- const size_t nb0 = dst->nb[0];
6082
- const size_t nb1 = dst->nb[1];
6083
- const size_t nb2 = dst->nb[2];
6084
- const size_t nb3 = dst->nb[3];
6085
-
6086
- const size_t nb00 = src0->nb[0];
6087
- const size_t nb01 = src0->nb[1];
6088
- const size_t nb02 = src0->nb[2];
6089
- const size_t nb03 = src0->nb[3];
6090
-
6091
- const int nr0 = (int)(ne0/ne00);
6092
- const int nr1 = (int)(ne1/ne01);
6093
- const int nr2 = (int)(ne2/ne02);
6094
- const int nr3 = (int)(ne3/ne03);
6095
-
6096
- // TODO: support for transposed / permuted tensors
6097
- GGML_ASSERT(nb0 == sizeof(float));
6098
- GGML_ASSERT(nb00 == sizeof(float));
6099
-
6100
- // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors
6101
- for (int i3 = 0; i3 < nr3; i3++) {
6102
- for (int k3 = 0; k3 < ne03; k3++) {
6103
- for (int i2 = 0; i2 < nr2; i2++) {
6104
- for (int k2 = 0; k2 < ne02; k2++) {
6105
- for (int i1 = 0; i1 < nr1; i1++) {
6106
- for (int k1 = 0; k1 < ne01; k1++) {
6107
- for (int i0 = 0; i0 < nr0; i0++) {
6108
- CUDA_CHECK(cudaMemcpyAsync(
6109
- (char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
6110
- (const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
6111
- ne00*nb0, cudaMemcpyDeviceToDevice, stream));
6112
- }
6113
- }
6114
- }
6115
- }
6116
- }
6117
- }
6118
- }
6119
-
6120
- (void) src1;
6121
- (void) src1_d;
6122
- }
6123
-
6124
  static void ggml_cuda_op_get_rows(
6125
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6126
  const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
@@ -6165,47 +6485,55 @@ static void ggml_cuda_op_get_rows(
6165
  }
6166
  }
6167
 
6168
- inline void ggml_cuda_op_add(
 
6169
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6170
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6171
 
6172
- GGML_ASSERT(ggml_is_contiguous(src0));
6173
- GGML_ASSERT(ggml_is_contiguous(src1));
6174
-
6175
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
6176
 
6177
- const int64_t ne10 = src1->ne[0];
6178
- const int64_t ne11 = src1->ne[1];
6179
-
6180
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6181
- add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
6182
  } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
6183
- add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
6184
  } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
6185
- add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream);
6186
  } else {
6187
- fprintf(stderr, "src0->type: %d dst->type: %d\n", src0->type, dst->type);
 
6188
  GGML_ASSERT(false);
6189
  }
 
 
 
 
 
 
 
6190
 
6191
  (void) src1;
6192
- (void) dst;
6193
  }
6194
 
6195
- inline void ggml_cuda_op_mul(
6196
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6197
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6198
 
6199
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
6200
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
6201
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
6202
 
6203
- const int64_t ne10 = src1->ne[0];
6204
- const int64_t ne11 = src1->ne[1];
 
6205
 
6206
- mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
 
6207
 
6208
- (void) dst;
 
 
 
 
6209
  }
6210
 
6211
  inline void ggml_cuda_op_gelu(
@@ -6274,7 +6602,10 @@ inline void ggml_cuda_op_norm(
6274
  const int64_t ne00 = src0->ne[0];
6275
  const int64_t nrows = ggml_nrows(src0);
6276
 
6277
- norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
 
 
 
6278
 
6279
  (void) src1;
6280
  (void) dst;
@@ -6429,6 +6760,8 @@ inline void ggml_cuda_op_mul_mat_vec_q(
6429
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
6430
  const int64_t src1_padded_row_size, const cudaStream_t & stream) {
6431
 
 
 
6432
  const int64_t ne00 = src0->ne[0];
6433
  const int64_t row_diff = row_high - row_low;
6434
 
@@ -6488,7 +6821,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
6488
  size_t ash;
6489
  dfloat * src1_dfloat = nullptr; // dfloat == half
6490
 
6491
- bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
 
6492
  src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
6493
  src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
6494
 
@@ -6710,15 +7044,14 @@ inline void ggml_cuda_op_rope(
6710
  GGML_ASSERT(false);
6711
  rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
6712
  } else if (is_neox) {
6713
- GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
6714
  if (src0->type == GGML_TYPE_F32) {
6715
  rope_neox_cuda(
6716
- (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6717
  attn_factor, corr_dims, main_stream
6718
  );
6719
  } else if (src0->type == GGML_TYPE_F16) {
6720
  rope_neox_cuda(
6721
- (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6722
  attn_factor, corr_dims, main_stream
6723
  );
6724
  } else {
@@ -6815,6 +7148,42 @@ inline void ggml_cuda_op_im2col(
6815
  (void) src0_dd;
6816
  }
6817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6818
  inline void ggml_cuda_op_diag_mask_inf(
6819
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6820
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6842,14 +7211,18 @@ inline void ggml_cuda_op_soft_max(
6842
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
6843
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
6844
 
 
 
6845
  const int64_t ne00 = src0->ne[0];
6846
- const int64_t nrows = ggml_nrows(src0);
 
6847
 
6848
- soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
 
 
 
6849
 
6850
- (void) src1;
6851
  (void) dst;
6852
- (void) src1_dd;
6853
  }
6854
 
6855
  inline void ggml_cuda_op_scale(
@@ -7019,7 +7392,7 @@ static void ggml_cuda_op_mul_mat(
7019
  const int64_t ne01 = src0->ne[1];
7020
  const int64_t ne02 = src0->ne[2];
7021
  const int64_t ne03 = src0->ne[3];
7022
- // const int64_t nrows0 = ggml_nrows(src0);
7023
 
7024
  const int64_t ne10 = src1->ne[0];
7025
  const int64_t ne11 = src1->ne[1];
@@ -7055,10 +7428,9 @@ static void ggml_cuda_op_mul_mat(
7055
 
7056
  const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
7057
  const bool src0_is_contiguous = ggml_is_contiguous(src0);
7058
-
7059
  const bool src1_is_contiguous = ggml_is_contiguous(src1);
7060
- const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ?
7061
- ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
7062
 
7063
  const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
7064
  GGML_ASSERT(!(split && ne02 > 1));
@@ -7183,7 +7555,7 @@ static void ggml_cuda_op_mul_mat(
7183
  const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
7184
 
7185
  // for split tensors the data begins at i0 == i0_offset_low
7186
- char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs;
7187
  float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
7188
  char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset;
7189
  float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@@ -7328,6 +7700,10 @@ static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, gg
7328
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
7329
  }
7330
 
 
 
 
 
7331
  static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7332
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
7333
  }
@@ -7353,7 +7729,7 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
7353
  }
7354
 
7355
  bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
7356
- if (!g_cublas_loaded) { return false; }
7357
 
7358
  const int64_t ne10 = src1->ne[0];
7359
 
@@ -7431,7 +7807,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
7431
  ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
7432
  }
7433
 
7434
- __global__ static void k_compute_batched_ptrs(
7435
  const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
7436
  const void ** ptrs_src, void ** ptrs_dst,
7437
  int ne12, int ne13,
@@ -7487,9 +7863,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7487
  CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7488
  cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
7489
 
7490
- int id;
7491
- CUDA_CHECK(cudaGetDevice(&id));
7492
- CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
7493
 
7494
  ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
7495
  void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -7546,7 +7920,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7546
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
7547
  // use cublasGemmStridedBatchedEx
7548
  CUBLAS_CHECK(
7549
- cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7550
  ne01, ne11, ne10,
7551
  &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
7552
  (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
@@ -7580,7 +7954,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7580
  CUDA_CHECK(cudaGetLastError());
7581
 
7582
  CUBLAS_CHECK(
7583
- cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7584
  ne01, ne11, ne10,
7585
  &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7586
  (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
@@ -7650,10 +8024,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
7650
  #ifdef GGML_CUDA_FORCE_DMMV
7651
  const bool use_mul_mat_vec_q = false;
7652
  #else
7653
- const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
7654
  #endif // GGML_CUDA_FORCE_DMMV
7655
 
7656
  if (use_mul_mat_vec_q) {
 
7657
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
7658
  } else {
7659
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
@@ -7678,42 +8053,255 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
7678
  }
7679
  }
7680
 
7681
- static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7682
- ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
7683
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7684
 
7685
- static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7686
- ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
 
 
 
 
 
 
 
 
 
 
 
7687
  }
7688
 
7689
- static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7690
- const int64_t ne = ggml_nelements(src0);
7691
- GGML_ASSERT(ne == ggml_nelements(src1));
 
7692
 
7693
- GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
7694
- GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
7695
 
7696
- GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
7697
- GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
7698
 
7699
- const int64_t ne00 = src0->ne[0];
7700
- const int64_t ne01 = src0->ne[1];
7701
- GGML_ASSERT(src0->ne[3] == 1);
7702
 
7703
- const int64_t nb00 = src0->nb[0];
7704
- const int64_t nb01 = src0->nb[1];
7705
- const int64_t nb02 = src0->nb[2];
 
 
 
 
 
7706
 
7707
  const int64_t ne10 = src1->ne[0];
7708
  const int64_t ne11 = src1->ne[1];
7709
- GGML_ASSERT(src1->ne[3] == 1);
 
7710
 
7711
- const int64_t nb10 = src1->nb[0];
7712
- const int64_t nb11 = src1->nb[1];
7713
- const int64_t nb12 = src1->nb[2];
7714
 
7715
- CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7716
- cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7717
 
7718
  const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
7719
  const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
@@ -7722,14 +8310,17 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
7722
  char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
7723
 
7724
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
7725
- ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7726
- ne10, ne11, nb10, nb11, nb12, main_stream);
7727
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
7728
- ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7729
- ne10, ne11, nb10, nb11, nb12, main_stream);
 
 
 
 
 
7730
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
7731
- ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7732
- ne10, ne11, nb10, nb11, nb12, main_stream);
7733
  } else {
7734
  fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
7735
  ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7740,6 +8331,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
7740
  }
7741
 
7742
  static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
7743
  ggml_cuda_cpy(src0, dst, nullptr);
7744
  (void) src1;
7745
  }
@@ -7765,6 +8357,16 @@ static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1,
7765
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
7766
  }
7767
 
 
 
 
 
 
 
 
 
 
 
7768
  static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7769
  (void) src0;
7770
  (void) src1;
@@ -8020,8 +8622,9 @@ void ggml_cuda_set_main_device(const int main_device) {
8020
  main_device, g_device_count, g_main_device);
8021
  return;
8022
  }
8023
- g_main_device = main_device;
8024
- if (g_device_count > 1) {
 
8025
  cudaDeviceProp prop;
8026
  CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device));
8027
  fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name);
@@ -8047,7 +8650,7 @@ void ggml_cuda_free_scratch() {
8047
  }
8048
 
8049
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
8050
- if (!g_cublas_loaded) { return false; }
8051
 
8052
  ggml_cuda_func_t func;
8053
  const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
@@ -8083,6 +8686,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8083
  case GGML_OP_MUL:
8084
  func = ggml_cuda_mul;
8085
  break;
 
 
 
8086
  case GGML_OP_UNARY:
8087
  switch (ggml_get_unary_op(tensor)) {
8088
  case GGML_UNARY_OP_GELU:
@@ -8096,7 +8702,8 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8096
  break;
8097
  default:
8098
  return false;
8099
- } break;
 
8100
  case GGML_OP_NORM:
8101
  func = ggml_cuda_norm;
8102
  break;
@@ -8109,6 +8716,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8109
  }
8110
  func = ggml_cuda_mul_mat;
8111
  break;
 
 
 
 
 
 
8112
  case GGML_OP_SCALE:
8113
  func = ggml_cuda_scale;
8114
  break;
@@ -8148,6 +8761,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8148
  case GGML_OP_IM2COL:
8149
  func = ggml_cuda_im2col;
8150
  break;
 
 
 
 
 
 
8151
  default:
8152
  return false;
8153
  }
@@ -8164,7 +8783,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8164
 
8165
  int ggml_cuda_get_device_count() {
8166
  int device_count;
8167
- CUDA_CHECK(cudaGetDeviceCount(&device_count));
 
 
8168
  return device_count;
8169
  }
8170
 
@@ -8180,27 +8801,16 @@ void ggml_cuda_get_device_description(int device, char * description, size_t des
8180
 
8181
  #define UNUSED GGML_UNUSED
8182
 
8183
- struct ggml_backend_context_cuda {
8184
- };
8185
-
8186
- static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
8187
- return GGML_CUDA_NAME;
8188
-
8189
- UNUSED(backend);
8190
- }
8191
-
8192
- static void ggml_backend_cuda_free(ggml_backend_t backend) {
8193
- ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
8194
- delete cuda_ctx;
8195
- delete backend;
8196
- }
8197
 
8198
  struct ggml_backend_buffer_context_cuda {
8199
- void * device;
8200
-
8201
  ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
8202
  size_t temp_tensor_extra_index = 0;
8203
 
 
 
8204
  ~ggml_backend_buffer_context_cuda() {
8205
  delete[] temp_tensor_extras;
8206
  }
@@ -8221,41 +8831,20 @@ struct ggml_backend_buffer_context_cuda {
8221
 
8222
  static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
8223
  ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
8224
- CUDA_CHECK(cudaFree(ctx->device));
8225
  delete ctx;
8226
  }
8227
 
8228
  static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
8229
  ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
8230
- return ctx->device;
8231
- }
8232
-
8233
- static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
8234
- int64_t row_low = 0;
8235
- int64_t row_high = ggml_nrows(tensor);
8236
- int64_t nrows_split = row_high - row_low;
8237
-
8238
- size_t size = ggml_nbytes_split(tensor, nrows_split);
8239
-
8240
- int64_t ne0 = tensor->ne[0];
8241
-
8242
- if (ggml_is_quantized(tensor->type)) {
8243
- if (ne0 % MATRIX_ROW_PADDING != 0) {
8244
- size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
8245
- * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
8246
- }
8247
- }
8248
-
8249
- return size;
8250
-
8251
- UNUSED(buffer);
8252
  }
8253
 
8254
  static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
8255
  ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
8256
 
8257
  if (tensor->view_src != NULL && tensor->view_offs == 0) {
8258
- assert(tensor->view_src->buffer->backend == buffer->backend);
8259
  tensor->backend = tensor->view_src->backend;
8260
  tensor->extra = tensor->view_src->extra;
8261
  return;
@@ -8263,7 +8852,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
8263
 
8264
  ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra();
8265
 
8266
- extra->data_device[g_main_device] = tensor->data;
8267
 
8268
  tensor->backend = GGML_BACKEND_GPU;
8269
  tensor->extra = extra;
@@ -8275,64 +8864,208 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
8275
  int64_t nrows_split = row_high - row_low;
8276
 
8277
  size_t original_size = ggml_nbytes_split(tensor, nrows_split);
8278
- size_t padded_size = ggml_backend_cuda_buffer_get_alloc_size(tensor->buffer, tensor);
8279
 
8280
  if (padded_size > original_size && tensor->view_src == nullptr) {
8281
- CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[g_main_device][0]));
8282
  }
8283
  }
8284
 
8285
  UNUSED(buffer);
8286
  }
8287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8288
  static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
8289
- /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
8290
- /* .get_base = */ ggml_backend_cuda_buffer_get_base,
8291
- /* .get_alloc_size = */ ggml_backend_cuda_buffer_get_alloc_size,
8292
- /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
8293
- /* .free_tensor = */ NULL,
 
 
8294
  };
8295
 
8296
- static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) {
8297
- ggml_cuda_set_device(g_main_device);
8298
 
8299
- ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda;
 
 
 
8300
 
8301
  size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
8302
 
8303
- ggml_cuda_set_device(g_main_device);
8304
- CUDA_CHECK(cudaMalloc(&ctx->device, size));
8305
 
8306
- return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size);
 
 
8307
  }
8308
 
8309
- static size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) {
8310
  return 128;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8311
  UNUSED(backend);
8312
  }
8313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8314
  static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
 
 
 
8315
  GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
8316
  GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
8317
  GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
8318
 
8319
- CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[g_main_device][0]));
8320
-
8321
- UNUSED(backend);
8322
  }
8323
 
8324
  static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
 
 
 
8325
  GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
8326
  GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
8327
  GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
8328
 
8329
- CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8330
-
8331
- UNUSED(backend);
8332
  }
8333
 
8334
  static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
8335
- CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
 
 
8336
 
8337
  UNUSED(backend);
8338
  }
@@ -8346,14 +9079,14 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
8346
  UNUSED(cgraph);
8347
  }
8348
 
8349
- [[noreturn]] static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
8350
  GGML_ASSERT(!"not implemented");
8351
 
8352
  UNUSED(backend);
8353
  UNUSED(plan);
8354
  }
8355
 
8356
- [[noreturn]] static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
8357
  GGML_ASSERT(!"not implemented");
8358
 
8359
  UNUSED(backend);
@@ -8361,7 +9094,9 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
8361
  }
8362
 
8363
  static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
8364
- ggml_cuda_set_device(g_main_device);
 
 
8365
 
8366
  ggml_compute_params params = {};
8367
  params.type = GGML_TASK_COMPUTE;
@@ -8369,13 +9104,18 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
8369
  for (int i = 0; i < cgraph->n_nodes; i++) {
8370
  ggml_tensor * node = cgraph->nodes[i];
8371
 
8372
- if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) {
8373
  continue;
8374
- }
8375
  assert(node->backend == GGML_BACKEND_GPU);
 
 
 
8376
  for (int j = 0; j < GGML_MAX_SRC; j++) {
8377
  if (node->src[j] != nullptr) {
8378
  assert(node->src[j]->backend == GGML_BACKEND_GPU);
 
 
8379
  }
8380
  }
8381
 
@@ -8412,27 +9152,98 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
8412
  UNUSED(backend);
8413
  }
8414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8415
  static ggml_backend_i cuda_backend_i = {
8416
- /* .get_name = */ ggml_backend_cuda_name,
8417
- /* .free = */ ggml_backend_cuda_free,
8418
- /* .alloc_buffer = */ ggml_backend_cuda_alloc_buffer,
8419
- /* .get_alignment = */ ggml_backend_cuda_get_alignment,
8420
- /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
8421
- /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
8422
- /* .synchronize = */ ggml_backend_cuda_synchronize,
8423
- /* .cpy_tensor_from = */ nullptr,
8424
- /* .cpy_tensor_to = */ nullptr,
8425
- /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create,
8426
- /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free,
8427
- /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute,
8428
- /* .graph_compute = */ ggml_backend_cuda_graph_compute,
8429
- /* .supports_op = */ nullptr,
8430
  };
8431
 
8432
- ggml_backend_t ggml_backend_cuda_init() {
8433
  ggml_init_cublas(); // TODO: remove from ggml.c
8434
 
8435
- ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda;
 
 
 
 
 
 
 
 
 
 
8436
 
8437
  ggml_backend_t cuda_backend = new ggml_backend {
8438
  /* .interface = */ cuda_backend_i,
@@ -8441,3 +9252,25 @@ ggml_backend_t ggml_backend_cuda_init() {
8441
 
8442
  return cuda_backend;
8443
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #include <algorithm>
 
2
  #include <cstddef>
3
  #include <cstdint>
4
+ #include <cinttypes>
5
+ #include <float.h>
6
  #include <limits>
7
  #include <stdint.h>
8
  #include <stdio.h>
 
70
  #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
71
  #define cudaSetDevice hipSetDevice
72
  #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
73
+ #define cudaStreamFireAndForget hipStreamFireAndForget
74
  #define cudaStreamNonBlocking hipStreamNonBlocking
75
  #define cudaStreamSynchronize hipStreamSynchronize
76
  #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
 
192
  fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
193
  cudaGetErrorString(err_)); \
194
  fprintf(stderr, "current device: %d\n", id); \
195
+ GGML_ASSERT(!"CUDA error"); \
196
  } \
197
  } while (0)
198
 
 
206
  fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
207
  err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
208
  fprintf(stderr, "current device: %d\n", id); \
209
+ GGML_ASSERT(!"cuBLAS error"); \
210
  } \
211
  } while (0)
212
  #else
 
218
  cudaGetDevice(&id); \
219
  fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
220
  fprintf(stderr, "current device: %d\n", id); \
221
+ GGML_ASSERT(!"cuBLAS error"); \
222
  } \
223
  } while (0)
224
  #endif // CUDART_VERSION >= 11
 
435
  #define WARP_SIZE 32
436
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
437
 
 
 
438
  #define CUDA_GELU_BLOCK_SIZE 256
439
  #define CUDA_SILU_BLOCK_SIZE 256
440
  #define CUDA_RELU_BLOCK_SIZE 256
 
443
  #define CUDA_SCALE_BLOCK_SIZE 256
444
  #define CUDA_CLAMP_BLOCK_SIZE 256
445
  #define CUDA_ROPE_BLOCK_SIZE 256
446
+ #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
447
  #define CUDA_ALIBI_BLOCK_SIZE 32
448
  #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
449
  #define CUDA_QUANTIZE_BLOCK_SIZE 256
 
502
 
503
  static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
504
 
505
+ static __device__ __forceinline__ float warp_reduce_sum(float x) {
506
+ #pragma unroll
507
+ for (int mask = 16; mask > 0; mask >>= 1) {
508
+ x += __shfl_xor_sync(0xffffffff, x, mask, 32);
 
509
  }
510
+ return x;
511
  }
512
 
513
+ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
514
+ #pragma unroll
515
+ for (int mask = 16; mask > 0; mask >>= 1) {
516
+ a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
517
+ a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
518
+ }
519
+ return a;
520
+ }
521
 
522
+ static __device__ __forceinline__ float warp_reduce_max(float x) {
523
+ #pragma unroll
524
+ for (int mask = 16; mask > 0; mask >>= 1) {
525
+ x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
526
  }
527
+ return x;
528
  }
529
 
530
+ static __device__ __forceinline__ float op_repeat(const float a, const float b) {
531
+ return b;
532
+ }
533
 
534
+ static __device__ __forceinline__ float op_add(const float a, const float b) {
535
+ return a + b;
536
+ }
537
+
538
+ static __device__ __forceinline__ float op_mul(const float a, const float b) {
539
+ return a * b;
540
+ }
541
+
542
+ static __device__ __forceinline__ float op_div(const float a, const float b) {
543
+ return a / b;
544
+ }
545
+
546
+ template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
547
+ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
548
+ int ne0, int ne1, int ne2, int ne3,
549
+ int ne10, int ne11, int ne12, int ne13,
550
+ /*int s0, */ int s1, int s2, int s3,
551
+ /*int s10,*/ int s11, int s12, int s13) {
552
+ const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
553
+ const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
554
+ const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
555
+ const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
556
+
557
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
558
  return;
559
  }
560
+
561
+ const int i11 = i1 % ne11;
562
+ const int i12 = i2 % ne12;
563
+ const int i13 = i3 % ne13;
564
+
565
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
566
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
567
+ const size_t i_dst = i_src0;
568
+
569
+ const src0_t * src0_row = src0 + i_src0;
570
+ const src1_t * src1_row = src1 + i_src1;
571
+ dst_t * dst_row = dst + i_dst;
572
+
573
+ for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
574
+ const int i10 = i0 % ne10;
575
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
576
+ }
577
  }
578
 
579
+ template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
580
+ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
581
+ int ne0, int ne1, int ne2, int ne3,
582
+ int ne10, int ne11, int ne12, int ne13,
583
+ /*int s0, */ int s1, int s2, int s3,
584
+ /*int s10,*/ int s11, int s12, int s13) {
585
+
586
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
587
 
588
+ const int i3 = i/(ne2*ne1*ne0);
589
+ const int i2 = (i/(ne1*ne0)) % ne2;
590
+ const int i1 = (i/ne0) % ne1;
591
+ const int i0 = i % ne0;
592
+
593
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
594
  return;
595
  }
596
+
597
+ const int i11 = i1 % ne11;
598
+ const int i12 = i2 % ne12;
599
+ const int i13 = i3 % ne13;
600
+
601
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
602
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
603
+ const size_t i_dst = i_src0;
604
+
605
+ const src0_t * src0_row = src0 + i_src0;
606
+ const src1_t * src1_row = src1 + i_src1;
607
+ dst_t * dst_row = dst + i_dst;
608
+
609
+ const int i10 = i0 % ne10;
610
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
611
  }
612
 
613
  static __global__ void gelu_f32(const float * x, float * dst, const int k) {
 
650
  dst[i] = x[i] * x[i];
651
  }
652
 
 
 
 
 
 
 
 
 
 
653
  template <int block_size>
654
+ static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
655
  const int row = blockIdx.x*blockDim.y + threadIdx.y;
656
  const int tid = threadIdx.x;
657
 
 
 
658
  float2 mean_var = make_float2(0.f, 0.f);
659
 
660
  for (int col = tid; col < ncols; col += block_size) {
 
686
  }
687
  }
688
 
 
 
 
 
 
 
 
 
689
  template <int block_size>
690
  static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
691
  const int row = blockIdx.x*blockDim.y + threadIdx.y;
 
4604
  cpy_1(cx + x_offset, cdst + dst_offset);
4605
  }
4606
 
4607
+ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
4608
+ const float * xi = (const float *) cxi;
4609
+ block_q8_0 * dsti = (block_q8_0 *) cdsti;
4610
+
4611
+ float amax = 0.0f; // absolute max
4612
+
4613
+ for (int j = 0; j < QK8_0; j++) {
4614
+ const float v = xi[j];
4615
+ amax = fmaxf(amax, fabsf(v));
4616
+ }
4617
+
4618
+ const float d = amax / ((1 << 7) - 1);
4619
+ const float id = d ? 1.0f/d : 0.0f;
4620
+
4621
+ dsti->d = d;
4622
+
4623
+ for (int j = 0; j < QK8_0; ++j) {
4624
+ const float x0 = xi[j]*id;
4625
+
4626
+ dsti->qs[j] = roundf(x0);
4627
+ }
4628
+ }
4629
+
4630
+ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
4631
+ const float * xi = (const float *) cxi;
4632
+ block_q4_0 * dsti = (block_q4_0 *) cdsti;
4633
+
4634
+ float amax = 0.0f;
4635
+ float vmax = 0.0f;
4636
+
4637
+ for (int j = 0; j < QK4_0; ++j) {
4638
+ const float v = xi[j];
4639
+ if (amax < fabsf(v)) {
4640
+ amax = fabsf(v);
4641
+ vmax = v;
4642
+ }
4643
+ }
4644
+
4645
+ const float d = vmax / -8;
4646
+ const float id = d ? 1.0f/d : 0.0f;
4647
+
4648
+ dsti->d = d;
4649
+
4650
+ for (int j = 0; j < QK4_0/2; ++j) {
4651
+ const float x0 = xi[0 + j]*id;
4652
+ const float x1 = xi[QK4_0/2 + j]*id;
4653
+
4654
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
4655
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
4656
+
4657
+ dsti->qs[j] = xi0;
4658
+ dsti->qs[j] |= xi1 << 4;
4659
+ }
4660
+ }
4661
+
4662
+ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
4663
+ const float * xi = (const float *) cxi;
4664
+ block_q4_1 * dsti = (block_q4_1 *) cdsti;
4665
+
4666
+ float vmin = FLT_MAX;
4667
+ float vmax = -FLT_MAX;
4668
+
4669
+ for (int j = 0; j < QK4_1; ++j) {
4670
+ const float v = xi[j];
4671
+
4672
+ if (v < vmin) vmin = v;
4673
+ if (v > vmax) vmax = v;
4674
+ }
4675
+
4676
+ const float d = (vmax - vmin) / ((1 << 4) - 1);
4677
+ const float id = d ? 1.0f/d : 0.0f;
4678
+
4679
+ dsti->dm.x = d;
4680
+ dsti->dm.y = vmin;
4681
+
4682
+ for (int j = 0; j < QK4_1/2; ++j) {
4683
+ const float x0 = (xi[0 + j] - vmin)*id;
4684
+ const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
4685
+
4686
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
4687
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
4688
+
4689
+ dsti->qs[j] = xi0;
4690
+ dsti->qs[j] |= xi1 << 4;
4691
+ }
4692
+ }
4693
+
4694
+ template <cpy_kernel_t cpy_blck, int qk>
4695
+ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
4696
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
4697
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
4698
+ const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
4699
+
4700
+ if (i >= ne) {
4701
+ return;
4702
+ }
4703
+
4704
+ const int i02 = i / (ne00*ne01);
4705
+ const int i01 = (i - i02*ne01*ne00) / ne00;
4706
+ const int i00 = (i - i02*ne01*ne00 - i01*ne00);
4707
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
4708
+
4709
+ const int i12 = i / (ne10*ne11);
4710
+ const int i11 = (i - i12*ne10*ne11) / ne10;
4711
+ const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
4712
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
4713
+
4714
+ cpy_blck(cx + x_offset, cdst + dst_offset);
4715
+ }
4716
+
4717
  static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
4718
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
4719
  return 1.0f - min(1.0f, max(0.0f, y));
 
4774
 
4775
  template<typename T, bool has_pos>
4776
  static __global__ void rope_neox(
4777
+ const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
4778
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
4779
  ) {
4780
  const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
4781
 
 
4784
  }
4785
 
4786
  const int row = blockDim.x*blockIdx.x + threadIdx.x;
4787
+ const int ib = col / n_dims;
4788
+ const int ic = col % n_dims;
4789
+
4790
+ const int i = row*ncols + ib*n_dims + ic/2;
4791
  const int i2 = row/p_delta_rows;
4792
 
4793
+ float cur_rot = inv_ndims * ic - ib;
 
4794
 
4795
  const int p = has_pos ? pos[i2] : 0;
4796
+ const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
4797
 
4798
  float cos_theta, sin_theta;
4799
  rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
4800
 
4801
  const float x0 = x[i + 0];
4802
+ const float x1 = x[i + n_dims/2];
4803
 
4804
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
4805
+ dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
4806
  }
4807
 
4808
  static __global__ void rope_glm_f32(
 
4868
  dst[i] = col * m_k + x[i];
4869
  }
4870
 
4871
+ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
4872
+ const int row = blockIdx.y;
4873
+ const int col = threadIdx.x;
4874
+
4875
+ float sum = 0.0f;
4876
+ for (int i = col; i < ncols; i += blockDim.x) {
4877
+ sum += x[row * ncols + i];
4878
+ }
4879
+
4880
+ sum = warp_reduce_sum(sum);
4881
+
4882
+ if (col == 0) {
4883
+ dst[row] = sum;
4884
+ }
4885
+ }
4886
+
4887
+ template<typename T>
4888
+ static inline __device__ void swap(T & a, T & b) {
4889
+ T tmp = a;
4890
+ a = b;
4891
+ b = tmp;
4892
+ }
4893
+
4894
+ template<ggml_sort_order order>
4895
+ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
4896
+ // bitonic sort
4897
+ int col = threadIdx.x;
4898
+ int row = blockIdx.y;
4899
+
4900
+ if (col >= ncols) return;
4901
+
4902
+ const float * x_row = x + row * ncols;
4903
+ int * dst_row = dst + row * ncols;
4904
+
4905
+ // initialize indices
4906
+ if (col < ncols) {
4907
+ dst_row[col] = col;
4908
+ }
4909
+ __syncthreads();
4910
+
4911
+ for (int k = 2; k <= ncols; k *= 2) {
4912
+ for (int j = k / 2; j > 0; j /= 2) {
4913
+ int ixj = col ^ j;
4914
+ if (ixj > col) {
4915
+ if ((col & k) == 0) {
4916
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
4917
+ swap(dst_row[col], dst_row[ixj]);
4918
+ }
4919
+ } else {
4920
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
4921
+ swap(dst_row[col], dst_row[ixj]);
4922
+ }
4923
+ }
4924
+ }
4925
+ __syncthreads();
4926
+ }
4927
+ }
4928
+ }
4929
+
4930
  static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
4931
  const int col = blockDim.y*blockIdx.y + threadIdx.y;
4932
  const int row = blockDim.x*blockIdx.x + threadIdx.x;
 
4936
  }
4937
 
4938
  const int i = row*ncols + col;
4939
+ //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
4940
+ //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
4941
+ dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
4942
  }
4943
 
4944
+ static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4945
+ const int tid = threadIdx.x;
4946
+ const int rowx = blockIdx.x;
4947
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4948
+
4949
+ const int block_size = blockDim.x;
4950
+
4951
+ const int warp_id = threadIdx.x / WARP_SIZE;
4952
+ const int lane_id = threadIdx.x % WARP_SIZE;
4953
+
4954
+ __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
4955
 
4956
  float max_val = -INFINITY;
4957
 
4958
  for (int col = tid; col < ncols; col += block_size) {
4959
+ const int ix = rowx*ncols + col;
4960
+ const int iy = rowy*ncols + col;
4961
+ max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
4962
  }
4963
 
4964
  // find the max value in the block
4965
+ max_val = warp_reduce_max(max_val);
4966
+ if (block_size > WARP_SIZE) {
4967
+ if (warp_id == 0) {
4968
+ buf[lane_id] = -INFINITY;
4969
+ }
4970
+ __syncthreads();
4971
+
4972
+ if (lane_id == 0) {
4973
+ buf[warp_id] = max_val;
4974
+ }
4975
+ __syncthreads();
4976
+
4977
+ max_val = buf[lane_id];
4978
+ max_val = warp_reduce_max(max_val);
4979
  }
4980
 
4981
  float tmp = 0.f;
4982
 
4983
  for (int col = tid; col < ncols; col += block_size) {
4984
+ const int ix = rowx*ncols + col;
4985
+ const int iy = rowy*ncols + col;
4986
+ const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
4987
  tmp += val;
4988
+ dst[ix] = val;
4989
  }
4990
 
4991
+ // find the sum of exps in the block
4992
+ tmp = warp_reduce_sum(tmp);
4993
+ if (block_size > WARP_SIZE) {
4994
+ if (warp_id == 0) {
4995
+ buf[lane_id] = 0.f;
4996
+ }
4997
+ __syncthreads();
4998
+
4999
+ if (lane_id == 0) {
5000
+ buf[warp_id] = tmp;
5001
+ }
5002
+ __syncthreads();
5003
+
5004
+ tmp = buf[lane_id];
5005
+ tmp = warp_reduce_sum(tmp);
5006
  }
5007
 
5008
  const float inv_tmp = 1.f / tmp;
5009
 
5010
  for (int col = tid; col < ncols; col += block_size) {
5011
+ const int i = rowx*ncols + col;
5012
  dst[i] *= inv_tmp;
5013
  }
5014
  }
 
5060
  k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
5061
  }
5062
 
5063
+ template<float (*bin_op)(const float, const float)>
5064
+ struct bin_bcast_cuda {
5065
+ template<typename src0_t, typename src1_t, typename dst_t>
5066
+ void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
5067
+ const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
5068
+ cudaStream_t stream) {
5069
+
5070
+ GGML_TENSOR_BINARY_OP_LOCALS
5071
+
5072
+
5073
+ int nr0 = ne10/ne0;
5074
+ int nr1 = ne11/ne1;
5075
+ int nr2 = ne12/ne2;
5076
+ int nr3 = ne13/ne3;
5077
+
5078
+ int nr[4] = { nr0, nr1, nr2, nr3 };
5079
+
5080
+ // collapse dimensions until first broadcast dimension
5081
+ int64_t cne0[] = {ne0, ne1, ne2, ne3};
5082
+ int64_t cne1[] = {ne10, ne11, ne12, ne13};
5083
+ size_t cnb0[] = {nb0, nb1, nb2, nb3};
5084
+ size_t cnb1[] = {nb10, nb11, nb12, nb13};
5085
+ auto collapse = [](int64_t cne[]) {
5086
+ cne[0] *= cne[1];
5087
+ cne[1] = cne[2];
5088
+ cne[2] = cne[3];
5089
+ cne[3] = 1;
5090
+ };
5091
+
5092
+ auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
5093
+ cnb[1] *= cne[1];
5094
+ cnb[2] *= cne[2];
5095
+ cnb[3] *= cne[3];
5096
+ };
5097
+
5098
+ for (int i = 0; i < 4; i++) {
5099
+ if (nr[i] != 1) {
5100
+ break;
5101
+ }
5102
+ if (i > 0) {
5103
+ collapse_nb(cnb0, cne0);
5104
+ collapse_nb(cnb1, cne1);
5105
+ collapse(cne0);
5106
+ collapse(cne1);
5107
+ }
5108
+ }
5109
+ {
5110
+ int64_t ne0 = cne0[0];
5111
+ int64_t ne1 = cne0[1];
5112
+ int64_t ne2 = cne0[2];
5113
+ int64_t ne3 = cne0[3];
5114
+
5115
+ int64_t ne10 = cne1[0];
5116
+ int64_t ne11 = cne1[1];
5117
+ int64_t ne12 = cne1[2];
5118
+ int64_t ne13 = cne1[3];
5119
+
5120
+ //size_t nb0 = cnb0[0];
5121
+ size_t nb1 = cnb0[1];
5122
+ size_t nb2 = cnb0[2];
5123
+ size_t nb3 = cnb0[3];
5124
+
5125
+ //size_t nb10 = cnb1[0];
5126
+ size_t nb11 = cnb1[1];
5127
+ size_t nb12 = cnb1[2];
5128
+ size_t nb13 = cnb1[3];
5129
+
5130
+ //size_t s0 = nb0 / sizeof(src1_t);
5131
+ size_t s1 = nb1 / sizeof(src1_t);
5132
+ size_t s2 = nb2 / sizeof(src1_t);
5133
+ size_t s3 = nb3 / sizeof(src1_t);
5134
+
5135
+ //size_t s10 = nb10 / sizeof(src1_t);
5136
+ size_t s11 = nb11 / sizeof(src1_t);
5137
+ size_t s12 = nb12 / sizeof(src1_t);
5138
+ size_t s13 = nb13 / sizeof(src1_t);
5139
+
5140
+
5141
+ const int block_size = 128;
5142
+
5143
+ int64_t hne0 = std::max(ne0/2LL, 1LL);
5144
+
5145
+ dim3 block_dims;
5146
+ block_dims.x = std::min<unsigned int>(hne0, block_size);
5147
+ block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
5148
+ block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
5149
+
5150
+ dim3 block_nums(
5151
+ (hne0 + block_dims.x - 1) / block_dims.x,
5152
+ (ne1 + block_dims.y - 1) / block_dims.y,
5153
+ (ne2*ne3 + block_dims.z - 1) / block_dims.z
5154
+ );
5155
 
5156
+ if (block_nums.z > 65535) {
5157
+ // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
5158
+ int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
5159
+ k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
5160
+ src0_dd, src1_dd, dst_dd,
5161
+ ne0, ne1, ne2, ne3,
5162
+ ne10, ne11, ne12, ne13,
5163
+ /* s0, */ s1, s2, s3,
5164
+ /* s10, */ s11, s12, s13);
5165
+ } else {
5166
+ k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
5167
+ src0_dd, src1_dd, dst_dd,
5168
+ ne0, ne1, ne2, ne3,
5169
+ ne10, ne11, ne12, ne13,
5170
+ /* s0, */ s1, s2, s3,
5171
+ /* s10, */ s11, s12, s13);
5172
+ }
5173
+ }
5174
+ }
5175
+ };
5176
 
5177
  static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
5178
  const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
 
5194
  sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
5195
  }
5196
 
5197
+ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
5198
  GGML_ASSERT(ncols % WARP_SIZE == 0);
5199
  if (ncols < 1024) {
5200
  const dim3 block_dims(WARP_SIZE, 1, 1);
5201
+ norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
5202
  } else {
5203
  const dim3 block_dims(1024, 1, 1);
5204
+ norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
5205
  }
5206
  }
5207
 
 
5223
  quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
5224
  }
5225
 
5226
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
5227
+ static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5228
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
5229
+ dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
5230
  }
5231
 
5232
  template<typename dst_t>
 
5275
  #endif
5276
  }
5277
 
5278
+ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
5279
+ switch (type) {
5280
+ case GGML_TYPE_Q4_0:
5281
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
5282
+ case GGML_TYPE_Q4_1:
5283
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
5284
+ case GGML_TYPE_Q5_0:
5285
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
5286
+ case GGML_TYPE_Q5_1:
5287
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
5288
+ case GGML_TYPE_Q8_0:
5289
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
5290
+ case GGML_TYPE_Q2_K:
5291
+ return dequantize_row_q2_K_cuda;
5292
+ case GGML_TYPE_Q3_K:
5293
+ return dequantize_row_q3_K_cuda;
5294
+ case GGML_TYPE_Q4_K:
5295
+ return dequantize_row_q4_K_cuda;
5296
+ case GGML_TYPE_Q5_K:
5297
+ return dequantize_row_q5_K_cuda;
5298
+ case GGML_TYPE_Q6_K:
5299
+ return dequantize_row_q6_K_cuda;
5300
+ case GGML_TYPE_F32:
5301
+ return dequantize_block_cuda<1, 1, convert_f32>;
5302
+ default:
5303
+ return nullptr;
5304
+ }
5305
+ }
5306
+
5307
+ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
5308
+ switch (type) {
5309
+ case GGML_TYPE_Q4_0:
5310
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
5311
+ case GGML_TYPE_Q4_1:
5312
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
5313
+ case GGML_TYPE_Q5_0:
5314
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
5315
+ case GGML_TYPE_Q5_1:
5316
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
5317
+ case GGML_TYPE_Q8_0:
5318
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
5319
+ case GGML_TYPE_Q2_K:
5320
+ return dequantize_row_q2_K_cuda;
5321
+ case GGML_TYPE_Q3_K:
5322
+ return dequantize_row_q3_K_cuda;
5323
+ case GGML_TYPE_Q4_K:
5324
+ return dequantize_row_q4_K_cuda;
5325
+ case GGML_TYPE_Q5_K:
5326
+ return dequantize_row_q5_K_cuda;
5327
+ case GGML_TYPE_Q6_K:
5328
+ return dequantize_row_q6_K_cuda;
5329
+ case GGML_TYPE_F16:
5330
+ return dequantize_block_cuda<1, 1, convert_f16>;
5331
+ default:
5332
+ return nullptr;
5333
+ }
5334
+ }
5335
+
5336
  static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5337
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
5338
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
 
5421
  dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
5422
  }
5423
 
5424
+ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5425
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
5426
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5427
  const dim3 block_nums(block_num_y, 1, 1);
5428
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5429
+ dequantize_mul_mat_vec<1, 1, convert_f16>
5430
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
5431
+ }
5432
+
5433
+ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5434
+ GGML_ASSERT(ncols % QK4_0 == 0);
5435
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5436
+ const dim3 block_nums(block_num_y, 1, 1);
5437
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
5438
+ mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
5439
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
5440
  }
5441
 
5442
  static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
 
5520
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
5521
  }
5522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5523
  static void ggml_mul_mat_q4_0_q8_1_cuda(
5524
  const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
5525
  const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
6012
  (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
6013
  }
6014
 
6015
+ static void ggml_cpy_f32_q8_0_cuda(
6016
+ const char * cx, char * cdst, const int ne,
6017
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
6018
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
6019
+
6020
+ GGML_ASSERT(ne % QK8_0 == 0);
6021
+ const int num_blocks = ne / QK8_0;
6022
+ cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
6023
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
6024
+ }
6025
+
6026
+ static void ggml_cpy_f32_q4_0_cuda(
6027
+ const char * cx, char * cdst, const int ne,
6028
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
6029
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
6030
+
6031
+ GGML_ASSERT(ne % QK4_0 == 0);
6032
+ const int num_blocks = ne / QK4_0;
6033
+ cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
6034
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
6035
+ }
6036
+
6037
+ static void ggml_cpy_f32_q4_1_cuda(
6038
+ const char * cx, char * cdst, const int ne,
6039
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
6040
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
6041
+
6042
+ GGML_ASSERT(ne % QK4_1 == 0);
6043
+ const int num_blocks = ne / QK4_1;
6044
+ cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
6045
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
6046
+ }
6047
+
6048
  static void ggml_cpy_f16_f16_cuda(
6049
  const char * cx, char * cdst, const int ne,
6050
  const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
 
6087
 
6088
  template<typename T>
6089
  static void rope_neox_cuda(
6090
+ const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
6091
  float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
6092
  ) {
6093
  GGML_ASSERT(ncols % 2 == 0);
6094
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
6095
  const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
6096
  const dim3 block_nums(nrows, num_blocks_x, 1);
6097
+
6098
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
6099
+ const float inv_ndims = -1.0f / n_dims;
6100
+
6101
  if (pos == nullptr) {
6102
  rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
6103
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
6104
+ theta_scale, inv_ndims
6105
  );
6106
  } else {
6107
  rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
6108
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
6109
+ theta_scale, inv_ndims
6110
  );
6111
  }
6112
  }
 
6131
  alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
6132
  }
6133
 
6134
+ static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6135
+ const dim3 block_dims(WARP_SIZE, 1, 1);
6136
+ const dim3 block_nums(1, nrows, 1);
6137
+ k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
6138
+ }
6139
+
6140
+ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
6141
+ // bitonic sort requires ncols to be power of 2
6142
+ GGML_ASSERT((ncols & (ncols - 1)) == 0);
6143
+
6144
+ const dim3 block_dims(ncols, 1, 1);
6145
+ const dim3 block_nums(1, nrows, 1);
6146
+ if (order == GGML_SORT_ASC) {
6147
+ k_argsort_f32_i32<GGML_SORT_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
6148
+ } else if (order == GGML_SORT_DESC) {
6149
+ k_argsort_f32_i32<GGML_SORT_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
6150
+ } else {
6151
+ GGML_ASSERT(false);
6152
+ }
6153
+ }
6154
+
6155
  static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
6156
  const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
6157
  const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
 
6159
  diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
6160
  }
6161
 
6162
+ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
6163
+ int nth = WARP_SIZE;
6164
+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
6165
+ const dim3 block_dims(nth, 1, 1);
6166
  const dim3 block_nums(nrows_x, 1, 1);
6167
+ soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6168
  }
6169
 
6170
  static void im2col_f32_f16_cuda(const float * x, half * dst,
 
6244
  return ptr;
6245
  }
6246
  #ifdef DEBUG_CUDA_MALLOC
6247
+ fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
6248
  (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
6249
  #endif
6250
  void * ptr;
 
6382
  // The allocation error can be bypassed. A null ptr will assigned out of this function.
6383
  // This can fixed the OOM error in WSL.
6384
  cudaGetLastError();
6385
+ fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
6386
  size/1024.0/1024.0, cudaGetErrorString(err));
6387
  return nullptr;
6388
  }
 
6441
  }
6442
  }
6443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6444
  static void ggml_cuda_op_get_rows(
6445
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6446
  const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
 
6485
  }
6486
  }
6487
 
6488
+ template<class op>
6489
+ inline void ggml_cuda_op_bin_bcast(
6490
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6491
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6492
 
 
 
 
6493
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
6494
 
 
 
 
6495
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6496
+ op()(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
6497
  } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
6498
+ op()(src0, src1, dst, (const half *) src0_dd, src1_dd, (half *) dst_dd, main_stream);
6499
  } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
6500
+ op()(src0, src1, dst, (const half *) src0_dd, src1_dd, dst_dd, main_stream);
6501
  } else {
6502
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
6503
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
6504
  GGML_ASSERT(false);
6505
  }
6506
+ }
6507
+
6508
+ static void ggml_cuda_op_repeat(
6509
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6510
+ const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) {
6511
+
6512
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
6513
 
6514
  (void) src1;
6515
+ (void) src1_d;
6516
  }
6517
 
6518
+ inline void ggml_cuda_op_add(
6519
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6520
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6521
 
6522
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
6523
+ }
 
6524
 
6525
+ inline void ggml_cuda_op_mul(
6526
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6527
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6528
 
6529
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
6530
+ }
6531
 
6532
+ inline void ggml_cuda_op_div(
6533
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6534
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6535
+
6536
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
6537
  }
6538
 
6539
  inline void ggml_cuda_op_gelu(
 
6602
  const int64_t ne00 = src0->ne[0];
6603
  const int64_t nrows = ggml_nrows(src0);
6604
 
6605
+ float eps;
6606
+ memcpy(&eps, dst->op_params, sizeof(float));
6607
+
6608
+ norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
6609
 
6610
  (void) src1;
6611
  (void) dst;
 
6760
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
6761
  const int64_t src1_padded_row_size, const cudaStream_t & stream) {
6762
 
6763
+ GGML_ASSERT(ggml_nrows(src1) == 1);
6764
+
6765
  const int64_t ne00 = src0->ne[0];
6766
  const int64_t row_diff = row_high - row_low;
6767
 
 
6821
  size_t ash;
6822
  dfloat * src1_dfloat = nullptr; // dfloat == half
6823
 
6824
+ bool src1_convert_f16 =
6825
+ src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
6826
  src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
6827
  src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
6828
 
 
7044
  GGML_ASSERT(false);
7045
  rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
7046
  } else if (is_neox) {
 
7047
  if (src0->type == GGML_TYPE_F32) {
7048
  rope_neox_cuda(
7049
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
7050
  attn_factor, corr_dims, main_stream
7051
  );
7052
  } else if (src0->type == GGML_TYPE_F16) {
7053
  rope_neox_cuda(
7054
+ (const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
7055
  attn_factor, corr_dims, main_stream
7056
  );
7057
  } else {
 
7148
  (void) src0_dd;
7149
  }
7150
 
7151
+ inline void ggml_cuda_op_sum_rows(
7152
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
7153
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
7154
+
7155
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
7156
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
7157
+
7158
+ const int64_t ncols = src0->ne[0];
7159
+ const int64_t nrows = ggml_nrows(src0);
7160
+
7161
+ sum_rows_f32_cuda(src0_dd, dst_dd, ncols, nrows, main_stream);
7162
+
7163
+ (void) src1;
7164
+ (void) dst;
7165
+ (void) src1_dd;
7166
+ }
7167
+
7168
+ inline void ggml_cuda_op_argsort(
7169
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
7170
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
7171
+
7172
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
7173
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
7174
+
7175
+ const int64_t ncols = src0->ne[0];
7176
+ const int64_t nrows = ggml_nrows(src0);
7177
+
7178
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
7179
+
7180
+ argsort_f32_i32_cuda(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
7181
+
7182
+ (void) src1;
7183
+ (void) dst;
7184
+ (void) src1_dd;
7185
+ }
7186
+
7187
  inline void ggml_cuda_op_diag_mask_inf(
7188
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
7189
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
7211
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
7212
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
7213
 
7214
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
7215
+
7216
  const int64_t ne00 = src0->ne[0];
7217
+ const int64_t nrows_x = ggml_nrows(src0);
7218
+ const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
7219
 
7220
+ float scale = 1.0f;
7221
+ memcpy(&scale, dst->op_params, sizeof(float));
7222
+
7223
+ soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
7224
 
 
7225
  (void) dst;
 
7226
  }
7227
 
7228
  inline void ggml_cuda_op_scale(
 
7392
  const int64_t ne01 = src0->ne[1];
7393
  const int64_t ne02 = src0->ne[2];
7394
  const int64_t ne03 = src0->ne[3];
7395
+ const int64_t nrows0 = ggml_nrows(src0);
7396
 
7397
  const int64_t ne10 = src1->ne[0];
7398
  const int64_t ne11 = src1->ne[1];
 
7428
 
7429
  const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
7430
  const bool src0_is_contiguous = ggml_is_contiguous(src0);
 
7431
  const bool src1_is_contiguous = ggml_is_contiguous(src1);
7432
+
7433
+ const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
7434
 
7435
  const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
7436
  GGML_ASSERT(!(split && ne02 > 1));
 
7555
  const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
7556
 
7557
  // for split tensors the data begins at i0 == i0_offset_low
7558
+ char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
7559
  float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
7560
  char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset;
7561
  float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
 
7700
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
7701
  }
7702
 
7703
+ static void ggml_cuda_div(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7704
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_div);
7705
+ }
7706
+
7707
  static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7708
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
7709
  }
 
7729
  }
7730
 
7731
  bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
7732
+ if (!g_cublas_loaded) return false;
7733
 
7734
  const int64_t ne10 = src1->ne[0];
7735
 
 
7807
  ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
7808
  }
7809
 
7810
+ static __global__ void k_compute_batched_ptrs(
7811
  const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
7812
  const void ** ptrs_src, void ** ptrs_dst,
7813
  int ne12, int ne13,
 
7863
  CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7864
  cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
7865
 
7866
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
 
 
7867
 
7868
  ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
7869
  void * src0_ddq = src0_extra->data_device[g_main_device];
 
7920
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
7921
  // use cublasGemmStridedBatchedEx
7922
  CUBLAS_CHECK(
7923
+ cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
7924
  ne01, ne11, ne10,
7925
  &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
7926
  (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
 
7954
  CUDA_CHECK(cudaGetLastError());
7955
 
7956
  CUBLAS_CHECK(
7957
+ cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
7958
  ne01, ne11, ne10,
7959
  &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7960
  (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
 
8024
  #ifdef GGML_CUDA_FORCE_DMMV
8025
  const bool use_mul_mat_vec_q = false;
8026
  #else
8027
+ const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
8028
  #endif // GGML_CUDA_FORCE_DMMV
8029
 
8030
  if (use_mul_mat_vec_q) {
8031
+ // NOTE: this kernel does not support ggml_nrows(src1) > 1
8032
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
8033
  } else {
8034
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
 
8053
  }
8054
  }
8055
 
8056
+ #if 0
8057
+ template<typename ... Srcs>
8058
+ static __global__ void k_compute_batched_ptrs_id(
8059
+ const void ** ptrs_src, void ** ptrs_dst,
8060
+ int ne12, int ne13,
8061
+ int ne23,
8062
+ int nb02, int nb03,
8063
+ int nb12, int nb13,
8064
+ int nb2, int nb3,
8065
+ int r2, int r3,
8066
+ ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
8067
+ const half * src1_f16, half * dst_f16,
8068
+ const int32_t * ids, const int id,
8069
+ Srcs... src0s) {
8070
+
8071
+ int i = ids[id];
8072
+
8073
+ half * src0_f16;
8074
+ const void * srcs_ar[] = { (const half *) src0s... };
8075
+ if (src0_type == GGML_TYPE_F16) {
8076
+ src0_f16 = (half *) srcs_ar[i];
8077
+ } else {
8078
+ src0_f16 = src0_as_f16;
8079
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
8080
+ const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
8081
+ to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
8082
+ }
8083
+ }
8084
 
8085
+ int i13 = blockIdx.x * blockDim.x + threadIdx.x;
8086
+ int i12 = blockIdx.y * blockDim.y + threadIdx.y;
8087
+
8088
+ if (i13 >= ne13 || i12 >= ne12) {
8089
+ return;
8090
+ }
8091
+
8092
+ int i03 = i13 / r3;
8093
+ int i02 = i12 / r2;
8094
+
8095
+ ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03;
8096
+ ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
8097
+ ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
8098
  }
8099
 
8100
+ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
8101
+ const struct ggml_tensor * ids = dst->src[0];
8102
+ const struct ggml_tensor * src1 = dst->src[1];
8103
+ const struct ggml_tensor * src00 = dst->src[2];
8104
 
8105
+ const int id = dst->op_params[0];
 
8106
 
8107
+ GGML_ASSERT(!ggml_is_transposed(src00));
8108
+ GGML_ASSERT(!ggml_is_transposed(src1));
8109
 
8110
+ GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT);
8111
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
8112
 
8113
+ const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
8114
+ const int64_t ne01 = src00->ne[1];
8115
+ const int64_t ne02 = src00->ne[2];
8116
+ const int64_t ne03 = src00->ne[3];
8117
+
8118
+ //const int64_t nb01 = src00->nb[1];
8119
+ const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
8120
+ const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
8121
 
8122
  const int64_t ne10 = src1->ne[0];
8123
  const int64_t ne11 = src1->ne[1];
8124
+ const int64_t ne12 = src1->ne[2];
8125
+ const int64_t ne13 = src1->ne[3];
8126
 
8127
+ //const int64_t nb11 = src1->nb[1];
8128
+ const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
8129
+ const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
8130
 
8131
+ const int64_t ne1 = ggml_nelements(src1);
8132
+ const int64_t ne = ggml_nelements(dst);
8133
+
8134
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
8135
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
8136
+
8137
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
8138
+
8139
+ //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
8140
+ //void * src0_ddq = src0_extra->data_device[g_main_device];
8141
+ //half * src0_as_f16 = (half *) src0_ddq;
8142
+
8143
+ ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
8144
+ float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
8145
+
8146
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
8147
+ float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
8148
+
8149
+ // convert src1 to fp16
8150
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
8151
+ GGML_ASSERT(to_fp16_cuda != nullptr);
8152
+
8153
+ size_t src1_as = 0;
8154
+ half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
8155
+ to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
8156
+
8157
+ size_t dst_as = 0;
8158
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
8159
+
8160
+ GGML_ASSERT(ne12 % ne02 == 0);
8161
+ GGML_ASSERT(ne13 % ne03 == 0);
8162
+
8163
+ // broadcast factors
8164
+ const int64_t r2 = ne12/ne02;
8165
+ const int64_t r3 = ne13/ne03;
8166
+
8167
+ const half alpha_f16 = 1.0f;
8168
+ const half beta_f16 = 0.0f;
8169
+
8170
+ // use cublasGemmBatchedEx
8171
+ const int ne23 = ne12*ne13;
8172
+
8173
+ const void ** ptrs_src = nullptr;
8174
+ void ** ptrs_dst = nullptr;
8175
+
8176
+ size_t ptrs_src_s = 0;
8177
+ size_t ptrs_dst_s = 0;
8178
+
8179
+ ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
8180
+ ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
8181
+
8182
+ int64_t src0_ne = ggml_nelements(src00);
8183
+ half * src0_as_f16 = nullptr;
8184
+ size_t src0_as = 0;
8185
+ if (src00->type != GGML_TYPE_F16) {
8186
+ src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
8187
+ }
8188
+
8189
+ static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
8190
+ dim3 block_dims(ne13, ne12);
8191
+ k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
8192
+ ptrs_src, ptrs_dst,
8193
+ ne12, ne13,
8194
+ ne23,
8195
+ ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
8196
+ nb12, nb13,
8197
+ dst->nb[2], dst->nb[3],
8198
+ r2, r3,
8199
+ src00->type, src0_as_f16, src0_ne,
8200
+ src1_as_f16, dst_f16,
8201
+ (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
8202
+ dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
8203
+ dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
8204
+ dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
8205
+ dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
8206
+ );
8207
+ CUDA_CHECK(cudaGetLastError());
8208
+
8209
+ CUBLAS_CHECK(
8210
+ cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
8211
+ ne01, ne11, ne10,
8212
+ &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
8213
+ (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
8214
+ &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
8215
+ ne23,
8216
+ CUBLAS_COMPUTE_16F,
8217
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
8218
+
8219
+ if (src0_as != 0) {
8220
+ ggml_cuda_pool_free(src0_as_f16, src0_as);
8221
+ }
8222
+ if (ptrs_src_s != 0) {
8223
+ ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
8224
+ }
8225
+ if (ptrs_dst_s != 0) {
8226
+ ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
8227
+ }
8228
+
8229
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
8230
+ to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
8231
+
8232
+ ggml_cuda_pool_free(src1_as_f16, src1_as);
8233
+ ggml_cuda_pool_free(dst_f16, dst_as);
8234
+ }
8235
+ #endif
8236
+
8237
+ static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
8238
+ #if 0
8239
+ //#ifdef CUDA_USE_TENSOR_CORES
8240
+ // const bool use_tensor_cores = true;
8241
+ //#else
8242
+ // const bool use_tensor_cores = false;
8243
+ //#endif
8244
+
8245
+ ggml_cuda_mul_mat_id_cublas(dst);
8246
+
8247
+ // TODO: mmq/mmv support
8248
+ #else
8249
+ const struct ggml_tensor * ids = dst->src[0];
8250
+ const struct ggml_tensor * src1 = dst->src[1];
8251
+ const int id = dst->op_params[0];
8252
+
8253
+ int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
8254
+
8255
+ int32_t a_id;
8256
+ CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8257
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
8258
+
8259
+ GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
8260
+ const struct ggml_tensor * src0 = dst->src[a_id + 2];
8261
+
8262
+ ggml_cuda_mul_mat(src0, src1, dst);
8263
+ #endif
8264
+
8265
+ (void) _src0;
8266
+ (void) _src1;
8267
+ }
8268
+
8269
+ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8270
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
8271
+ }
8272
+
8273
+ static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8274
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
8275
+ }
8276
+
8277
+ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8278
+ const int64_t ne = ggml_nelements(src0);
8279
+ GGML_ASSERT(ne == ggml_nelements(src1));
8280
+
8281
+ GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
8282
+ GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
8283
+
8284
+ GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
8285
+ GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
8286
+
8287
+ const int64_t ne00 = src0->ne[0];
8288
+ const int64_t ne01 = src0->ne[1];
8289
+ GGML_ASSERT(src0->ne[3] == 1);
8290
+
8291
+ const int64_t nb00 = src0->nb[0];
8292
+ const int64_t nb01 = src0->nb[1];
8293
+ const int64_t nb02 = src0->nb[2];
8294
+
8295
+ const int64_t ne10 = src1->ne[0];
8296
+ const int64_t ne11 = src1->ne[1];
8297
+ GGML_ASSERT(src1->ne[3] == 1);
8298
+
8299
+ const int64_t nb10 = src1->nb[0];
8300
+ const int64_t nb11 = src1->nb[1];
8301
+ const int64_t nb12 = src1->nb[2];
8302
+
8303
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
8304
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
8305
 
8306
  const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
8307
  const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
 
8310
  char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
8311
 
8312
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
8313
+ ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
 
8314
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
8315
+ ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
8316
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
8317
+ ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
8318
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
8319
+ ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
8320
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
8321
+ ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
8322
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
8323
+ ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
 
8324
  } else {
8325
  fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
8326
  ggml_type_name(src0->type), ggml_type_name(src1->type));
 
8331
  }
8332
 
8333
  static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8334
+ // TODO: why do we pass dst as src1 here?
8335
  ggml_cuda_cpy(src0, dst, nullptr);
8336
  (void) src1;
8337
  }
 
8357
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
8358
  }
8359
 
8360
+ static void ggml_cuda_sum_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8361
+ GGML_ASSERT(ggml_is_contiguous(src0));
8362
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sum_rows);
8363
+ }
8364
+
8365
+ static void ggml_cuda_argsort(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8366
+ GGML_ASSERT(ggml_is_contiguous(src0));
8367
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_argsort);
8368
+ }
8369
+
8370
  static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8371
  (void) src0;
8372
  (void) src1;
 
8622
  main_device, g_device_count, g_main_device);
8623
  return;
8624
  }
8625
+
8626
+ if (g_main_device != main_device && g_device_count > 1) {
8627
+ g_main_device = main_device;
8628
  cudaDeviceProp prop;
8629
  CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device));
8630
  fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name);
 
8650
  }
8651
 
8652
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
8653
+ if (!g_cublas_loaded) return false;
8654
 
8655
  ggml_cuda_func_t func;
8656
  const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
 
8686
  case GGML_OP_MUL:
8687
  func = ggml_cuda_mul;
8688
  break;
8689
+ case GGML_OP_DIV:
8690
+ func = ggml_cuda_div;
8691
+ break;
8692
  case GGML_OP_UNARY:
8693
  switch (ggml_get_unary_op(tensor)) {
8694
  case GGML_UNARY_OP_GELU:
 
8702
  break;
8703
  default:
8704
  return false;
8705
+ }
8706
+ break;
8707
  case GGML_OP_NORM:
8708
  func = ggml_cuda_norm;
8709
  break;
 
8716
  }
8717
  func = ggml_cuda_mul_mat;
8718
  break;
8719
+ case GGML_OP_MUL_MAT_ID:
8720
+ if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) {
8721
+ return false;
8722
+ }
8723
+ func = ggml_cuda_mul_mat_id;
8724
+ break;
8725
  case GGML_OP_SCALE:
8726
  func = ggml_cuda_scale;
8727
  break;
 
8761
  case GGML_OP_IM2COL:
8762
  func = ggml_cuda_im2col;
8763
  break;
8764
+ case GGML_OP_SUM_ROWS:
8765
+ func = ggml_cuda_sum_rows;
8766
+ break;
8767
+ case GGML_OP_ARGSORT:
8768
+ func = ggml_cuda_argsort;
8769
+ break;
8770
  default:
8771
  return false;
8772
  }
 
8783
 
8784
  int ggml_cuda_get_device_count() {
8785
  int device_count;
8786
+ if (cudaGetDeviceCount(&device_count) != cudaSuccess) {
8787
+ return 0;
8788
+ }
8789
  return device_count;
8790
  }
8791
 
 
8801
 
8802
  #define UNUSED GGML_UNUSED
8803
 
8804
+ // cuda buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
8805
 
8806
  struct ggml_backend_buffer_context_cuda {
8807
+ int device;
8808
+ void * dev_ptr = nullptr;
8809
  ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
8810
  size_t temp_tensor_extra_index = 0;
8811
 
8812
+ ggml_backend_buffer_context_cuda(int device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {}
8813
+
8814
  ~ggml_backend_buffer_context_cuda() {
8815
  delete[] temp_tensor_extras;
8816
  }
 
8831
 
8832
  static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
8833
  ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
8834
+ CUDA_CHECK(cudaFree(ctx->dev_ptr));
8835
  delete ctx;
8836
  }
8837
 
8838
  static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
8839
  ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
8840
+ return ctx->dev_ptr;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8841
  }
8842
 
8843
  static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
8844
  ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
8845
 
8846
  if (tensor->view_src != NULL && tensor->view_offs == 0) {
8847
+ assert(tensor->view_src->buffer->buft == buffer->buft); // TODO
8848
  tensor->backend = tensor->view_src->backend;
8849
  tensor->extra = tensor->view_src->extra;
8850
  return;
 
8852
 
8853
  ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra();
8854
 
8855
+ extra->data_device[ctx->device] = tensor->data;
8856
 
8857
  tensor->backend = GGML_BACKEND_GPU;
8858
  tensor->extra = extra;
 
8864
  int64_t nrows_split = row_high - row_low;
8865
 
8866
  size_t original_size = ggml_nbytes_split(tensor, nrows_split);
8867
+ size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
8868
 
8869
  if (padded_size > original_size && tensor->view_src == nullptr) {
8870
+ CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[ctx->device][0]));
8871
  }
8872
  }
8873
 
8874
  UNUSED(buffer);
8875
  }
8876
 
8877
+ static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
8878
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
8879
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
8880
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
8881
+
8882
+ CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
8883
+
8884
+ UNUSED(buffer);
8885
+ }
8886
+
8887
+ static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
8888
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
8889
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
8890
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
8891
+
8892
+ CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost));
8893
+
8894
+ UNUSED(buffer);
8895
+ }
8896
+
8897
  static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
8898
+ /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
8899
+ /* .get_base = */ ggml_backend_cuda_buffer_get_base,
8900
+ /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
8901
+ /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
8902
+ /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
8903
+ /* .cpy_tensor_from = */ NULL,
8904
+ /* .cpy_tensor_to = */ NULL,
8905
  };
8906
 
8907
+ // cuda buffer type
 
8908
 
8909
+ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
8910
+ int device = (int) (intptr_t) buft->context;
8911
+
8912
+ ggml_cuda_set_device(device);
8913
 
8914
  size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
8915
 
8916
+ void * dev_ptr;
8917
+ CUDA_CHECK(cudaMalloc(&dev_ptr, size));
8918
 
8919
+ ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda(device, dev_ptr);
8920
+
8921
+ return ggml_backend_buffer_init(buft, cuda_backend_buffer_interface, ctx, size);
8922
  }
8923
 
8924
+ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
8925
  return 128;
8926
+
8927
+ UNUSED(buft);
8928
+ }
8929
+
8930
+ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, ggml_tensor * tensor) {
8931
+ int64_t row_low = 0;
8932
+ int64_t row_high = ggml_nrows(tensor);
8933
+ int64_t nrows_split = row_high - row_low;
8934
+
8935
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
8936
+
8937
+ int64_t ne0 = tensor->ne[0];
8938
+
8939
+ if (ggml_is_quantized(tensor->type)) {
8940
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
8941
+ size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
8942
+ * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
8943
+ }
8944
+ }
8945
+
8946
+ return size;
8947
+
8948
+ UNUSED(buft);
8949
+ }
8950
+
8951
+ static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
8952
+ return ggml_backend_is_cuda(backend);
8953
+
8954
+ UNUSED(buft);
8955
+ }
8956
+
8957
+ static ggml_backend_buffer_type_i cuda_backend_buffer_type_interface = {
8958
+ /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer,
8959
+ /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment,
8960
+ /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
8961
+ /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend,
8962
+ };
8963
+
8964
+ ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
8965
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda[GGML_CUDA_MAX_DEVICES];
8966
+ static bool ggml_backend_buffer_type_cuda_initialized = false;
8967
+ if (!ggml_backend_buffer_type_cuda_initialized) {
8968
+ for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) {
8969
+ ggml_backend_buffer_type_cuda[i] = {
8970
+ /* .iface = */ cuda_backend_buffer_type_interface,
8971
+ /* .context = */ (ggml_backend_buffer_type_context_t) (intptr_t) i,
8972
+ };
8973
+ }
8974
+ ggml_backend_buffer_type_cuda_initialized = true;
8975
+ }
8976
+
8977
+ return &ggml_backend_buffer_type_cuda[device];
8978
+ }
8979
+
8980
+ // host buffer type
8981
+
8982
+ static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
8983
+ ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
8984
+ CUDA_CHECK(cudaFreeHost(ctx->dev_ptr));
8985
+ delete ctx;
8986
+ }
8987
+
8988
+ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
8989
+ void * ptr;
8990
+ CUDA_CHECK(cudaMallocHost(&ptr, size));
8991
+
8992
+ // FIXME: this is a hack to avoid having to implement a new buffer type
8993
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
8994
+ buffer->buft = buft;
8995
+ buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
8996
+
8997
+ return buffer;
8998
+
8999
+ UNUSED(buft);
9000
+ }
9001
+
9002
+ struct ggml_backend_buffer_type_i cuda_backend_host_buffer_type_interface = {
9003
+ /* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
9004
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
9005
+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
9006
+ /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
9007
+ };
9008
+
9009
+ ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
9010
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda_host = {
9011
+ /* .iface = */ cuda_backend_host_buffer_type_interface,
9012
+ /* .context = */ nullptr,
9013
+ };
9014
+
9015
+ return &ggml_backend_buffer_type_cuda_host;
9016
+ }
9017
+
9018
+ // backend
9019
+
9020
+ struct ggml_backend_context_cuda {
9021
+ int device;
9022
+ };
9023
+
9024
+ static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
9025
+ return GGML_CUDA_NAME;
9026
+
9027
  UNUSED(backend);
9028
  }
9029
 
9030
+ static void ggml_backend_cuda_free(ggml_backend_t backend) {
9031
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
9032
+
9033
+ delete cuda_ctx;
9034
+ delete backend;
9035
+ }
9036
+
9037
+ static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer_type(ggml_backend_t backend) {
9038
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
9039
+
9040
+ return ggml_backend_cuda_buffer_type(cuda_ctx->device);
9041
+ }
9042
+
9043
  static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
9044
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
9045
+
9046
+ GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
9047
  GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
9048
  GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
9049
  GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
9050
 
9051
+ CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0]));
 
 
9052
  }
9053
 
9054
  static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
9055
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
9056
+
9057
+ GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
9058
  GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
9059
  GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
9060
  GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
9061
 
9062
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0]));
 
 
9063
  }
9064
 
9065
  static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
9066
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
9067
+
9068
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[cuda_ctx->device][0]));
9069
 
9070
  UNUSED(backend);
9071
  }
 
9079
  UNUSED(cgraph);
9080
  }
9081
 
9082
+ static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
9083
  GGML_ASSERT(!"not implemented");
9084
 
9085
  UNUSED(backend);
9086
  UNUSED(plan);
9087
  }
9088
 
9089
+ static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
9090
  GGML_ASSERT(!"not implemented");
9091
 
9092
  UNUSED(backend);
 
9094
  }
9095
 
9096
  static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
9097
+ ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
9098
+
9099
+ ggml_cuda_set_main_device(cuda_ctx->device);
9100
 
9101
  ggml_compute_params params = {};
9102
  params.type = GGML_TASK_COMPUTE;
 
9104
  for (int i = 0; i < cgraph->n_nodes; i++) {
9105
  ggml_tensor * node = cgraph->nodes[i];
9106
 
9107
+ if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE)
9108
  continue;
9109
+
9110
  assert(node->backend == GGML_BACKEND_GPU);
9111
+ assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
9112
+ assert(node->extra != nullptr);
9113
+
9114
  for (int j = 0; j < GGML_MAX_SRC; j++) {
9115
  if (node->src[j] != nullptr) {
9116
  assert(node->src[j]->backend == GGML_BACKEND_GPU);
9117
+ assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
9118
+ assert(node->src[j]->extra != nullptr);
9119
  }
9120
  }
9121
 
 
9152
  UNUSED(backend);
9153
  }
9154
 
9155
+ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
9156
+ switch (op->op) {
9157
+ case GGML_OP_UNARY:
9158
+ switch (ggml_get_unary_op(op)) {
9159
+ case GGML_UNARY_OP_GELU:
9160
+ case GGML_UNARY_OP_SILU:
9161
+ case GGML_UNARY_OP_RELU:
9162
+ return true;
9163
+ default:
9164
+ return false;
9165
+ }
9166
+ break;
9167
+ case GGML_OP_MUL_MAT:
9168
+ case GGML_OP_MUL_MAT_ID:
9169
+ {
9170
+ struct ggml_tensor * a;
9171
+ struct ggml_tensor * b;
9172
+ if (op->op == GGML_OP_MUL_MAT) {
9173
+ a = op->src[0];
9174
+ b = op->src[1];
9175
+ } else {
9176
+ a = op->src[2];
9177
+ b = op->src[1];
9178
+ }
9179
+ if (a->ne[3] != b->ne[3]) {
9180
+ return false;
9181
+ }
9182
+ return true;
9183
+ } break;
9184
+ case GGML_OP_NONE:
9185
+ case GGML_OP_RESHAPE:
9186
+ case GGML_OP_VIEW:
9187
+ case GGML_OP_PERMUTE:
9188
+ case GGML_OP_TRANSPOSE:
9189
+ case GGML_OP_NORM:
9190
+ case GGML_OP_REPEAT:
9191
+ case GGML_OP_GET_ROWS:
9192
+ case GGML_OP_DUP:
9193
+ case GGML_OP_ADD:
9194
+ case GGML_OP_MUL:
9195
+ case GGML_OP_DIV:
9196
+ case GGML_OP_RMS_NORM:
9197
+ case GGML_OP_SCALE:
9198
+ case GGML_OP_SQR:
9199
+ case GGML_OP_CLAMP:
9200
+ case GGML_OP_CPY:
9201
+ case GGML_OP_CONT:
9202
+ case GGML_OP_DIAG_MASK_INF:
9203
+ case GGML_OP_SOFT_MAX:
9204
+ case GGML_OP_ROPE:
9205
+ case GGML_OP_ALIBI:
9206
+ case GGML_OP_IM2COL:
9207
+ case GGML_OP_SUM_ROWS:
9208
+ case GGML_OP_ARGSORT:
9209
+ return true;
9210
+ default:
9211
+ return false;
9212
+ }
9213
+
9214
+ UNUSED(backend);
9215
+ }
9216
+
9217
  static ggml_backend_i cuda_backend_i = {
9218
+ /* .get_name = */ ggml_backend_cuda_name,
9219
+ /* .free = */ ggml_backend_cuda_free,
9220
+ /* .get_default_buffer_type = */ ggml_backend_cuda_get_default_buffer_type,
9221
+ /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
9222
+ /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
9223
+ /* .cpy_tensor_from_async = */ NULL,
9224
+ /* .cpy_tensor_to_async = */ NULL,
9225
+ /* .synchronize = */ ggml_backend_cuda_synchronize,
9226
+ /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create,
9227
+ /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free,
9228
+ /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute,
9229
+ /* .graph_compute = */ ggml_backend_cuda_graph_compute,
9230
+ /* .supports_op = */ ggml_backend_cuda_supports_op,
 
9231
  };
9232
 
9233
+ ggml_backend_t ggml_backend_cuda_init(int device) {
9234
  ggml_init_cublas(); // TODO: remove from ggml.c
9235
 
9236
+ if (device < 0 || device >= ggml_cuda_get_device_count()) {
9237
+ fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
9238
+ return nullptr;
9239
+ }
9240
+
9241
+ // not strictly necessary, but it may reduce the overhead of the first graph_compute
9242
+ ggml_cuda_set_main_device(device);
9243
+
9244
+ ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda {
9245
+ /* .device = */ device
9246
+ };
9247
 
9248
  ggml_backend_t cuda_backend = new ggml_backend {
9249
  /* .interface = */ cuda_backend_i,
 
9252
 
9253
  return cuda_backend;
9254
  }
9255
+
9256
+ bool ggml_backend_is_cuda(ggml_backend_t backend) {
9257
+ return backend->iface.get_name == ggml_backend_cuda_name;
9258
+ }
9259
+
9260
+ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) {
9261
+ ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data);
9262
+ return cuda_backend;
9263
+
9264
+ UNUSED(params);
9265
+ }
9266
+
9267
+ extern "C" int ggml_backend_cuda_reg_devices() {
9268
+ int device_count = ggml_cuda_get_device_count();
9269
+ //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
9270
+ for (int i = 0; i < device_count; i++) {
9271
+ char name[128];
9272
+ snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i);
9273
+ ggml_backend_register(name, ggml_backend_reg_cuda_init, ggml_backend_cuda_buffer_type(i), (void *) (intptr_t) i);
9274
+ }
9275
+ return device_count;
9276
+ }
ggml-cuda.h CHANGED
@@ -49,7 +49,15 @@ GGML_API int ggml_cuda_get_device_count(void);
49
  GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
50
 
51
  // backend API
52
- GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
 
 
 
 
 
 
 
 
53
 
54
  #ifdef __cplusplus
55
  }
 
49
  GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
50
 
51
  // backend API
52
+ GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
53
+
54
+ GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
55
+ GGML_API int ggml_backend_cuda_get_device(ggml_backend_t backend);
56
+
57
+ GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
58
+
59
+ // pinned host buffer for use with CPU backend for faster copies between CPU and GPU
60
+ GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
61
 
62
  #ifdef __cplusplus
63
  }
ggml-impl.h CHANGED
@@ -232,7 +232,7 @@ bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml
232
  // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
233
  size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
234
 
235
- // returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
236
  size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
237
 
238
  // return index, asserts if table is full
 
232
  // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
233
  size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
234
 
235
+ // returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
236
  size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
237
 
238
  // return index, asserts if table is full
ggml-metal.h CHANGED
@@ -52,11 +52,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx);
52
  void * ggml_metal_host_malloc(size_t n);
53
  void ggml_metal_host_free (void * data);
54
 
55
- // helper to check if the device supports a specific family
56
- // ideally, the user code should be doing these checks
57
- // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
58
- bool ggml_metal_supports_family(struct ggml_metal_context * ctx, int family);
59
-
60
  // set the number of command buffers to use
61
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
62
 
@@ -104,7 +99,11 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
104
  GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
105
 
106
  GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
 
107
 
 
 
 
108
  GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
109
 
110
  #ifdef __cplusplus
 
52
  void * ggml_metal_host_malloc(size_t n);
53
  void ggml_metal_host_free (void * data);
54
 
 
 
 
 
 
55
  // set the number of command buffers to use
56
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
57
 
 
99
  GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
100
 
101
  GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
102
+ GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
103
 
104
+ // helper to check if the device supports a specific family
105
+ // ideally, the user code should be doing these checks
106
+ // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
107
  GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
108
 
109
  #ifdef __cplusplus
ggml-metal.m CHANGED
@@ -62,6 +62,8 @@ struct ggml_metal_context {
62
  GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
  GGML_METAL_DECL_KERNEL(mul);
64
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
 
 
65
  GGML_METAL_DECL_KERNEL(scale);
66
  GGML_METAL_DECL_KERNEL(scale_4);
67
  GGML_METAL_DECL_KERNEL(silu);
@@ -112,15 +114,35 @@ struct ggml_metal_context {
112
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
113
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
114
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
 
 
 
 
 
 
 
 
 
 
 
 
115
  GGML_METAL_DECL_KERNEL(rope_f32);
116
  GGML_METAL_DECL_KERNEL(rope_f16);
117
  GGML_METAL_DECL_KERNEL(alibi_f32);
118
  GGML_METAL_DECL_KERNEL(im2col_f16);
 
 
119
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
120
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
 
 
 
 
 
121
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
122
  GGML_METAL_DECL_KERNEL(concat);
123
  GGML_METAL_DECL_KERNEL(sqr);
 
124
 
125
  #undef GGML_METAL_DECL_KERNEL
126
  };
@@ -164,12 +186,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
164
  }
165
  }
166
 
167
-
168
-
169
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
171
 
172
- id <MTLDevice> device;
173
  NSString * s;
174
 
175
  #if TARGET_OS_OSX
@@ -215,6 +235,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
215
 
216
  NSString * sourcePath;
217
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
 
 
 
218
  if (ggmlMetalPathResources) {
219
  sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
220
  } else {
@@ -245,6 +268,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
245
  }
246
  }
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  // load kernels
249
  {
250
  NSError * error = nil;
@@ -266,6 +312,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
266
  GGML_METAL_ADD_KERNEL(add_row);
267
  GGML_METAL_ADD_KERNEL(mul);
268
  GGML_METAL_ADD_KERNEL(mul_row);
 
 
269
  GGML_METAL_ADD_KERNEL(scale);
270
  GGML_METAL_ADD_KERNEL(scale_4);
271
  GGML_METAL_ADD_KERNEL(silu);
@@ -317,43 +365,40 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
317
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
318
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
319
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
 
 
 
 
 
 
 
 
 
 
 
 
320
  }
321
  GGML_METAL_ADD_KERNEL(rope_f32);
322
  GGML_METAL_ADD_KERNEL(rope_f16);
323
  GGML_METAL_ADD_KERNEL(alibi_f32);
324
  GGML_METAL_ADD_KERNEL(im2col_f16);
 
 
325
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
326
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
 
 
 
 
 
327
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
328
  GGML_METAL_ADD_KERNEL(concat);
329
  GGML_METAL_ADD_KERNEL(sqr);
 
330
 
331
  #undef GGML_METAL_ADD_KERNEL
332
  }
333
 
334
- #if TARGET_OS_OSX
335
- // print MTL GPU family:
336
- GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
337
-
338
- // determine max supported GPU family
339
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
340
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
341
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
342
- if ([ctx->device supportsFamily:i]) {
343
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
344
- break;
345
- }
346
- }
347
-
348
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
349
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
350
- if (ctx->device.maxTransferRate != 0) {
351
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
352
- } else {
353
- GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
354
- }
355
- #endif
356
-
357
  return ctx;
358
  }
359
 
@@ -367,6 +412,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
367
  GGML_METAL_DEL_KERNEL(add_row);
368
  GGML_METAL_DEL_KERNEL(mul);
369
  GGML_METAL_DEL_KERNEL(mul_row);
 
 
370
  GGML_METAL_DEL_KERNEL(scale);
371
  GGML_METAL_DEL_KERNEL(scale_4);
372
  GGML_METAL_DEL_KERNEL(silu);
@@ -418,16 +465,36 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
418
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
419
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
420
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
 
 
 
 
 
 
 
 
 
 
 
 
421
  }
422
  GGML_METAL_DEL_KERNEL(rope_f32);
423
  GGML_METAL_DEL_KERNEL(rope_f16);
424
  GGML_METAL_DEL_KERNEL(alibi_f32);
425
  GGML_METAL_DEL_KERNEL(im2col_f16);
 
 
426
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
427
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
 
 
 
 
 
428
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
429
  GGML_METAL_DEL_KERNEL(concat);
430
  GGML_METAL_DEL_KERNEL(sqr);
 
431
 
432
  #undef GGML_METAL_DEL_KERNEL
433
 
@@ -459,10 +526,6 @@ void ggml_metal_host_free(void * data) {
459
  free(data);
460
  }
461
 
462
- bool ggml_metal_supports_family(struct ggml_metal_context * ctx, int family) {
463
- return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
464
- }
465
-
466
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
467
  ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
468
  }
@@ -475,6 +538,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
475
  return ctx->concur_list;
476
  }
477
 
 
 
 
 
 
 
 
478
  // finds the Metal buffer that contains the tensor data on the GPU device
479
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
480
  // Metal buffer based on the host memory pointer
@@ -484,8 +554,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
484
 
485
  const int64_t tsize = ggml_nbytes(t);
486
 
487
- if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
488
- ctx = t->buffer->backend->context;
 
 
 
 
 
 
 
 
 
489
  }
490
 
491
  // find the view that contains the tensor fully
@@ -545,11 +624,11 @@ bool ggml_metal_add_buffer(
545
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
546
 
547
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
548
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1e6);
549
  return false;
550
  }
551
 
552
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1e6);
553
 
554
  ++ctx->n_buffers;
555
  } else {
@@ -569,11 +648,11 @@ bool ggml_metal_add_buffer(
569
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
570
 
571
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
572
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1e6);
573
  return false;
574
  }
575
 
576
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1e6, i);
577
  if (i + size_step < size) {
578
  GGML_METAL_LOG_INFO("\n");
579
  }
@@ -584,8 +663,8 @@ bool ggml_metal_add_buffer(
584
 
585
  #if TARGET_OS_OSX
586
  GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
587
- ctx->device.currentAllocatedSize / 1e6,
588
- ctx->device.recommendedMaxWorkingSetSize / 1e6);
589
 
590
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
591
  GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
@@ -593,7 +672,7 @@ bool ggml_metal_add_buffer(
593
  GGML_METAL_LOG_INFO("\n");
594
  }
595
  #else
596
- GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1e6);
597
  #endif
598
  }
599
 
@@ -710,6 +789,51 @@ void ggml_metal_graph_find_concurrency(
710
  }
711
  }
712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  void ggml_metal_graph_compute(
714
  struct ggml_metal_context * ctx,
715
  struct ggml_cgraph * gf) {
@@ -780,6 +904,8 @@ void ggml_metal_graph_compute(
780
  } break;
781
  }
782
 
 
 
783
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
784
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
785
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -872,6 +998,8 @@ void ggml_metal_graph_compute(
872
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
873
  } break;
874
  case GGML_OP_ADD:
 
 
875
  {
876
  GGML_ASSERT(ggml_is_contiguous(src0));
877
  GGML_ASSERT(ggml_is_contiguous(src1));
@@ -885,11 +1013,21 @@ void ggml_metal_graph_compute(
885
  GGML_ASSERT(ne11 == 1);
886
 
887
  nb = ne00 / 4;
888
- [encoder setComputePipelineState:ctx->pipeline_add_row];
 
 
 
 
 
889
 
890
  bcast_row = true;
891
  } else {
892
- [encoder setComputePipelineState:ctx->pipeline_add];
 
 
 
 
 
893
  }
894
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
895
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -930,31 +1068,6 @@ void ggml_metal_graph_compute(
930
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
931
  }
932
  } break;
933
- case GGML_OP_MUL:
934
- {
935
- GGML_ASSERT(ggml_is_contiguous(src0));
936
- GGML_ASSERT(ggml_is_contiguous(src1));
937
-
938
- // utilize float4
939
- GGML_ASSERT(ne00 % 4 == 0);
940
- const int64_t nb = ne00/4;
941
-
942
- if (ggml_nelements(src1) == ne10) {
943
- // src1 is a row
944
- GGML_ASSERT(ne11 == 1);
945
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
946
- } else {
947
- [encoder setComputePipelineState:ctx->pipeline_mul];
948
- }
949
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
950
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
951
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
952
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
953
-
954
- const int64_t n = ggml_nelements(dst)/4;
955
-
956
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
957
- } break;
958
  case GGML_OP_SCALE:
959
  {
960
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1027,25 +1140,66 @@ void ggml_metal_graph_compute(
1027
  const int64_t n = ggml_nelements(dst);
1028
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1029
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1030
  case GGML_OP_SOFT_MAX:
1031
  {
1032
  int nth = 32; // SIMD width
1033
 
1034
  if (ne00%4 == 0) {
 
 
 
1035
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1036
  } else {
1037
- do {
1038
  nth *= 2;
1039
- } while (nth <= ne00 && nth <= 1024);
1040
- nth /= 2;
1041
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
1042
  }
1043
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1044
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1045
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1046
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1047
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1048
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
 
 
 
 
 
1049
 
1050
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1051
  } break;
@@ -1074,9 +1228,13 @@ void ggml_metal_graph_compute(
1074
  case GGML_OP_MUL_MAT:
1075
  {
1076
  GGML_ASSERT(ne00 == ne10);
1077
- GGML_ASSERT(ne03 == ne13);
1078
 
1079
- const unsigned int gqa = ne12/ne02;
 
 
 
 
 
1080
 
1081
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1082
  // to the matrix-vector kernel
@@ -1111,7 +1269,7 @@ void ggml_metal_graph_compute(
1111
  !ggml_is_transposed(src1) &&
1112
  src1t == GGML_TYPE_F32 &&
1113
  ne00 % 32 == 0 && ne00 >= 64 &&
1114
- ne11 > ne11_mm_min) {
1115
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1116
  switch (src0->type) {
1117
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
@@ -1141,9 +1299,10 @@ void ggml_metal_graph_compute(
1141
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1142
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1143
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1144
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
 
1145
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1146
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1147
  } else {
1148
  int nth0 = 32;
1149
  int nth1 = 1;
@@ -1179,90 +1338,60 @@ void ggml_metal_graph_compute(
1179
  } break;
1180
  case GGML_TYPE_Q4_0:
1181
  {
1182
- GGML_ASSERT(ne02 == 1);
1183
- GGML_ASSERT(ne12 == 1);
1184
-
1185
  nth0 = 8;
1186
  nth1 = 8;
1187
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1188
  } break;
1189
  case GGML_TYPE_Q4_1:
1190
  {
1191
- GGML_ASSERT(ne02 == 1);
1192
- GGML_ASSERT(ne12 == 1);
1193
-
1194
  nth0 = 8;
1195
  nth1 = 8;
1196
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1197
  } break;
1198
  case GGML_TYPE_Q5_0:
1199
  {
1200
- GGML_ASSERT(ne02 == 1);
1201
- GGML_ASSERT(ne12 == 1);
1202
-
1203
  nth0 = 8;
1204
  nth1 = 8;
1205
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1206
  } break;
1207
  case GGML_TYPE_Q5_1:
1208
  {
1209
- GGML_ASSERT(ne02 == 1);
1210
- GGML_ASSERT(ne12 == 1);
1211
-
1212
  nth0 = 8;
1213
  nth1 = 8;
1214
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1215
  } break;
1216
  case GGML_TYPE_Q8_0:
1217
  {
1218
- GGML_ASSERT(ne02 == 1);
1219
- GGML_ASSERT(ne12 == 1);
1220
-
1221
  nth0 = 8;
1222
  nth1 = 8;
1223
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1224
  } break;
1225
  case GGML_TYPE_Q2_K:
1226
  {
1227
- GGML_ASSERT(ne02 == 1);
1228
- GGML_ASSERT(ne12 == 1);
1229
-
1230
  nth0 = 2;
1231
  nth1 = 32;
1232
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1233
  } break;
1234
  case GGML_TYPE_Q3_K:
1235
  {
1236
- GGML_ASSERT(ne02 == 1);
1237
- GGML_ASSERT(ne12 == 1);
1238
-
1239
  nth0 = 2;
1240
  nth1 = 32;
1241
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1242
  } break;
1243
  case GGML_TYPE_Q4_K:
1244
  {
1245
- GGML_ASSERT(ne02 == 1);
1246
- GGML_ASSERT(ne12 == 1);
1247
-
1248
  nth0 = 4; //1;
1249
  nth1 = 8; //32;
1250
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1251
  } break;
1252
  case GGML_TYPE_Q5_K:
1253
  {
1254
- GGML_ASSERT(ne02 == 1);
1255
- GGML_ASSERT(ne12 == 1);
1256
-
1257
  nth0 = 2;
1258
  nth1 = 32;
1259
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1260
  } break;
1261
  case GGML_TYPE_Q6_K:
1262
  {
1263
- GGML_ASSERT(ne02 == 1);
1264
- GGML_ASSERT(ne12 == 1);
1265
-
1266
  nth0 = 2;
1267
  nth1 = 32;
1268
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1291,32 +1420,125 @@ void ggml_metal_graph_compute(
1291
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1292
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1293
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1294
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
 
1295
 
1296
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1297
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1298
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1299
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1300
  }
1301
  else if (src0t == GGML_TYPE_Q4_K) {
1302
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1303
  }
1304
  else if (src0t == GGML_TYPE_Q3_K) {
1305
  #ifdef GGML_QKK_64
1306
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1307
  #else
1308
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1309
  #endif
1310
  }
1311
  else if (src0t == GGML_TYPE_Q5_K) {
1312
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1313
  }
1314
  else if (src0t == GGML_TYPE_Q6_K) {
1315
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1316
  } else {
1317
  int64_t ny = (ne11 + nrows - 1)/nrows;
1318
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1319
  }
 
 
 
1320
  }
1321
  } break;
1322
  case GGML_OP_GET_ROWS:
@@ -1355,15 +1577,19 @@ void ggml_metal_graph_compute(
1355
  float eps;
1356
  memcpy(&eps, dst->op_params, sizeof(float));
1357
 
1358
- const int nth = MIN(512, ne00);
 
 
 
 
1359
 
1360
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1361
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1362
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1363
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1364
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1365
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1366
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1367
 
1368
  const int64_t nrows = ggml_nrows(src0);
1369
 
@@ -1437,7 +1663,8 @@ void ggml_metal_graph_compute(
1437
  const int n_past = ((int32_t *) dst->op_params)[0];
1438
  const int n_dims = ((int32_t *) dst->op_params)[1];
1439
  const int mode = ((int32_t *) dst->op_params)[2];
1440
- const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
 
1441
 
1442
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1443
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@@ -1537,18 +1764,48 @@ void ggml_metal_graph_compute(
1537
 
1538
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1539
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1540
  case GGML_OP_DUP:
1541
  case GGML_OP_CPY:
1542
  case GGML_OP_CONT:
1543
  {
1544
- const int nth = MIN(1024, ne00);
 
 
1545
 
1546
  switch (src0t) {
1547
  case GGML_TYPE_F32:
1548
  {
 
 
1549
  switch (dstt) {
1550
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1551
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
 
 
 
 
 
1552
  default: GGML_ASSERT(false && "not implemented");
1553
  };
1554
  } break;
@@ -1623,81 +1880,150 @@ void ggml_metal_graph_compute(
1623
 
1624
  // backend interface
1625
 
1626
- static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1627
- return "Metal";
1628
 
1629
- UNUSED(backend);
 
 
 
 
 
 
 
1630
  }
1631
 
1632
- static void ggml_backend_metal_free(ggml_backend_t backend) {
1633
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1634
- ggml_metal_free(ctx);
1635
- free(backend);
 
 
 
 
 
1636
  }
1637
 
1638
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1639
- return (void *)buffer->context;
 
 
1640
  }
1641
 
1642
  static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1643
- free(buffer->context);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1644
  UNUSED(buffer);
1645
  }
1646
 
1647
  static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1648
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1649
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
1650
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1651
- /* .init_tensor = */ NULL, // no initialization required
1652
- /* .free_tensor = */ NULL, // no cleanup required
 
 
1653
  };
1654
 
1655
- static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1656
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1657
 
1658
- void * data = ggml_metal_host_malloc(size);
1659
 
1660
- // TODO: set proper name of the buffers
1661
- ggml_metal_add_buffer(ctx, "backend", data, size, 0);
 
 
 
 
 
 
 
 
1662
 
1663
- return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1664
  }
1665
 
1666
- static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
1667
  return 32;
1668
- UNUSED(backend);
1669
  }
1670
 
1671
- static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1672
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1673
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1674
-
1675
- memcpy((char *)tensor->data + offset, data, size);
1676
 
1677
- UNUSED(backend);
1678
  }
1679
 
1680
- static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1681
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1682
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1683
-
1684
- memcpy(data, (const char *)tensor->data + offset, size);
 
 
 
 
 
1685
 
1686
- UNUSED(backend);
1687
  }
1688
 
1689
- static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
 
 
1690
  UNUSED(backend);
1691
  }
1692
 
1693
- static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1694
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
 
 
 
1695
 
 
1696
  UNUSED(backend);
1697
  }
1698
 
1699
- static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1700
- ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
1701
 
1702
  UNUSED(backend);
1703
  }
@@ -1709,32 +2035,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
1709
  }
1710
 
1711
  static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1712
- return true;
 
1713
  UNUSED(backend);
1714
- UNUSED(op);
1715
  }
1716
 
1717
  static struct ggml_backend_i metal_backend_i = {
1718
- /* .get_name = */ ggml_backend_metal_name,
1719
- /* .free = */ ggml_backend_metal_free,
1720
- /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1721
- /* .get_alignment = */ ggml_backend_metal_get_alignment,
1722
- /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1723
- /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1724
- /* .synchronize = */ ggml_backend_metal_synchronize,
1725
- /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1726
- /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1727
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1728
- /* .graph_plan_free = */ NULL,
1729
- /* .graph_plan_compute = */ NULL,
1730
- /* .graph_compute = */ ggml_backend_metal_graph_compute,
1731
- /* .supports_op = */ ggml_backend_metal_supports_op,
1732
  };
1733
 
 
 
 
 
 
 
 
 
1734
  ggml_backend_t ggml_backend_metal_init(void) {
1735
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
1736
 
1737
- ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
 
 
 
 
1738
 
1739
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1740
 
@@ -1751,13 +2088,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
1751
  }
1752
 
1753
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
 
 
1754
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1755
 
1756
  ggml_metal_set_n_cb(ctx, n_cb);
1757
  }
1758
 
1759
  bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
 
 
1760
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1761
 
1762
- return ggml_metal_supports_family(ctx, family);
 
 
 
 
 
 
 
 
 
1763
  }
 
62
  GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
  GGML_METAL_DECL_KERNEL(mul);
64
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
+ GGML_METAL_DECL_KERNEL(div);
66
+ GGML_METAL_DECL_KERNEL(div_row);
67
  GGML_METAL_DECL_KERNEL(scale);
68
  GGML_METAL_DECL_KERNEL(scale_4);
69
  GGML_METAL_DECL_KERNEL(silu);
 
114
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
115
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
116
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
123
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
124
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
125
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
126
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
127
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
128
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
129
  GGML_METAL_DECL_KERNEL(rope_f32);
130
  GGML_METAL_DECL_KERNEL(rope_f16);
131
  GGML_METAL_DECL_KERNEL(alibi_f32);
132
  GGML_METAL_DECL_KERNEL(im2col_f16);
133
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
135
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
136
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
+ GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
138
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
139
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
140
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
142
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
143
  GGML_METAL_DECL_KERNEL(concat);
144
  GGML_METAL_DECL_KERNEL(sqr);
145
+ GGML_METAL_DECL_KERNEL(sum_rows);
146
 
147
  #undef GGML_METAL_DECL_KERNEL
148
  };
 
186
  }
187
  }
188
 
 
 
189
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
190
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
191
 
192
+ id<MTLDevice> device;
193
  NSString * s;
194
 
195
  #if TARGET_OS_OSX
 
235
 
236
  NSString * sourcePath;
237
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
238
+
239
+ GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
240
+
241
  if (ggmlMetalPathResources) {
242
  sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
243
  } else {
 
268
  }
269
  }
270
 
271
+ #if TARGET_OS_OSX
272
+ // print MTL GPU family:
273
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
274
+
275
+ // determine max supported GPU family
276
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
277
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
278
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
279
+ if ([ctx->device supportsFamily:i]) {
280
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
281
+ break;
282
+ }
283
+ }
284
+
285
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
286
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
287
+ if (ctx->device.maxTransferRate != 0) {
288
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
289
+ } else {
290
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
291
+ }
292
+ #endif
293
+
294
  // load kernels
295
  {
296
  NSError * error = nil;
 
312
  GGML_METAL_ADD_KERNEL(add_row);
313
  GGML_METAL_ADD_KERNEL(mul);
314
  GGML_METAL_ADD_KERNEL(mul_row);
315
+ GGML_METAL_ADD_KERNEL(div);
316
+ GGML_METAL_ADD_KERNEL(div_row);
317
  GGML_METAL_ADD_KERNEL(scale);
318
  GGML_METAL_ADD_KERNEL(scale_4);
319
  GGML_METAL_ADD_KERNEL(silu);
 
365
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
366
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
367
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
368
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
369
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
370
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
371
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
372
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
373
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
374
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
375
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
376
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
377
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
378
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
379
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
380
  }
381
  GGML_METAL_ADD_KERNEL(rope_f32);
382
  GGML_METAL_ADD_KERNEL(rope_f16);
383
  GGML_METAL_ADD_KERNEL(alibi_f32);
384
  GGML_METAL_ADD_KERNEL(im2col_f16);
385
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
386
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
387
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
388
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
389
+ GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
390
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
391
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
392
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
393
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
394
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
395
  GGML_METAL_ADD_KERNEL(concat);
396
  GGML_METAL_ADD_KERNEL(sqr);
397
+ GGML_METAL_ADD_KERNEL(sum_rows);
398
 
399
  #undef GGML_METAL_ADD_KERNEL
400
  }
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  return ctx;
403
  }
404
 
 
412
  GGML_METAL_DEL_KERNEL(add_row);
413
  GGML_METAL_DEL_KERNEL(mul);
414
  GGML_METAL_DEL_KERNEL(mul_row);
415
+ GGML_METAL_DEL_KERNEL(div);
416
+ GGML_METAL_DEL_KERNEL(div_row);
417
  GGML_METAL_DEL_KERNEL(scale);
418
  GGML_METAL_DEL_KERNEL(scale_4);
419
  GGML_METAL_DEL_KERNEL(silu);
 
465
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
466
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
467
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
468
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
469
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
470
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
471
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
472
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
473
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
474
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
475
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
476
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
477
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
478
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
479
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
480
  }
481
  GGML_METAL_DEL_KERNEL(rope_f32);
482
  GGML_METAL_DEL_KERNEL(rope_f16);
483
  GGML_METAL_DEL_KERNEL(alibi_f32);
484
  GGML_METAL_DEL_KERNEL(im2col_f16);
485
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
487
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
488
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
+ GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
490
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
491
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
492
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
494
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
495
  GGML_METAL_DEL_KERNEL(concat);
496
  GGML_METAL_DEL_KERNEL(sqr);
497
+ GGML_METAL_DEL_KERNEL(sum_rows);
498
 
499
  #undef GGML_METAL_DEL_KERNEL
500
 
 
526
  free(data);
527
  }
528
 
 
 
 
 
529
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
530
  ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
531
  }
 
538
  return ctx->concur_list;
539
  }
540
 
541
+ // temporarily defined here for compatibility between ggml-backend and the old API
542
+ struct ggml_backend_metal_buffer_context {
543
+ void * data;
544
+
545
+ id<MTLBuffer> metal;
546
+ };
547
+
548
  // finds the Metal buffer that contains the tensor data on the GPU device
549
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
550
  // Metal buffer based on the host memory pointer
 
554
 
555
  const int64_t tsize = ggml_nbytes(t);
556
 
557
+ // compatibility with ggml-backend
558
+ if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
559
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
560
+
561
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
562
+
563
+ GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
564
+
565
+ *offs = (size_t) ioffs;
566
+
567
+ return buf_ctx->metal;
568
  }
569
 
570
  // find the view that contains the tensor fully
 
624
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
625
 
626
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
627
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
628
  return false;
629
  }
630
 
631
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
632
 
633
  ++ctx->n_buffers;
634
  } else {
 
648
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
649
 
650
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
651
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
652
  return false;
653
  }
654
 
655
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
656
  if (i + size_step < size) {
657
  GGML_METAL_LOG_INFO("\n");
658
  }
 
663
 
664
  #if TARGET_OS_OSX
665
  GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
666
+ ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
667
+ ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
668
 
669
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
670
  GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
 
672
  GGML_METAL_LOG_INFO("\n");
673
  }
674
  #else
675
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
676
  #endif
677
  }
678
 
 
789
  }
790
  }
791
 
792
+ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
793
+ switch (op->op) {
794
+ case GGML_OP_UNARY:
795
+ switch (ggml_get_unary_op(op)) {
796
+ case GGML_UNARY_OP_SILU:
797
+ case GGML_UNARY_OP_RELU:
798
+ case GGML_UNARY_OP_GELU:
799
+ return true;
800
+ default:
801
+ return false;
802
+ }
803
+ case GGML_OP_NONE:
804
+ case GGML_OP_RESHAPE:
805
+ case GGML_OP_VIEW:
806
+ case GGML_OP_TRANSPOSE:
807
+ case GGML_OP_PERMUTE:
808
+ case GGML_OP_CONCAT:
809
+ case GGML_OP_ADD:
810
+ case GGML_OP_MUL:
811
+ case GGML_OP_DIV:
812
+ case GGML_OP_SCALE:
813
+ case GGML_OP_SQR:
814
+ case GGML_OP_SUM_ROWS:
815
+ case GGML_OP_SOFT_MAX:
816
+ case GGML_OP_RMS_NORM:
817
+ case GGML_OP_NORM:
818
+ case GGML_OP_ALIBI:
819
+ case GGML_OP_ROPE:
820
+ case GGML_OP_IM2COL:
821
+ case GGML_OP_ARGSORT:
822
+ case GGML_OP_DUP:
823
+ case GGML_OP_CPY:
824
+ case GGML_OP_CONT:
825
+ case GGML_OP_MUL_MAT:
826
+ case GGML_OP_MUL_MAT_ID:
827
+ return true;
828
+ case GGML_OP_DIAG_MASK_INF:
829
+ case GGML_OP_GET_ROWS:
830
+ {
831
+ return op->ne[0] % 4 == 0;
832
+ }
833
+ default:
834
+ return false;
835
+ }
836
+ }
837
  void ggml_metal_graph_compute(
838
  struct ggml_metal_context * ctx,
839
  struct ggml_cgraph * gf) {
 
904
  } break;
905
  }
906
 
907
+ GGML_ASSERT(ggml_metal_supports_op(dst));
908
+
909
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
910
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
911
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
 
998
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
999
  } break;
1000
  case GGML_OP_ADD:
1001
+ case GGML_OP_MUL:
1002
+ case GGML_OP_DIV:
1003
  {
1004
  GGML_ASSERT(ggml_is_contiguous(src0));
1005
  GGML_ASSERT(ggml_is_contiguous(src1));
 
1013
  GGML_ASSERT(ne11 == 1);
1014
 
1015
  nb = ne00 / 4;
1016
+ switch (dst->op) {
1017
+ case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1018
+ case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1019
+ case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1020
+ default: GGML_ASSERT(false);
1021
+ }
1022
 
1023
  bcast_row = true;
1024
  } else {
1025
+ switch (dst->op) {
1026
+ case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1027
+ case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1028
+ case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1029
+ default: GGML_ASSERT(false);
1030
+ }
1031
  }
1032
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1033
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
 
1068
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1069
  }
1070
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
  case GGML_OP_SCALE:
1072
  {
1073
  GGML_ASSERT(ggml_is_contiguous(src0));
 
1140
  const int64_t n = ggml_nelements(dst);
1141
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1142
  } break;
1143
+ case GGML_OP_SUM_ROWS:
1144
+ {
1145
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1146
+
1147
+ [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1148
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1149
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1150
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1151
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1152
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1153
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1154
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1155
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1156
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1157
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1158
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1159
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1160
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1161
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1162
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1163
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1164
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1165
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1166
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1167
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1168
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1169
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1170
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1171
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1172
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1173
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1174
+
1175
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1176
+ } break;
1177
  case GGML_OP_SOFT_MAX:
1178
  {
1179
  int nth = 32; // SIMD width
1180
 
1181
  if (ne00%4 == 0) {
1182
+ while (nth < ne00/4 && nth < 256) {
1183
+ nth *= 2;
1184
+ }
1185
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1186
  } else {
1187
+ while (nth < ne00 && nth < 1024) {
1188
  nth *= 2;
1189
+ }
 
1190
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
1191
  }
1192
+
1193
+ const float scale = ((float *) dst->op_params)[0];
1194
+
1195
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1196
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1197
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1198
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1199
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1200
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1201
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1202
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1203
 
1204
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1205
  } break;
 
1228
  case GGML_OP_MUL_MAT:
1229
  {
1230
  GGML_ASSERT(ne00 == ne10);
 
1231
 
1232
+ // TODO: assert that dim2 and dim3 are contiguous
1233
+ GGML_ASSERT(ne12 % ne02 == 0);
1234
+ GGML_ASSERT(ne13 % ne03 == 0);
1235
+
1236
+ const uint r2 = ne12/ne02;
1237
+ const uint r3 = ne13/ne03;
1238
 
1239
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1240
  // to the matrix-vector kernel
 
1269
  !ggml_is_transposed(src1) &&
1270
  src1t == GGML_TYPE_F32 &&
1271
  ne00 % 32 == 0 && ne00 >= 64 &&
1272
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1273
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1274
  switch (src0->type) {
1275
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
 
1299
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1300
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1301
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1302
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1303
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1304
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1305
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1306
  } else {
1307
  int nth0 = 32;
1308
  int nth1 = 1;
 
1338
  } break;
1339
  case GGML_TYPE_Q4_0:
1340
  {
 
 
 
1341
  nth0 = 8;
1342
  nth1 = 8;
1343
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1344
  } break;
1345
  case GGML_TYPE_Q4_1:
1346
  {
 
 
 
1347
  nth0 = 8;
1348
  nth1 = 8;
1349
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1350
  } break;
1351
  case GGML_TYPE_Q5_0:
1352
  {
 
 
 
1353
  nth0 = 8;
1354
  nth1 = 8;
1355
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1356
  } break;
1357
  case GGML_TYPE_Q5_1:
1358
  {
 
 
 
1359
  nth0 = 8;
1360
  nth1 = 8;
1361
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1362
  } break;
1363
  case GGML_TYPE_Q8_0:
1364
  {
 
 
 
1365
  nth0 = 8;
1366
  nth1 = 8;
1367
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1368
  } break;
1369
  case GGML_TYPE_Q2_K:
1370
  {
 
 
 
1371
  nth0 = 2;
1372
  nth1 = 32;
1373
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1374
  } break;
1375
  case GGML_TYPE_Q3_K:
1376
  {
 
 
 
1377
  nth0 = 2;
1378
  nth1 = 32;
1379
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1380
  } break;
1381
  case GGML_TYPE_Q4_K:
1382
  {
 
 
 
1383
  nth0 = 4; //1;
1384
  nth1 = 8; //32;
1385
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1386
  } break;
1387
  case GGML_TYPE_Q5_K:
1388
  {
 
 
 
1389
  nth0 = 2;
1390
  nth1 = 32;
1391
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1392
  } break;
1393
  case GGML_TYPE_Q6_K:
1394
  {
 
 
 
1395
  nth0 = 2;
1396
  nth1 = 32;
1397
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
 
1420
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1421
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1422
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1423
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1424
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1425
 
1426
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1427
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1428
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1429
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1430
  }
1431
  else if (src0t == GGML_TYPE_Q4_K) {
1432
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1433
  }
1434
  else if (src0t == GGML_TYPE_Q3_K) {
1435
  #ifdef GGML_QKK_64
1436
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1437
  #else
1438
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1439
  #endif
1440
  }
1441
  else if (src0t == GGML_TYPE_Q5_K) {
1442
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1443
  }
1444
  else if (src0t == GGML_TYPE_Q6_K) {
1445
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1446
  } else {
1447
  int64_t ny = (ne11 + nrows - 1)/nrows;
1448
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1449
+ }
1450
+ }
1451
+ } break;
1452
+ case GGML_OP_MUL_MAT_ID:
1453
+ {
1454
+ //GGML_ASSERT(ne00 == ne10);
1455
+ //GGML_ASSERT(ne03 == ne13);
1456
+
1457
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
1458
+
1459
+ const int n_as = ne00;
1460
+
1461
+ // TODO: make this more general
1462
+ GGML_ASSERT(n_as <= 8);
1463
+
1464
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1465
+
1466
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1467
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1468
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1469
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1470
+
1471
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1472
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1473
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1474
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1475
+
1476
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1477
+
1478
+ GGML_ASSERT(!ggml_is_transposed(src2));
1479
+ GGML_ASSERT(!ggml_is_transposed(src1));
1480
+
1481
+ GGML_ASSERT(ne20 % 32 == 0);
1482
+ // !!!!!!!!! TODO: this assert is probably required but not sure!
1483
+ //GGML_ASSERT(ne20 >= 64);
1484
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1485
+
1486
+ const uint r2 = ne12/ne22;
1487
+ const uint r3 = ne13/ne23;
1488
+
1489
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1490
+ // to the matrix-vector kernel
1491
+ int ne11_mm_min = 0;
1492
+
1493
+ const int idx = ((int32_t *) dst->op_params)[0];
1494
+
1495
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1496
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1497
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1498
+ ne11 > ne11_mm_min) {
1499
+ switch (src2->type) {
1500
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1501
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1502
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1503
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1504
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1505
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1506
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1507
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1508
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1509
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1510
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1511
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1512
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1513
+ }
1514
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1517
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1518
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1519
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1520
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1521
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1522
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1523
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1524
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1525
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1526
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1527
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1528
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1529
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1530
+ // TODO: how to make this an array? read Metal docs
1531
+ for (int j = 0; j < n_as; ++j) {
1532
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1533
+
1534
+ size_t offs_src_cur = 0;
1535
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1536
+
1537
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1538
  }
1539
+
1540
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1541
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1542
  }
1543
  } break;
1544
  case GGML_OP_GET_ROWS:
 
1577
  float eps;
1578
  memcpy(&eps, dst->op_params, sizeof(float));
1579
 
1580
+ int nth = 32; // SIMD width
1581
+
1582
+ while (nth < ne00/4 && nth < 1024) {
1583
+ nth *= 2;
1584
+ }
1585
 
1586
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1587
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1588
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1589
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1590
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1591
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1592
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1593
 
1594
  const int64_t nrows = ggml_nrows(src0);
1595
 
 
1663
  const int n_past = ((int32_t *) dst->op_params)[0];
1664
  const int n_dims = ((int32_t *) dst->op_params)[1];
1665
  const int mode = ((int32_t *) dst->op_params)[2];
1666
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1667
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1668
 
1669
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1670
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
 
1764
 
1765
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1766
  } break;
1767
+ case GGML_OP_ARGSORT:
1768
+ {
1769
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
1770
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
1771
+
1772
+ const int nrows = ggml_nrows(src0);
1773
+
1774
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
1775
+
1776
+ switch (order) {
1777
+ case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
1778
+ case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
1779
+ default: GGML_ASSERT(false);
1780
+ };
1781
+
1782
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1783
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1784
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1785
+
1786
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1787
+ } break;
1788
  case GGML_OP_DUP:
1789
  case GGML_OP_CPY:
1790
  case GGML_OP_CONT:
1791
  {
1792
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
1793
+
1794
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
1795
 
1796
  switch (src0t) {
1797
  case GGML_TYPE_F32:
1798
  {
1799
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
1800
+
1801
  switch (dstt) {
1802
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1803
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1804
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
1805
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
1806
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
1807
+ //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
1808
+ //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
1809
  default: GGML_ASSERT(false && "not implemented");
1810
  };
1811
  } break;
 
1880
 
1881
  // backend interface
1882
 
1883
+ static id<MTLDevice> g_backend_device = nil;
1884
+ static int g_backend_device_ref_count = 0;
1885
 
1886
+ static id<MTLDevice> ggml_backend_metal_get_device(void) {
1887
+ if (g_backend_device == nil) {
1888
+ g_backend_device = MTLCreateSystemDefaultDevice();
1889
+ }
1890
+
1891
+ g_backend_device_ref_count++;
1892
+
1893
+ return g_backend_device;
1894
  }
1895
 
1896
+ static void ggml_backend_metal_free_device(void) {
1897
+ assert(g_backend_device_ref_count > 0);
1898
+
1899
+ g_backend_device_ref_count--;
1900
+
1901
+ if (g_backend_device_ref_count == 0) {
1902
+ [g_backend_device release];
1903
+ g_backend_device = nil;
1904
+ }
1905
  }
1906
 
1907
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1908
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1909
+
1910
+ return ctx->data;
1911
  }
1912
 
1913
  static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1914
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1915
+
1916
+ [ctx->metal release];
1917
+ ggml_backend_metal_free_device();
1918
+
1919
+ free(ctx->data);
1920
+ free(ctx);
1921
+
1922
+ UNUSED(buffer);
1923
+ }
1924
+
1925
+ static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1926
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1927
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1928
+
1929
+ memcpy((char *)tensor->data + offset, data, size);
1930
+
1931
+ UNUSED(buffer);
1932
+ }
1933
+
1934
+ static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1935
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1936
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1937
+
1938
+ memcpy(data, (const char *)tensor->data + offset, size);
1939
+
1940
+ UNUSED(buffer);
1941
+ }
1942
+
1943
+ static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
1944
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
1945
+
1946
+ UNUSED(buffer);
1947
+ }
1948
+
1949
+ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
1950
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
1951
+
1952
  UNUSED(buffer);
1953
  }
1954
 
1955
  static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1956
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1957
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
1958
+ /* .init_tensor = */ NULL,
1959
+ /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
1960
+ /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
1961
+ /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
1962
+ /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
1963
  };
1964
 
1965
+ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1966
+ struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
1967
 
1968
+ const size_t size_page = sysconf(_SC_PAGESIZE);
1969
 
1970
+ size_t size_aligned = size;
1971
+ if ((size_aligned % size_page) != 0) {
1972
+ size_aligned += (size_page - (size_aligned % size_page));
1973
+ }
1974
+
1975
+ ctx->data = ggml_metal_host_malloc(size);
1976
+ ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
1977
+ length:size_aligned
1978
+ options:MTLResourceStorageModeShared
1979
+ deallocator:nil];
1980
 
1981
+ return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
1982
  }
1983
 
1984
+ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1985
  return 32;
1986
+ UNUSED(buft);
1987
  }
1988
 
1989
+ static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
1990
+ return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
 
 
 
1991
 
1992
+ GGML_UNUSED(buft);
1993
  }
1994
 
1995
+ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
1996
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
1997
+ /* .iface = */ {
1998
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
1999
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2000
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2001
+ /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
2002
+ },
2003
+ /* .context = */ NULL,
2004
+ };
2005
 
2006
+ return &ggml_backend_buffer_type_metal;
2007
  }
2008
 
2009
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2010
+ return "Metal";
2011
+
2012
  UNUSED(backend);
2013
  }
2014
 
2015
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
2016
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2017
+ ggml_metal_free(ctx);
2018
+ free(backend);
2019
+ }
2020
 
2021
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
2022
  UNUSED(backend);
2023
  }
2024
 
2025
+ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2026
+ return ggml_backend_metal_buffer_type();
2027
 
2028
  UNUSED(backend);
2029
  }
 
2035
  }
2036
 
2037
  static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
2038
+ return ggml_metal_supports_op(op);
2039
+
2040
  UNUSED(backend);
 
2041
  }
2042
 
2043
  static struct ggml_backend_i metal_backend_i = {
2044
+ /* .get_name = */ ggml_backend_metal_name,
2045
+ /* .free = */ ggml_backend_metal_free,
2046
+ /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
2047
+ /* .set_tensor_async = */ NULL,
2048
+ /* .get_tensor_async = */ NULL,
2049
+ /* .cpy_tensor_from_async = */ NULL,
2050
+ /* .cpy_tensor_to_async = */ NULL,
2051
+ /* .synchronize = */ ggml_backend_metal_synchronize,
2052
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2053
+ /* .graph_plan_free = */ NULL,
2054
+ /* .graph_plan_compute = */ NULL,
2055
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
2056
+ /* .supports_op = */ ggml_backend_metal_supports_op,
 
2057
  };
2058
 
2059
+ // TODO: make a common log callback for all backends in ggml-backend
2060
+ static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
2061
+ fprintf(stderr, "%s", msg);
2062
+
2063
+ UNUSED(level);
2064
+ UNUSED(user_data);
2065
+ }
2066
+
2067
  ggml_backend_t ggml_backend_metal_init(void) {
2068
+ ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
2069
 
2070
+ struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2071
+
2072
+ if (ctx == NULL) {
2073
+ return NULL;
2074
+ }
2075
 
2076
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
2077
 
 
2088
  }
2089
 
2090
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
2091
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2092
+
2093
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2094
 
2095
  ggml_metal_set_n_cb(ctx, n_cb);
2096
  }
2097
 
2098
  bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2099
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2100
+
2101
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2102
 
2103
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2104
+ }
2105
+
2106
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2107
+
2108
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2109
+ return ggml_backend_metal_init();
2110
+
2111
+ GGML_UNUSED(params);
2112
+ GGML_UNUSED(user_data);
2113
  }
ggml-metal.metal CHANGED
@@ -3,6 +3,8 @@
3
  using namespace metal;
4
 
5
  #define MAX(x, y) ((x) > (y) ? (x) : (y))
 
 
6
 
7
  #define QK4_0 32
8
  #define QR4_0 2
@@ -39,8 +41,15 @@ typedef struct {
39
  int8_t qs[QK8_0]; // quants
40
  } block_q8_0;
41
 
42
- // general-purpose kernel for addition of two tensors
43
- // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
 
 
 
 
 
 
 
44
  // cons: not very efficient
45
  kernel void kernel_add(
46
  device const char * src0,
@@ -81,16 +90,111 @@ kernel void kernel_add(
81
  const int64_t i12 = i02 % ne12;
82
  const int64_t i11 = i01 % ne11;
83
 
84
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
85
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
86
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
89
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- src0_ptr += ntg.x*nb00;
92
- src1_ptr += ntg.x*nb10;
93
- dst_ptr += ntg.x*nb0;
94
  }
95
  }
96
 
@@ -105,23 +209,22 @@ kernel void kernel_add_row(
105
  dst[tpig] = src0[tpig] + src1[tpig % nb];
106
  }
107
 
108
- kernel void kernel_mul(
109
  device const float4 * src0,
110
  device const float4 * src1,
111
  device float4 * dst,
 
112
  uint tpig[[thread_position_in_grid]]) {
113
- dst[tpig] = src0[tpig] * src1[tpig];
114
  }
115
 
116
- // assumption: src1 is a row
117
- // broadcast src1 into src0
118
- kernel void kernel_mul_row(
119
  device const float4 * src0,
120
  device const float4 * src1,
121
  device float4 * dst,
122
- constant int64_t & nb,
123
  uint tpig[[thread_position_in_grid]]) {
124
- dst[tpig] = src0[tpig] * src1[tpig % nb];
125
  }
126
 
127
  kernel void kernel_scale(
@@ -162,6 +265,54 @@ kernel void kernel_sqr(
162
  dst[tpig] = src0[tpig] * src0[tpig];
163
  }
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  constant float GELU_COEF_A = 0.044715f;
166
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
167
 
@@ -180,10 +331,12 @@ kernel void kernel_gelu(
180
 
181
  kernel void kernel_soft_max(
182
  device const float * src0,
 
183
  device float * dst,
184
  constant int64_t & ne00,
185
  constant int64_t & ne01,
186
  constant int64_t & ne02,
 
187
  threadgroup float * buf [[threadgroup(0)]],
188
  uint tgpig[[threadgroup_position_in_grid]],
189
  uint tpitg[[thread_position_in_threadgroup]],
@@ -194,73 +347,77 @@ kernel void kernel_soft_max(
194
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
196
 
197
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
198
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
199
 
200
  // parallel max
201
- float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
202
 
203
- for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
204
- lmax = MAX(lmax, psrc0[i00]);
205
  }
206
 
207
- float max = simd_max(lmax);
208
- if (tiisg == 0) {
209
- buf[sgitg] = max;
210
- }
 
 
211
 
212
- threadgroup_barrier(mem_flags::mem_threadgroup);
213
 
214
- // broadcast, simd group number is ntg / 32
215
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216
- if (tpitg < i) {
217
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218
- }
219
- }
220
 
221
- threadgroup_barrier(mem_flags::mem_threadgroup);
222
 
223
- max = buf[0];
 
 
224
 
225
  // parallel sum
226
  float lsum = 0.0f;
227
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
228
- const float exp_psrc0 = exp(psrc0[i00] - max);
229
  lsum += exp_psrc0;
230
- // Remember the result of exp here. exp is expensive, so we really do not
231
- // wish to compute it twice.
232
  pdst[i00] = exp_psrc0;
233
  }
234
 
235
  float sum = simd_sum(lsum);
236
- if (tiisg == 0) {
237
- buf[sgitg] = sum;
238
- }
 
239
 
240
- threadgroup_barrier(mem_flags::mem_threadgroup);
241
 
242
- // broadcast, simd group number is ntg / 32
243
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244
- if (tpitg < i) {
245
- buf[tpitg] += buf[tpitg + i];
246
- }
247
- }
248
 
249
- threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
 
 
250
 
251
- sum = buf[0];
252
 
253
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
254
- pdst[i00] /= sum;
255
  }
256
  }
257
 
258
  kernel void kernel_soft_max_4(
259
  device const float * src0,
 
260
  device float * dst,
261
  constant int64_t & ne00,
262
  constant int64_t & ne01,
263
  constant int64_t & ne02,
 
264
  threadgroup float * buf [[threadgroup(0)]],
265
  uint tgpig[[threadgroup_position_in_grid]],
266
  uint tpitg[[thread_position_in_threadgroup]],
@@ -271,64 +428,68 @@ kernel void kernel_soft_max_4(
271
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
273
 
274
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
275
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
276
 
277
  // parallel max
278
- float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
279
 
280
- for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
281
- lmax4 = fmax(lmax4, psrc4[i00]);
282
  }
283
 
284
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285
- float max = simd_max(lmax);
286
- if (tiisg == 0) {
287
- buf[sgitg] = max;
288
- }
289
 
290
- threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
 
 
291
 
292
- // broadcast, simd group number is ntg / 32
293
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294
- if (tpitg < i) {
295
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296
- }
297
- }
298
 
299
- threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
300
 
301
- max = buf[0];
 
 
 
 
302
 
303
  // parallel sum
304
  float4 lsum4 = 0.0f;
305
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
306
- const float4 exp_psrc4 = exp(psrc4[i00] - max);
307
  lsum4 += exp_psrc4;
308
  pdst4[i00] = exp_psrc4;
309
  }
310
 
311
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
312
  float sum = simd_sum(lsum);
313
- if (tiisg == 0) {
314
- buf[sgitg] = sum;
315
- }
 
316
 
317
- threadgroup_barrier(mem_flags::mem_threadgroup);
318
 
319
- // broadcast, simd group number is ntg / 32
320
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321
- if (tpitg < i) {
322
- buf[tpitg] += buf[tpitg + i];
323
- }
324
- }
325
 
326
- threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
 
 
327
 
328
- sum = buf[0];
329
 
330
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
331
- pdst4[i00] /= sum;
332
  }
333
  }
334
 
@@ -435,14 +596,13 @@ kernel void kernel_rms_norm(
435
  constant int64_t & ne00,
436
  constant uint64_t & nb01,
437
  constant float & eps,
438
- threadgroup float * sum [[threadgroup(0)]],
439
  uint tgpig[[threadgroup_position_in_grid]],
440
  uint tpitg[[thread_position_in_threadgroup]],
441
  uint sgitg[[simdgroup_index_in_threadgroup]],
442
  uint tiisg[[thread_index_in_simdgroup]],
443
  uint ntg[[threads_per_threadgroup]]) {
444
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
445
- device const float * x_scalar = (device const float *) x;
446
 
447
  float4 sumf = 0;
448
  float all_sum = 0;
@@ -453,40 +613,30 @@ kernel void kernel_rms_norm(
453
  }
454
  all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
455
  all_sum = simd_sum(all_sum);
456
- if (tiisg == 0) {
457
- sum[sgitg] = all_sum;
458
- }
 
459
 
460
- threadgroup_barrier(mem_flags::mem_threadgroup);
461
 
462
- // broadcast, simd group number is ntg / 32
463
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
464
- if (tpitg < i) {
465
- sum[tpitg] += sum[tpitg + i];
466
- }
467
- }
468
- if (tpitg == 0) {
469
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {
470
- sum[0] += x_scalar[i];
471
  }
472
- sum[0] /= ne00;
473
- }
474
 
475
- threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
 
 
476
 
477
- const float mean = sum[0];
478
  const float scale = 1.0f/sqrt(mean + eps);
479
 
480
  device float4 * y = (device float4 *) (dst + tgpig*ne00);
481
- device float * y_scalar = (device float *) y;
482
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
483
  y[i00] = x[i00] * scale;
484
  }
485
- if (tpitg == 0) {
486
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
487
- y_scalar[i00] = x_scalar[i00] * scale;
488
- }
489
- }
490
  }
491
 
492
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -576,15 +726,25 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
576
  // putting them in the kernel cause a significant performance penalty
577
  #define N_DST 4 // each SIMD group works on 4 rows
578
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
579
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
580
  //Note: This is a template, but strictly speaking it only applies to
581
  // quantizations where the block size is 32. It also does not
582
  // giard against the number of rows not being divisible by
583
  // N_DST, so this is another explicit assumption of the implementation.
584
  template<typename block_q_type, int nr, int nsg, int nw>
585
- void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
586
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
587
- uint3 tgpig, uint tiisg, uint sgitg) {
 
 
 
 
 
 
 
 
 
 
 
588
  const int nb = ne00/QK4_0;
589
 
590
  const int r0 = tgpig.x;
@@ -593,7 +753,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
593
 
594
  const int first_row = (r0 * nsg + sgitg) * nr;
595
 
596
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
 
 
 
597
 
598
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
599
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
@@ -643,13 +806,14 @@ kernel void kernel_mul_mv_q4_0_f32(
643
  constant int64_t & ne02[[buffer(5)]],
644
  constant int64_t & ne10[[buffer(9)]],
645
  constant int64_t & ne12[[buffer(11)]],
646
- constant int64_t & ne0[[buffer(15)]],
647
- constant int64_t & ne1[[buffer(16)]],
648
- constant uint & gqa[[buffer(17)]],
 
649
  uint3 tgpig[[threadgroup_position_in_grid]],
650
  uint tiisg[[thread_index_in_simdgroup]],
651
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
652
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
653
  }
654
 
655
  kernel void kernel_mul_mv_q4_1_f32(
@@ -661,13 +825,14 @@ kernel void kernel_mul_mv_q4_1_f32(
661
  constant int64_t & ne02[[buffer(5)]],
662
  constant int64_t & ne10[[buffer(9)]],
663
  constant int64_t & ne12[[buffer(11)]],
664
- constant int64_t & ne0[[buffer(15)]],
665
- constant int64_t & ne1[[buffer(16)]],
666
- constant uint & gqa[[buffer(17)]],
 
667
  uint3 tgpig[[threadgroup_position_in_grid]],
668
  uint tiisg[[thread_index_in_simdgroup]],
669
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
670
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
671
  }
672
 
673
  kernel void kernel_mul_mv_q5_0_f32(
@@ -679,13 +844,14 @@ kernel void kernel_mul_mv_q5_0_f32(
679
  constant int64_t & ne02[[buffer(5)]],
680
  constant int64_t & ne10[[buffer(9)]],
681
  constant int64_t & ne12[[buffer(11)]],
682
- constant int64_t & ne0[[buffer(15)]],
683
- constant int64_t & ne1[[buffer(16)]],
684
- constant uint & gqa[[buffer(17)]],
 
685
  uint3 tgpig[[threadgroup_position_in_grid]],
686
  uint tiisg[[thread_index_in_simdgroup]],
687
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
688
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
689
  }
690
 
691
  kernel void kernel_mul_mv_q5_1_f32(
@@ -697,13 +863,14 @@ kernel void kernel_mul_mv_q5_1_f32(
697
  constant int64_t & ne02[[buffer(5)]],
698
  constant int64_t & ne10[[buffer(9)]],
699
  constant int64_t & ne12[[buffer(11)]],
700
- constant int64_t & ne0[[buffer(15)]],
701
- constant int64_t & ne1[[buffer(16)]],
702
- constant uint & gqa[[buffer(17)]],
 
703
  uint3 tgpig[[threadgroup_position_in_grid]],
704
  uint tiisg[[thread_index_in_simdgroup]],
705
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
706
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
707
  }
708
 
709
 
@@ -718,9 +885,10 @@ kernel void kernel_mul_mv_q8_0_f32(
718
  constant int64_t & ne02[[buffer(5)]],
719
  constant int64_t & ne10[[buffer(9)]],
720
  constant int64_t & ne12[[buffer(11)]],
721
- constant int64_t & ne0[[buffer(15)]],
722
- constant int64_t & ne1[[buffer(16)]],
723
- constant uint & gqa[[buffer(17)]],
 
724
  uint3 tgpig[[threadgroup_position_in_grid]],
725
  uint tiisg[[thread_index_in_simdgroup]],
726
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -732,8 +900,14 @@ kernel void kernel_mul_mv_q8_0_f32(
732
  const int r0 = tgpig.x;
733
  const int r1 = tgpig.y;
734
  const int im = tgpig.z;
 
735
  const int first_row = (r0 * nsg + sgitg) * nr;
736
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
 
 
 
 
 
737
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
738
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
739
 
@@ -791,6 +965,8 @@ kernel void kernel_mul_mv_f32_f32(
791
  constant uint64_t & nb12,
792
  constant int64_t & ne0,
793
  constant int64_t & ne1,
 
 
794
  uint3 tgpig[[threadgroup_position_in_grid]],
795
  uint tiisg[[thread_index_in_simdgroup]]) {
796
 
@@ -798,7 +974,12 @@ kernel void kernel_mul_mv_f32_f32(
798
  const int64_t rb = tgpig.y*N_F32_F32;
799
  const int64_t im = tgpig.z;
800
 
801
- device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
 
 
 
 
 
802
 
803
  if (ne00 < 128) {
804
  for (int row = 0; row < N_F32_F32; ++row) {
@@ -864,6 +1045,8 @@ kernel void kernel_mul_mv_f16_f16(
864
  constant uint64_t & nb12,
865
  constant int64_t & ne0,
866
  constant int64_t & ne1,
 
 
867
  uint3 tgpig[[threadgroup_position_in_grid]],
868
  uint tiisg[[thread_index_in_simdgroup]]) {
869
 
@@ -871,7 +1054,12 @@ kernel void kernel_mul_mv_f16_f16(
871
  const int64_t rb = tgpig.y*N_F16_F16;
872
  const int64_t im = tgpig.z;
873
 
874
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
 
 
 
 
 
875
 
876
  if (ne00 < 128) {
877
  for (int row = 0; row < N_F16_F16; ++row) {
@@ -935,6 +1123,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
935
  constant uint64_t & nb12,
936
  constant int64_t & ne0,
937
  constant int64_t & ne1,
 
 
938
  uint3 tgpig[[threadgroup_position_in_grid]],
939
  uint tiisg[[thread_index_in_simdgroup]]) {
940
 
@@ -942,7 +1132,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
942
  const int64_t r1 = tgpig.y;
943
  const int64_t im = tgpig.z;
944
 
945
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
 
 
 
 
 
946
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
947
 
948
  float sumf = 0;
@@ -989,6 +1184,8 @@ kernel void kernel_mul_mv_f16_f32(
989
  constant uint64_t & nb12,
990
  constant int64_t & ne0,
991
  constant int64_t & ne1,
 
 
992
  uint3 tgpig[[threadgroup_position_in_grid]],
993
  uint tiisg[[thread_index_in_simdgroup]]) {
994
 
@@ -996,7 +1193,12 @@ kernel void kernel_mul_mv_f16_f32(
996
  const int64_t rb = tgpig.y*N_F16_F32;
997
  const int64_t im = tgpig.z;
998
 
999
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
 
 
 
 
 
1000
 
1001
  if (ne00 < 128) {
1002
  for (int row = 0; row < N_F16_F32; ++row) {
@@ -1061,6 +1263,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1061
  constant uint64_t & nb12,
1062
  constant int64_t & ne0,
1063
  constant int64_t & ne1,
 
 
1064
  uint3 tgpig[[threadgroup_position_in_grid]],
1065
  uint tiisg[[thread_index_in_simdgroup]]) {
1066
 
@@ -1068,7 +1272,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
1068
  const int64_t r0 = tgpig.x;
1069
  const int64_t im = tgpig.z;
1070
 
1071
- device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
 
 
 
 
 
1072
 
1073
  for (int r1 = 0; r1 < nrows; ++r1) {
1074
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
@@ -1120,17 +1329,21 @@ kernel void kernel_alibi_f32(
1120
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1121
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1122
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
 
1123
 
1124
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1125
  float m_k;
1126
- if (i2 < n_heads_log2_floor) {
1127
- m_k = pow(m0, i2 + 1);
1128
  } else {
1129
- m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
1130
  }
 
 
 
1131
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1132
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1133
- dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
 
1134
  }
1135
  }
1136
 
@@ -1335,6 +1548,58 @@ kernel void kernel_im2col_f16(
1335
  }
1336
  }
1337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1338
  kernel void kernel_cpy_f16_f16(
1339
  device const half * src0,
1340
  device half * dst,
@@ -1460,6 +1725,197 @@ kernel void kernel_cpy_f32_f32(
1460
  }
1461
  }
1462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1463
  kernel void kernel_concat(
1464
  device const char * src0,
1465
  device const char * src1,
@@ -1617,23 +2073,30 @@ kernel void kernel_mul_mv_q2_K_f32(
1617
  constant int64_t & ne02[[buffer(5)]],
1618
  constant int64_t & ne10[[buffer(9)]],
1619
  constant int64_t & ne12[[buffer(11)]],
1620
- constant int64_t & ne0[[buffer(15)]],
1621
- constant int64_t & ne1[[buffer(16)]],
1622
- constant uint & gqa[[buffer(17)]],
 
1623
  uint3 tgpig[[threadgroup_position_in_grid]],
1624
- uint tiisg[[thread_index_in_simdgroup]],
1625
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1626
 
1627
  const int nb = ne00/QK_K;
1628
  const int r0 = tgpig.x;
1629
  const int r1 = tgpig.y;
1630
- const int r2 = tgpig.z;
1631
 
1632
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1633
  const int ib_row = first_row * nb;
1634
- const uint offset0 = r2/gqa*(nb*ne0);
 
 
 
 
 
1635
  device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
1636
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
 
1637
  float yl[32];
1638
  float sumf[N_DST]={0.f}, all_sum;
1639
 
@@ -1642,11 +2105,11 @@ kernel void kernel_mul_mv_q2_K_f32(
1642
  #if QK_K == 256
1643
  const int ix = tiisg/8; // 0...3
1644
  const int it = tiisg%8; // 0...7
1645
- const int im = it/4; // 0 or 1
1646
  const int ir = it%4; // 0...3
1647
  const int is = (8*ir)/16;// 0 or 1
1648
 
1649
- device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
1650
 
1651
  for (int ib = ix; ib < nb; ib += 4) {
1652
 
@@ -1658,8 +2121,8 @@ kernel void kernel_mul_mv_q2_K_f32(
1658
  yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1659
  }
1660
 
1661
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1662
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1663
  device const half * dh = &x[ib].d;
1664
 
1665
  for (int row = 0; row < N_DST; row++) {
@@ -1746,7 +2209,7 @@ kernel void kernel_mul_mv_q2_K_f32(
1746
  for (int row = 0; row < N_DST; ++row) {
1747
  all_sum = simd_sum(sumf[row]);
1748
  if (tiisg == 0) {
1749
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1750
  }
1751
  }
1752
  }
@@ -1761,9 +2224,10 @@ kernel void kernel_mul_mv_q3_K_f32(
1761
  constant int64_t & ne02[[buffer(5)]],
1762
  constant int64_t & ne10[[buffer(9)]],
1763
  constant int64_t & ne12[[buffer(11)]],
1764
- constant int64_t & ne0[[buffer(15)]],
1765
- constant int64_t & ne1[[buffer(16)]],
1766
- constant uint & gqa[[buffer(17)]],
 
1767
  uint3 tgpig[[threadgroup_position_in_grid]],
1768
  uint tiisg[[thread_index_in_simdgroup]],
1769
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1772,12 +2236,17 @@ kernel void kernel_mul_mv_q3_K_f32(
1772
 
1773
  const int64_t r0 = tgpig.x;
1774
  const int64_t r1 = tgpig.y;
1775
- const int64_t r2 = tgpig.z;
1776
 
1777
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1778
- const uint offset0 = r2/gqa*(nb*ne0);
 
 
 
 
 
1779
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1780
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1781
 
1782
  float yl[32];
1783
 
@@ -1899,7 +2368,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1899
  }
1900
  if (tiisg == 0) {
1901
  for (int row = 0; row < 2; ++row) {
1902
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1903
  }
1904
  }
1905
  }
@@ -1913,26 +2382,33 @@ kernel void kernel_mul_mv_q3_K_f32(
1913
  constant int64_t & ne02[[buffer(5)]],
1914
  constant int64_t & ne10[[buffer(9)]],
1915
  constant int64_t & ne12[[buffer(11)]],
1916
- constant int64_t & ne0[[buffer(15)]],
1917
- constant int64_t & ne1[[buffer(16)]],
1918
- constant uint & gqa[[buffer(17)]],
 
1919
  uint3 tgpig[[threadgroup_position_in_grid]],
1920
- uint tiisg[[thread_index_in_simdgroup]],
1921
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1922
 
1923
  const int nb = ne00/QK_K;
1924
 
1925
  const int64_t r0 = tgpig.x;
1926
  const int64_t r1 = tgpig.y;
1927
- const int64_t r2 = tgpig.z;
1928
 
1929
  const int row = 2 * r0 + sgitg;
1930
- const uint offset0 = r2/gqa*(nb*ne0);
 
 
 
 
 
1931
  device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1932
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
 
1933
  const int ix = tiisg/4;
1934
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1935
- const int im = il/8; // 0, 0, 1, 1
1936
  const int in = il%8; // 0, 4, 0, 4
1937
 
1938
  float2 sum = {0.f, 0.f};
@@ -1952,7 +2428,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1952
  const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1953
 
1954
  for (int l = 0; l < 4; l += 2) {
1955
- const uint16_t hm = h[l/2] >> im;
1956
  sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1957
  + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1958
  + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -1968,7 +2444,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1968
 
1969
  const float tot = simd_sum(sumf);
1970
  if (tiisg == 0) {
1971
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1972
  }
1973
 
1974
  }
@@ -1986,10 +2462,11 @@ kernel void kernel_mul_mv_q4_K_f32(
1986
  constant int64_t & ne12 [[buffer(11)]],
1987
  constant int64_t & ne0 [[buffer(15)]],
1988
  constant int64_t & ne1 [[buffer(16)]],
1989
- constant uint & gqa [[buffer(17)]],
 
1990
  uint3 tgpig[[threadgroup_position_in_grid]],
1991
- uint tiisg[[thread_index_in_simdgroup]],
1992
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1993
 
1994
  const uint16_t kmask1 = 0x3f3f;
1995
  const uint16_t kmask2 = 0x0f0f;
@@ -1997,26 +2474,32 @@ kernel void kernel_mul_mv_q4_K_f32(
1997
 
1998
  const int ix = tiisg/8; // 0...3
1999
  const int it = tiisg%8; // 0...7
2000
- const int im = it/4; // 0 or 1
2001
  const int ir = it%4; // 0...3
2002
 
2003
  const int nb = ne00/QK_K;
2004
  const int r0 = tgpig.x;
2005
  const int r1 = tgpig.y;
2006
- const int r2 = tgpig.z;
2007
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2008
  const int first_row = r0 * N_DST;
2009
  const int ib_row = first_row * nb;
2010
- const uint offset0 = r2/gqa*(nb*ne0);
 
 
 
 
 
2011
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2012
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
 
2013
  float yl[16];
2014
  float yh[16];
2015
  float sumf[N_DST]={0.f}, all_sum;
2016
 
2017
  const int step = sizeof(block_q4_K) * nb / 2;
2018
 
2019
- device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
2020
 
2021
  uint16_t sc16[4];
2022
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -2031,8 +2514,8 @@ kernel void kernel_mul_mv_q4_K_f32(
2031
  yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
2032
  }
2033
 
2034
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
2035
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2036
  device const half * dh = &x[ib].d;
2037
 
2038
  for (int row = 0; row < N_DST; row++) {
@@ -2076,7 +2559,7 @@ kernel void kernel_mul_mv_q4_K_f32(
2076
  for (int row = 0; row < N_DST; ++row) {
2077
  all_sum = simd_sum(sumf[row]);
2078
  if (tiisg == 0) {
2079
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2080
  }
2081
  }
2082
  }
@@ -2090,9 +2573,10 @@ kernel void kernel_mul_mv_q4_K_f32(
2090
  constant int64_t & ne02[[buffer(5)]],
2091
  constant int64_t & ne10[[buffer(9)]],
2092
  constant int64_t & ne12[[buffer(11)]],
2093
- constant int64_t & ne0[[buffer(15)]],
2094
- constant int64_t & ne1[[buffer(16)]],
2095
- constant uint & gqa[[buffer(17)]],
 
2096
  uint3 tgpig[[threadgroup_position_in_grid]],
2097
  uint tiisg[[thread_index_in_simdgroup]],
2098
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2103,12 +2587,18 @@ kernel void kernel_mul_mv_q4_K_f32(
2103
  const int nb = ne00/QK_K;
2104
  const int r0 = tgpig.x;
2105
  const int r1 = tgpig.y;
2106
- const int r2 = tgpig.z;
2107
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2108
  const int ib_row = first_row * nb;
2109
- const uint offset0 = r2/gqa*(nb*ne0);
 
 
 
 
 
2110
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2111
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
 
2112
  float yl[8];
2113
  float yh[8];
2114
  float sumf[N_DST]={0.f}, all_sum;
@@ -2164,7 +2654,7 @@ kernel void kernel_mul_mv_q4_K_f32(
2164
  for (int row = 0; row < N_DST; ++row) {
2165
  all_sum = simd_sum(sumf[row]);
2166
  if (tiisg == 0) {
2167
- dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
2168
  }
2169
  }
2170
  }
@@ -2179,9 +2669,10 @@ kernel void kernel_mul_mv_q5_K_f32(
2179
  constant int64_t & ne02[[buffer(5)]],
2180
  constant int64_t & ne10[[buffer(9)]],
2181
  constant int64_t & ne12[[buffer(11)]],
2182
- constant int64_t & ne0[[buffer(15)]],
2183
- constant int64_t & ne1[[buffer(16)]],
2184
- constant uint & gqa[[buffer(17)]],
 
2185
  uint3 tgpig[[threadgroup_position_in_grid]],
2186
  uint tiisg[[thread_index_in_simdgroup]],
2187
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2190,12 +2681,17 @@ kernel void kernel_mul_mv_q5_K_f32(
2190
 
2191
  const int64_t r0 = tgpig.x;
2192
  const int64_t r1 = tgpig.y;
2193
- const int r2 = tgpig.z;
2194
 
2195
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2196
- const uint offset0 = r2/gqa*(nb*ne0);
 
 
 
 
 
2197
  device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
2198
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2199
 
2200
  float sumf[2]={0.f};
2201
 
@@ -2211,15 +2707,15 @@ kernel void kernel_mul_mv_q5_K_f32(
2211
 
2212
  const int tid = tiisg/4;
2213
  const int ix = tiisg%4;
2214
- const int im = tid/4;
2215
  const int ir = tid%4;
2216
  const int n = 8;
2217
 
2218
  const int l0 = n*ir;
2219
- const int q_offset = 32*im + l0;
2220
- const int y_offset = 64*im + l0;
2221
 
2222
- const uint8_t hm1 = 1u << (2*im);
2223
  const uint8_t hm2 = hm1 << 1;
2224
  const uint8_t hm3 = hm1 << 4;
2225
  const uint8_t hm4 = hm2 << 4;
@@ -2234,7 +2730,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2234
  device const uint8_t * q1 = x[i].qs + q_offset;
2235
  device const uint8_t * qh = x[i].qh + l0;
2236
  device const half * dh = &x[i].d;
2237
- device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
2238
 
2239
  device const float * y2 = y1 + 128;
2240
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -2290,7 +2786,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2290
 
2291
  const int il = 4 * (tiisg/8); // 0, 4, 8, 12
2292
  const int ix = tiisg%8;
2293
- const int im = il/8; // 0, 0, 1, 1
2294
  const int in = il%8; // 0, 4, 0, 4
2295
 
2296
  device const float * y = yy + ix*QK_K + il;
@@ -2315,7 +2811,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2315
 
2316
  float2 acc = {0.f, 0.f};
2317
  for (int l = 0; l < 4; ++l) {
2318
- const uint8_t hl = h[l] >> im;
2319
  acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
2320
  + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
2321
  acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -2337,7 +2833,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2337
  for (int row = 0; row < 2; ++row) {
2338
  const float tot = simd_sum(sumf[row]);
2339
  if (tiisg == 0) {
2340
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
2341
  }
2342
  }
2343
 
@@ -2352,9 +2848,10 @@ kernel void kernel_mul_mv_q6_K_f32(
2352
  constant int64_t & ne02[[buffer(5)]],
2353
  constant int64_t & ne10[[buffer(9)]],
2354
  constant int64_t & ne12[[buffer(11)]],
2355
- constant int64_t & ne0[[buffer(15)]],
2356
- constant int64_t & ne1[[buffer(16)]],
2357
- constant uint & gqa[[buffer(17)]],
 
2358
  uint3 tgpig[[threadgroup_position_in_grid]],
2359
  uint tiisg[[thread_index_in_simdgroup]],
2360
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2368,12 +2865,17 @@ kernel void kernel_mul_mv_q6_K_f32(
2368
 
2369
  const int64_t r0 = tgpig.x;
2370
  const int64_t r1 = tgpig.y;
2371
- const int r2 = tgpig.z;
2372
 
2373
  const int row = 2 * r0 + sgitg;
2374
- const uint offset0 = r2/gqa*(nb*ne0);
 
 
 
 
 
2375
  device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
2376
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2377
 
2378
  float sumf = 0;
2379
 
@@ -2439,7 +2941,7 @@ kernel void kernel_mul_mv_q6_K_f32(
2439
 
2440
  const float tot = simd_sum(sumf);
2441
  if (tiisg == 0) {
2442
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2443
  }
2444
  }
2445
 
@@ -2749,24 +3251,25 @@ kernel void kernel_get_rows(
2749
 
2750
  // each block_q contains 16*nl weights
2751
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2752
- kernel void kernel_mul_mm(device const uchar * src0,
2753
- device const uchar * src1,
2754
- device float * dst,
2755
- constant int64_t & ne00,
2756
- constant int64_t & ne02,
2757
- constant int64_t & nb01,
2758
- constant int64_t & nb02,
2759
- constant int64_t & ne12,
2760
- constant int64_t & nb10,
2761
- constant int64_t & nb11,
2762
- constant int64_t & nb12,
2763
- constant int64_t & ne0,
2764
- constant int64_t & ne1,
2765
- constant uint & gqa,
2766
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2767
- uint3 tgpig[[threadgroup_position_in_grid]],
2768
- uint tiitg[[thread_index_in_threadgroup]],
2769
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2770
 
2771
  threadgroup half * sa = (threadgroup half *)(shared_memory);
2772
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2792,7 +3295,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
2792
 
2793
  short il = (tiitg % THREAD_PER_ROW);
2794
 
2795
- uint offset0 = im/gqa*nb02;
 
 
 
2796
  ushort offset1 = il/nl;
2797
 
2798
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -2876,14 +3382,116 @@ kernel void kernel_mul_mm(device const uchar * src0,
2876
  }
2877
  }
2878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2879
  #if QK_K == 256
2880
  #define QK_NL 16
2881
  #else
2882
  #define QK_NL 4
2883
  #endif
2884
 
2885
- typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2886
- constant uint64_t &, constant uint64_t &, uint, uint, uint);
 
 
 
 
 
 
2887
 
2888
  template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2889
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
@@ -2912,8 +3520,10 @@ typedef void (mat_mm_t)(
2912
  constant int64_t & nb12,
2913
  constant int64_t & ne0,
2914
  constant int64_t & ne1,
2915
- constant uint & gqa,
2916
- threadgroup uchar *, uint3, uint, uint);
 
 
2917
 
2918
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2919
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
@@ -2927,3 +3537,44 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
2927
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2928
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2929
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  using namespace metal;
4
 
5
  #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+ #define MIN(x, y) ((x) < (y) ? (x) : (y))
7
+ #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
8
 
9
  #define QK4_0 32
10
  #define QR4_0 2
 
41
  int8_t qs[QK8_0]; // quants
42
  } block_q8_0;
43
 
44
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
45
+
46
+ enum ggml_sort_order {
47
+ GGML_SORT_ASC,
48
+ GGML_SORT_DESC,
49
+ };
50
+
51
+ // general-purpose kernel for addition, multiplication and division of two tensors
52
+ // pros: works for non-contiguous tensors, supports broadcast across all dims
53
  // cons: not very efficient
54
  kernel void kernel_add(
55
  device const char * src0,
 
90
  const int64_t i12 = i02 % ne12;
91
  const int64_t i11 = i01 % ne11;
92
 
93
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
94
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
95
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
96
+
97
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
98
+ const int i10 = i0 % ne10;
99
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
100
+ }
101
+ }
102
+
103
+ kernel void kernel_mul(
104
+ device const char * src0,
105
+ device const char * src1,
106
+ device char * dst,
107
+ constant int64_t & ne00,
108
+ constant int64_t & ne01,
109
+ constant int64_t & ne02,
110
+ constant int64_t & ne03,
111
+ constant int64_t & nb00,
112
+ constant int64_t & nb01,
113
+ constant int64_t & nb02,
114
+ constant int64_t & nb03,
115
+ constant int64_t & ne10,
116
+ constant int64_t & ne11,
117
+ constant int64_t & ne12,
118
+ constant int64_t & ne13,
119
+ constant int64_t & nb10,
120
+ constant int64_t & nb11,
121
+ constant int64_t & nb12,
122
+ constant int64_t & nb13,
123
+ constant int64_t & ne0,
124
+ constant int64_t & ne1,
125
+ constant int64_t & ne2,
126
+ constant int64_t & ne3,
127
+ constant int64_t & nb0,
128
+ constant int64_t & nb1,
129
+ constant int64_t & nb2,
130
+ constant int64_t & nb3,
131
+ uint3 tgpig[[threadgroup_position_in_grid]],
132
+ uint3 tpitg[[thread_position_in_threadgroup]],
133
+ uint3 ntg[[threads_per_threadgroup]]) {
134
+ const int64_t i03 = tgpig.z;
135
+ const int64_t i02 = tgpig.y;
136
+ const int64_t i01 = tgpig.x;
137
+
138
+ const int64_t i13 = i03 % ne13;
139
+ const int64_t i12 = i02 % ne12;
140
+ const int64_t i11 = i01 % ne11;
141
+
142
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
143
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
144
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
145
 
146
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
147
+ const int i10 = i0 % ne10;
148
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
149
+ }
150
+ }
151
+
152
+ kernel void kernel_div(
153
+ device const char * src0,
154
+ device const char * src1,
155
+ device char * dst,
156
+ constant int64_t & ne00,
157
+ constant int64_t & ne01,
158
+ constant int64_t & ne02,
159
+ constant int64_t & ne03,
160
+ constant int64_t & nb00,
161
+ constant int64_t & nb01,
162
+ constant int64_t & nb02,
163
+ constant int64_t & nb03,
164
+ constant int64_t & ne10,
165
+ constant int64_t & ne11,
166
+ constant int64_t & ne12,
167
+ constant int64_t & ne13,
168
+ constant int64_t & nb10,
169
+ constant int64_t & nb11,
170
+ constant int64_t & nb12,
171
+ constant int64_t & nb13,
172
+ constant int64_t & ne0,
173
+ constant int64_t & ne1,
174
+ constant int64_t & ne2,
175
+ constant int64_t & ne3,
176
+ constant int64_t & nb0,
177
+ constant int64_t & nb1,
178
+ constant int64_t & nb2,
179
+ constant int64_t & nb3,
180
+ uint3 tgpig[[threadgroup_position_in_grid]],
181
+ uint3 tpitg[[thread_position_in_threadgroup]],
182
+ uint3 ntg[[threads_per_threadgroup]]) {
183
+ const int64_t i03 = tgpig.z;
184
+ const int64_t i02 = tgpig.y;
185
+ const int64_t i01 = tgpig.x;
186
+
187
+ const int64_t i13 = i03 % ne13;
188
+ const int64_t i12 = i02 % ne12;
189
+ const int64_t i11 = i01 % ne11;
190
+
191
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
192
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
193
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
194
 
195
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
196
+ const int i10 = i0 % ne10;
197
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
198
  }
199
  }
200
 
 
209
  dst[tpig] = src0[tpig] + src1[tpig % nb];
210
  }
211
 
212
+ kernel void kernel_mul_row(
213
  device const float4 * src0,
214
  device const float4 * src1,
215
  device float4 * dst,
216
+ constant int64_t & nb [[buffer(27)]],
217
  uint tpig[[thread_position_in_grid]]) {
218
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
219
  }
220
 
221
+ kernel void kernel_div_row(
 
 
222
  device const float4 * src0,
223
  device const float4 * src1,
224
  device float4 * dst,
225
+ constant int64_t & nb [[buffer(27)]],
226
  uint tpig[[thread_position_in_grid]]) {
227
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
228
  }
229
 
230
  kernel void kernel_scale(
 
265
  dst[tpig] = src0[tpig] * src0[tpig];
266
  }
267
 
268
+ kernel void kernel_sum_rows(
269
+ device const float * src0,
270
+ device float * dst,
271
+ constant int64_t & ne00,
272
+ constant int64_t & ne01,
273
+ constant int64_t & ne02,
274
+ constant int64_t & ne03,
275
+ constant int64_t & nb00,
276
+ constant int64_t & nb01,
277
+ constant int64_t & nb02,
278
+ constant int64_t & nb03,
279
+ constant int64_t & ne10,
280
+ constant int64_t & ne11,
281
+ constant int64_t & ne12,
282
+ constant int64_t & ne13,
283
+ constant int64_t & nb10,
284
+ constant int64_t & nb11,
285
+ constant int64_t & nb12,
286
+ constant int64_t & nb13,
287
+ constant int64_t & ne0,
288
+ constant int64_t & ne1,
289
+ constant int64_t & ne2,
290
+ constant int64_t & ne3,
291
+ constant int64_t & nb0,
292
+ constant int64_t & nb1,
293
+ constant int64_t & nb2,
294
+ constant int64_t & nb3,
295
+ uint3 tpig[[thread_position_in_grid]]) {
296
+ int64_t i3 = tpig.z;
297
+ int64_t i2 = tpig.y;
298
+ int64_t i1 = tpig.x;
299
+
300
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
301
+ return;
302
+ }
303
+
304
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
305
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
306
+
307
+ float row_sum = 0;
308
+
309
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
310
+ row_sum += src_row[i0];
311
+ }
312
+
313
+ dst_row[0] = row_sum;
314
+ }
315
+
316
  constant float GELU_COEF_A = 0.044715f;
317
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
318
 
 
331
 
332
  kernel void kernel_soft_max(
333
  device const float * src0,
334
+ device const float * src1,
335
  device float * dst,
336
  constant int64_t & ne00,
337
  constant int64_t & ne01,
338
  constant int64_t & ne02,
339
+ constant float & scale,
340
  threadgroup float * buf [[threadgroup(0)]],
341
  uint tgpig[[threadgroup_position_in_grid]],
342
  uint tpitg[[thread_position_in_threadgroup]],
 
347
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
348
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
349
 
350
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
351
+ device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
352
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
353
 
354
  // parallel max
355
+ float lmax = -INFINITY;
356
 
357
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
358
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
359
  }
360
 
361
+ // find the max value in the block
362
+ float max_val = simd_max(lmax);
363
+ if (ntg > N_SIMDWIDTH) {
364
+ if (sgitg == 0) {
365
+ buf[tiisg] = -INFINITY;
366
+ }
367
 
368
+ threadgroup_barrier(mem_flags::mem_threadgroup);
369
 
370
+ if (tiisg == 0) {
371
+ buf[sgitg] = max_val;
372
+ }
 
 
 
373
 
374
+ threadgroup_barrier(mem_flags::mem_threadgroup);
375
 
376
+ max_val = buf[tiisg];
377
+ max_val = simd_max(max_val);
378
+ }
379
 
380
  // parallel sum
381
  float lsum = 0.0f;
382
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
383
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
384
  lsum += exp_psrc0;
 
 
385
  pdst[i00] = exp_psrc0;
386
  }
387
 
388
  float sum = simd_sum(lsum);
389
+ if (ntg > N_SIMDWIDTH) {
390
+ if (sgitg == 0) {
391
+ buf[tiisg] = 0.0f;
392
+ }
393
 
394
+ threadgroup_barrier(mem_flags::mem_threadgroup);
395
 
396
+ if (tiisg == 0) {
397
+ buf[sgitg] = sum;
398
+ }
 
 
 
399
 
400
+ threadgroup_barrier(mem_flags::mem_threadgroup);
401
+
402
+ sum = buf[tiisg];
403
+ sum = simd_sum(sum);
404
+ }
405
 
406
+ const float inv_sum = 1.0f/sum;
407
 
408
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
409
+ pdst[i00] *= inv_sum;
410
  }
411
  }
412
 
413
  kernel void kernel_soft_max_4(
414
  device const float * src0,
415
+ device const float * src1,
416
  device float * dst,
417
  constant int64_t & ne00,
418
  constant int64_t & ne01,
419
  constant int64_t & ne02,
420
+ constant float & scale,
421
  threadgroup float * buf [[threadgroup(0)]],
422
  uint tgpig[[threadgroup_position_in_grid]],
423
  uint tpitg[[thread_position_in_threadgroup]],
 
428
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
429
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
430
 
431
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
432
+ device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
433
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
434
 
435
  // parallel max
436
+ float4 lmax4 = -INFINITY;
437
 
438
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
439
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
440
  }
441
 
442
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
 
 
 
 
443
 
444
+ float max_val = simd_max(lmax);
445
+ if (ntg > N_SIMDWIDTH) {
446
+ if (sgitg == 0) {
447
+ buf[tiisg] = -INFINITY;
448
+ }
449
 
450
+ threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
 
 
 
451
 
452
+ if (tiisg == 0) {
453
+ buf[sgitg] = max_val;
454
+ }
455
 
456
+ threadgroup_barrier(mem_flags::mem_threadgroup);
457
+
458
+ max_val = buf[tiisg];
459
+ max_val = simd_max(max_val);
460
+ }
461
 
462
  // parallel sum
463
  float4 lsum4 = 0.0f;
464
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
465
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
466
  lsum4 += exp_psrc4;
467
  pdst4[i00] = exp_psrc4;
468
  }
469
 
470
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
471
  float sum = simd_sum(lsum);
472
+ if (ntg > N_SIMDWIDTH) {
473
+ if (sgitg == 0) {
474
+ buf[tiisg] = 0.0f;
475
+ }
476
 
477
+ threadgroup_barrier(mem_flags::mem_threadgroup);
478
 
479
+ if (tiisg == 0) {
480
+ buf[sgitg] = sum;
481
+ }
 
 
 
482
 
483
+ threadgroup_barrier(mem_flags::mem_threadgroup);
484
+
485
+ sum = buf[tiisg];
486
+ sum = simd_sum(sum);
487
+ }
488
 
489
+ const float inv_sum = 1.0f/sum;
490
 
491
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
492
+ pdst4[i00] *= inv_sum;
493
  }
494
  }
495
 
 
596
  constant int64_t & ne00,
597
  constant uint64_t & nb01,
598
  constant float & eps,
599
+ threadgroup float * buf [[threadgroup(0)]],
600
  uint tgpig[[threadgroup_position_in_grid]],
601
  uint tpitg[[thread_position_in_threadgroup]],
602
  uint sgitg[[simdgroup_index_in_threadgroup]],
603
  uint tiisg[[thread_index_in_simdgroup]],
604
  uint ntg[[threads_per_threadgroup]]) {
605
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
 
606
 
607
  float4 sumf = 0;
608
  float all_sum = 0;
 
613
  }
614
  all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
615
  all_sum = simd_sum(all_sum);
616
+ if (ntg > N_SIMDWIDTH) {
617
+ if (sgitg == 0) {
618
+ buf[tiisg] = 0.0f;
619
+ }
620
 
621
+ threadgroup_barrier(mem_flags::mem_threadgroup);
622
 
623
+ if (tiisg == 0) {
624
+ buf[sgitg] = all_sum;
 
 
 
 
 
 
 
625
  }
 
 
626
 
627
+ threadgroup_barrier(mem_flags::mem_threadgroup);
628
+
629
+ all_sum = buf[tiisg];
630
+ all_sum = simd_sum(all_sum);
631
+ }
632
 
633
+ const float mean = all_sum/ne00;
634
  const float scale = 1.0f/sqrt(mean + eps);
635
 
636
  device float4 * y = (device float4 *) (dst + tgpig*ne00);
 
637
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
638
  y[i00] = x[i00] * scale;
639
  }
 
 
 
 
 
640
  }
641
 
642
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
 
726
  // putting them in the kernel cause a significant performance penalty
727
  #define N_DST 4 // each SIMD group works on 4 rows
728
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
 
729
  //Note: This is a template, but strictly speaking it only applies to
730
  // quantizations where the block size is 32. It also does not
731
  // giard against the number of rows not being divisible by
732
  // N_DST, so this is another explicit assumption of the implementation.
733
  template<typename block_q_type, int nr, int nsg, int nw>
734
+ void mul_vec_q_n_f32(
735
+ device const void * src0,
736
+ device const float * src1,
737
+ device float * dst,
738
+ int64_t ne00,
739
+ int64_t ne01,
740
+ int64_t ne02,
741
+ int64_t ne10,
742
+ int64_t ne12,
743
+ int64_t ne0,
744
+ int64_t ne1,
745
+ uint r2,
746
+ uint r3,
747
+ uint3 tgpig, uint tiisg, uint sgitg) {
748
  const int nb = ne00/QK4_0;
749
 
750
  const int r0 = tgpig.x;
 
753
 
754
  const int first_row = (r0 * nsg + sgitg) * nr;
755
 
756
+ const uint i12 = im%ne12;
757
+ const uint i13 = im/ne12;
758
+
759
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
760
 
761
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
762
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
 
806
  constant int64_t & ne02[[buffer(5)]],
807
  constant int64_t & ne10[[buffer(9)]],
808
  constant int64_t & ne12[[buffer(11)]],
809
+ constant int64_t & ne0 [[buffer(15)]],
810
+ constant int64_t & ne1 [[buffer(16)]],
811
+ constant uint & r2 [[buffer(17)]],
812
+ constant uint & r3 [[buffer(18)]],
813
  uint3 tgpig[[threadgroup_position_in_grid]],
814
  uint tiisg[[thread_index_in_simdgroup]],
815
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
816
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
817
  }
818
 
819
  kernel void kernel_mul_mv_q4_1_f32(
 
825
  constant int64_t & ne02[[buffer(5)]],
826
  constant int64_t & ne10[[buffer(9)]],
827
  constant int64_t & ne12[[buffer(11)]],
828
+ constant int64_t & ne0 [[buffer(15)]],
829
+ constant int64_t & ne1 [[buffer(16)]],
830
+ constant uint & r2 [[buffer(17)]],
831
+ constant uint & r3 [[buffer(18)]],
832
  uint3 tgpig[[threadgroup_position_in_grid]],
833
  uint tiisg[[thread_index_in_simdgroup]],
834
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
835
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
836
  }
837
 
838
  kernel void kernel_mul_mv_q5_0_f32(
 
844
  constant int64_t & ne02[[buffer(5)]],
845
  constant int64_t & ne10[[buffer(9)]],
846
  constant int64_t & ne12[[buffer(11)]],
847
+ constant int64_t & ne0 [[buffer(15)]],
848
+ constant int64_t & ne1 [[buffer(16)]],
849
+ constant uint & r2 [[buffer(17)]],
850
+ constant uint & r3 [[buffer(18)]],
851
  uint3 tgpig[[threadgroup_position_in_grid]],
852
  uint tiisg[[thread_index_in_simdgroup]],
853
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
854
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
855
  }
856
 
857
  kernel void kernel_mul_mv_q5_1_f32(
 
863
  constant int64_t & ne02[[buffer(5)]],
864
  constant int64_t & ne10[[buffer(9)]],
865
  constant int64_t & ne12[[buffer(11)]],
866
+ constant int64_t & ne0 [[buffer(15)]],
867
+ constant int64_t & ne1 [[buffer(16)]],
868
+ constant uint & r2 [[buffer(17)]],
869
+ constant uint & r3 [[buffer(18)]],
870
  uint3 tgpig[[threadgroup_position_in_grid]],
871
  uint tiisg[[thread_index_in_simdgroup]],
872
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
873
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
874
  }
875
 
876
 
 
885
  constant int64_t & ne02[[buffer(5)]],
886
  constant int64_t & ne10[[buffer(9)]],
887
  constant int64_t & ne12[[buffer(11)]],
888
+ constant int64_t & ne0 [[buffer(15)]],
889
+ constant int64_t & ne1 [[buffer(16)]],
890
+ constant uint & r2 [[buffer(17)]],
891
+ constant uint & r3 [[buffer(18)]],
892
  uint3 tgpig[[threadgroup_position_in_grid]],
893
  uint tiisg[[thread_index_in_simdgroup]],
894
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
900
  const int r0 = tgpig.x;
901
  const int r1 = tgpig.y;
902
  const int im = tgpig.z;
903
+
904
  const int first_row = (r0 * nsg + sgitg) * nr;
905
+
906
+ const uint i12 = im%ne12;
907
+ const uint i13 = im/ne12;
908
+
909
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
910
+
911
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
912
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
913
 
 
965
  constant uint64_t & nb12,
966
  constant int64_t & ne0,
967
  constant int64_t & ne1,
968
+ constant uint & r2 [[buffer(17)]],
969
+ constant uint & r3 [[buffer(18)]],
970
  uint3 tgpig[[threadgroup_position_in_grid]],
971
  uint tiisg[[thread_index_in_simdgroup]]) {
972
 
 
974
  const int64_t rb = tgpig.y*N_F32_F32;
975
  const int64_t im = tgpig.z;
976
 
977
+ const uint i12 = im%ne12;
978
+ const uint i13 = im/ne12;
979
+
980
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
981
+
982
+ device const float * x = (device const float *) (src0 + offset0);
983
 
984
  if (ne00 < 128) {
985
  for (int row = 0; row < N_F32_F32; ++row) {
 
1045
  constant uint64_t & nb12,
1046
  constant int64_t & ne0,
1047
  constant int64_t & ne1,
1048
+ constant uint & r2 [[buffer(17)]],
1049
+ constant uint & r3 [[buffer(18)]],
1050
  uint3 tgpig[[threadgroup_position_in_grid]],
1051
  uint tiisg[[thread_index_in_simdgroup]]) {
1052
 
 
1054
  const int64_t rb = tgpig.y*N_F16_F16;
1055
  const int64_t im = tgpig.z;
1056
 
1057
+ const uint i12 = im%ne12;
1058
+ const uint i13 = im/ne12;
1059
+
1060
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1061
+
1062
+ device const half * x = (device const half *) (src0 + offset0);
1063
 
1064
  if (ne00 < 128) {
1065
  for (int row = 0; row < N_F16_F16; ++row) {
 
1123
  constant uint64_t & nb12,
1124
  constant int64_t & ne0,
1125
  constant int64_t & ne1,
1126
+ constant uint & r2 [[buffer(17)]],
1127
+ constant uint & r3 [[buffer(18)]],
1128
  uint3 tgpig[[threadgroup_position_in_grid]],
1129
  uint tiisg[[thread_index_in_simdgroup]]) {
1130
 
 
1132
  const int64_t r1 = tgpig.y;
1133
  const int64_t im = tgpig.z;
1134
 
1135
+ const uint i12 = im%ne12;
1136
+ const uint i13 = im/ne12;
1137
+
1138
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1139
+
1140
+ device const half * x = (device const half *) (src0 + offset0);
1141
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1142
 
1143
  float sumf = 0;
 
1184
  constant uint64_t & nb12,
1185
  constant int64_t & ne0,
1186
  constant int64_t & ne1,
1187
+ constant uint & r2 [[buffer(17)]],
1188
+ constant uint & r3 [[buffer(18)]],
1189
  uint3 tgpig[[threadgroup_position_in_grid]],
1190
  uint tiisg[[thread_index_in_simdgroup]]) {
1191
 
 
1193
  const int64_t rb = tgpig.y*N_F16_F32;
1194
  const int64_t im = tgpig.z;
1195
 
1196
+ const uint i12 = im%ne12;
1197
+ const uint i13 = im/ne12;
1198
+
1199
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1200
+
1201
+ device const half * x = (device const half *) (src0 + offset0);
1202
 
1203
  if (ne00 < 128) {
1204
  for (int row = 0; row < N_F16_F32; ++row) {
 
1263
  constant uint64_t & nb12,
1264
  constant int64_t & ne0,
1265
  constant int64_t & ne1,
1266
+ constant uint & r2 [[buffer(17)]],
1267
+ constant uint & r3 [[buffer(18)]],
1268
  uint3 tgpig[[threadgroup_position_in_grid]],
1269
  uint tiisg[[thread_index_in_simdgroup]]) {
1270
 
 
1272
  const int64_t r0 = tgpig.x;
1273
  const int64_t im = tgpig.z;
1274
 
1275
+ const uint i12 = im%ne12;
1276
+ const uint i13 = im/ne12;
1277
+
1278
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1279
+
1280
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
1281
 
1282
  for (int r1 = 0; r1 < nrows; ++r1) {
1283
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
 
1329
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1330
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1331
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1332
+ const int64_t k = i3*ne3 + i2;
1333
 
 
1334
  float m_k;
1335
+ if (k < n_heads_log2_floor) {
1336
+ m_k = pow(m0, k + 1);
1337
  } else {
1338
+ m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1339
  }
1340
+
1341
+ device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1342
+ device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1343
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1344
+ const float src_v = *(device float *)(src_row + i00*nb00);
1345
+ device float * dst_v = (device float *)(dst_row + i00*nb0);
1346
+ *dst_v = i00 * m_k + src_v;
1347
  }
1348
  }
1349
 
 
1548
  }
1549
  }
1550
 
1551
+ // bitonic sort implementation following the CUDA kernels as reference
1552
+ typedef void (argsort_t)(
1553
+ device const float * x,
1554
+ device int32_t * dst,
1555
+ constant int64_t & ncols,
1556
+ uint3 tgpig[[threadgroup_position_in_grid]],
1557
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1558
+
1559
+ template<ggml_sort_order order>
1560
+ kernel void kernel_argsort_f32_i32(
1561
+ device const float * x,
1562
+ device int32_t * dst,
1563
+ constant int64_t & ncols,
1564
+ uint3 tgpig[[threadgroup_position_in_grid]],
1565
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1566
+ // bitonic sort
1567
+ int col = tpitg[0];
1568
+ int row = tgpig[1];
1569
+
1570
+ if (col >= ncols) return;
1571
+
1572
+ device const float * x_row = x + row * ncols;
1573
+ device int32_t * dst_row = dst + row * ncols;
1574
+
1575
+ // initialize indices
1576
+ if (col < ncols) {
1577
+ dst_row[col] = col;
1578
+ }
1579
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1580
+
1581
+ for (int k = 2; k <= ncols; k *= 2) {
1582
+ for (int j = k / 2; j > 0; j /= 2) {
1583
+ int ixj = col ^ j;
1584
+ if (ixj > col) {
1585
+ if ((col & k) == 0) {
1586
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
1587
+ SWAP(dst_row[col], dst_row[ixj]);
1588
+ }
1589
+ } else {
1590
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
1591
+ SWAP(dst_row[col], dst_row[ixj]);
1592
+ }
1593
+ }
1594
+ }
1595
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1596
+ }
1597
+ }
1598
+ }
1599
+
1600
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1601
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1602
+
1603
  kernel void kernel_cpy_f16_f16(
1604
  device const half * src0,
1605
  device half * dst,
 
1725
  }
1726
  }
1727
 
1728
+ kernel void kernel_cpy_f32_q8_0(
1729
+ device const float * src0,
1730
+ device void * dst,
1731
+ constant int64_t & ne00,
1732
+ constant int64_t & ne01,
1733
+ constant int64_t & ne02,
1734
+ constant int64_t & ne03,
1735
+ constant uint64_t & nb00,
1736
+ constant uint64_t & nb01,
1737
+ constant uint64_t & nb02,
1738
+ constant uint64_t & nb03,
1739
+ constant int64_t & ne0,
1740
+ constant int64_t & ne1,
1741
+ constant int64_t & ne2,
1742
+ constant int64_t & ne3,
1743
+ constant uint64_t & nb0,
1744
+ constant uint64_t & nb1,
1745
+ constant uint64_t & nb2,
1746
+ constant uint64_t & nb3,
1747
+ uint3 tgpig[[threadgroup_position_in_grid]],
1748
+ uint3 tpitg[[thread_position_in_threadgroup]],
1749
+ uint3 ntg[[threads_per_threadgroup]]) {
1750
+ const int64_t i03 = tgpig[2];
1751
+ const int64_t i02 = tgpig[1];
1752
+ const int64_t i01 = tgpig[0];
1753
+
1754
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1755
+
1756
+ const int64_t i3 = n / (ne2*ne1*ne0);
1757
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1758
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1759
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
1760
+
1761
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1762
+
1763
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
1764
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1765
+
1766
+ float amax = 0.0f; // absolute max
1767
+
1768
+ for (int j = 0; j < QK8_0; j++) {
1769
+ const float v = src[j];
1770
+ amax = MAX(amax, fabs(v));
1771
+ }
1772
+
1773
+ const float d = amax / ((1 << 7) - 1);
1774
+ const float id = d ? 1.0f/d : 0.0f;
1775
+
1776
+ dst_data[i00/QK8_0].d = d;
1777
+
1778
+ for (int j = 0; j < QK8_0; ++j) {
1779
+ const float x0 = src[j]*id;
1780
+
1781
+ dst_data[i00/QK8_0].qs[j] = round(x0);
1782
+ }
1783
+ }
1784
+ }
1785
+
1786
+ kernel void kernel_cpy_f32_q4_0(
1787
+ device const float * src0,
1788
+ device void * dst,
1789
+ constant int64_t & ne00,
1790
+ constant int64_t & ne01,
1791
+ constant int64_t & ne02,
1792
+ constant int64_t & ne03,
1793
+ constant uint64_t & nb00,
1794
+ constant uint64_t & nb01,
1795
+ constant uint64_t & nb02,
1796
+ constant uint64_t & nb03,
1797
+ constant int64_t & ne0,
1798
+ constant int64_t & ne1,
1799
+ constant int64_t & ne2,
1800
+ constant int64_t & ne3,
1801
+ constant uint64_t & nb0,
1802
+ constant uint64_t & nb1,
1803
+ constant uint64_t & nb2,
1804
+ constant uint64_t & nb3,
1805
+ uint3 tgpig[[threadgroup_position_in_grid]],
1806
+ uint3 tpitg[[thread_position_in_threadgroup]],
1807
+ uint3 ntg[[threads_per_threadgroup]]) {
1808
+ const int64_t i03 = tgpig[2];
1809
+ const int64_t i02 = tgpig[1];
1810
+ const int64_t i01 = tgpig[0];
1811
+
1812
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1813
+
1814
+ const int64_t i3 = n / (ne2*ne1*ne0);
1815
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1816
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1817
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
1818
+
1819
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820
+
1821
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
1822
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1823
+
1824
+ float amax = 0.0f; // absolute max
1825
+ float max = 0.0f;
1826
+
1827
+ for (int j = 0; j < QK4_0; j++) {
1828
+ const float v = src[j];
1829
+ if (amax < fabs(v)) {
1830
+ amax = fabs(v);
1831
+ max = v;
1832
+ }
1833
+ }
1834
+
1835
+ const float d = max / -8;
1836
+ const float id = d ? 1.0f/d : 0.0f;
1837
+
1838
+ dst_data[i00/QK4_0].d = d;
1839
+
1840
+ for (int j = 0; j < QK4_0/2; ++j) {
1841
+ const float x0 = src[0 + j]*id;
1842
+ const float x1 = src[QK4_0/2 + j]*id;
1843
+
1844
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
1845
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
1846
+
1847
+ dst_data[i00/QK4_0].qs[j] = xi0;
1848
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
1849
+ }
1850
+ }
1851
+ }
1852
+
1853
+ kernel void kernel_cpy_f32_q4_1(
1854
+ device const float * src0,
1855
+ device void * dst,
1856
+ constant int64_t & ne00,
1857
+ constant int64_t & ne01,
1858
+ constant int64_t & ne02,
1859
+ constant int64_t & ne03,
1860
+ constant uint64_t & nb00,
1861
+ constant uint64_t & nb01,
1862
+ constant uint64_t & nb02,
1863
+ constant uint64_t & nb03,
1864
+ constant int64_t & ne0,
1865
+ constant int64_t & ne1,
1866
+ constant int64_t & ne2,
1867
+ constant int64_t & ne3,
1868
+ constant uint64_t & nb0,
1869
+ constant uint64_t & nb1,
1870
+ constant uint64_t & nb2,
1871
+ constant uint64_t & nb3,
1872
+ uint3 tgpig[[threadgroup_position_in_grid]],
1873
+ uint3 tpitg[[thread_position_in_threadgroup]],
1874
+ uint3 ntg[[threads_per_threadgroup]]) {
1875
+ const int64_t i03 = tgpig[2];
1876
+ const int64_t i02 = tgpig[1];
1877
+ const int64_t i01 = tgpig[0];
1878
+
1879
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1880
+
1881
+ const int64_t i3 = n / (ne2*ne1*ne0);
1882
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1883
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1884
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
1885
+
1886
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1887
+
1888
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
1889
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1890
+
1891
+ float min = FLT_MAX;
1892
+ float max = -FLT_MAX;
1893
+
1894
+ for (int j = 0; j < QK4_1; j++) {
1895
+ const float v = src[j];
1896
+ if (min > v) min = v;
1897
+ if (max < v) max = v;
1898
+ }
1899
+
1900
+ const float d = (max - min) / ((1 << 4) - 1);
1901
+ const float id = d ? 1.0f/d : 0.0f;
1902
+
1903
+ dst_data[i00/QK4_1].d = d;
1904
+ dst_data[i00/QK4_1].m = min;
1905
+
1906
+ for (int j = 0; j < QK4_1/2; ++j) {
1907
+ const float x0 = (src[0 + j] - min)*id;
1908
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
1909
+
1910
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
1911
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
1912
+
1913
+ dst_data[i00/QK4_1].qs[j] = xi0;
1914
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
1915
+ }
1916
+ }
1917
+ }
1918
+
1919
  kernel void kernel_concat(
1920
  device const char * src0,
1921
  device const char * src1,
 
2073
  constant int64_t & ne02[[buffer(5)]],
2074
  constant int64_t & ne10[[buffer(9)]],
2075
  constant int64_t & ne12[[buffer(11)]],
2076
+ constant int64_t & ne0 [[buffer(15)]],
2077
+ constant int64_t & ne1 [[buffer(16)]],
2078
+ constant uint & r2 [[buffer(17)]],
2079
+ constant uint & r3 [[buffer(18)]],
2080
  uint3 tgpig[[threadgroup_position_in_grid]],
2081
+ uint tiisg[[thread_index_in_simdgroup]],
2082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2083
 
2084
  const int nb = ne00/QK_K;
2085
  const int r0 = tgpig.x;
2086
  const int r1 = tgpig.y;
2087
+ const int im = tgpig.z;
2088
 
2089
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2090
  const int ib_row = first_row * nb;
2091
+
2092
+ const uint i12 = im%ne12;
2093
+ const uint i13 = im/ne12;
2094
+
2095
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2096
+
2097
  device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
2098
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2099
+
2100
  float yl[32];
2101
  float sumf[N_DST]={0.f}, all_sum;
2102
 
 
2105
  #if QK_K == 256
2106
  const int ix = tiisg/8; // 0...3
2107
  const int it = tiisg%8; // 0...7
2108
+ const int iq = it/4; // 0 or 1
2109
  const int ir = it%4; // 0...3
2110
  const int is = (8*ir)/16;// 0 or 1
2111
 
2112
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
2113
 
2114
  for (int ib = ix; ib < nb; ib += 4) {
2115
 
 
2121
  yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
2122
  }
2123
 
2124
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
2125
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
2126
  device const half * dh = &x[ib].d;
2127
 
2128
  for (int row = 0; row < N_DST; row++) {
 
2209
  for (int row = 0; row < N_DST; ++row) {
2210
  all_sum = simd_sum(sumf[row]);
2211
  if (tiisg == 0) {
2212
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
2213
  }
2214
  }
2215
  }
 
2224
  constant int64_t & ne02[[buffer(5)]],
2225
  constant int64_t & ne10[[buffer(9)]],
2226
  constant int64_t & ne12[[buffer(11)]],
2227
+ constant int64_t & ne0 [[buffer(15)]],
2228
+ constant int64_t & ne1 [[buffer(16)]],
2229
+ constant uint & r2 [[buffer(17)]],
2230
+ constant uint & r3 [[buffer(18)]],
2231
  uint3 tgpig[[threadgroup_position_in_grid]],
2232
  uint tiisg[[thread_index_in_simdgroup]],
2233
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2236
 
2237
  const int64_t r0 = tgpig.x;
2238
  const int64_t r1 = tgpig.y;
2239
+ const int64_t im = tgpig.z;
2240
 
2241
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2242
+
2243
+ const uint i12 = im%ne12;
2244
+ const uint i13 = im/ne12;
2245
+
2246
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2247
+
2248
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
2249
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2250
 
2251
  float yl[32];
2252
 
 
2368
  }
2369
  if (tiisg == 0) {
2370
  for (int row = 0; row < 2; ++row) {
2371
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
2372
  }
2373
  }
2374
  }
 
2382
  constant int64_t & ne02[[buffer(5)]],
2383
  constant int64_t & ne10[[buffer(9)]],
2384
  constant int64_t & ne12[[buffer(11)]],
2385
+ constant int64_t & ne0 [[buffer(15)]],
2386
+ constant int64_t & ne1 [[buffer(16)]],
2387
+ constant uint & r2 [[buffer(17)]],
2388
+ constant uint & r3 [[buffer(18)]],
2389
  uint3 tgpig[[threadgroup_position_in_grid]],
2390
+ uint tiisg[[thread_index_in_simdgroup]],
2391
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2392
 
2393
  const int nb = ne00/QK_K;
2394
 
2395
  const int64_t r0 = tgpig.x;
2396
  const int64_t r1 = tgpig.y;
2397
+ const int64_t im = tgpig.z;
2398
 
2399
  const int row = 2 * r0 + sgitg;
2400
+
2401
+ const uint i12 = im%ne12;
2402
+ const uint i13 = im/ne12;
2403
+
2404
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2405
+
2406
  device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
2407
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2408
+
2409
  const int ix = tiisg/4;
2410
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
2411
+ const int iq = il/8; // 0, 0, 1, 1
2412
  const int in = il%8; // 0, 4, 0, 4
2413
 
2414
  float2 sum = {0.f, 0.f};
 
2428
  const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
2429
 
2430
  for (int l = 0; l < 4; l += 2) {
2431
+ const uint16_t hm = h[l/2] >> iq;
2432
  sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
2433
  + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
2434
  + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
 
2444
 
2445
  const float tot = simd_sum(sumf);
2446
  if (tiisg == 0) {
2447
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
2448
  }
2449
 
2450
  }
 
2462
  constant int64_t & ne12 [[buffer(11)]],
2463
  constant int64_t & ne0 [[buffer(15)]],
2464
  constant int64_t & ne1 [[buffer(16)]],
2465
+ constant uint & r2 [[buffer(17)]],
2466
+ constant uint & r3 [[buffer(18)]],
2467
  uint3 tgpig[[threadgroup_position_in_grid]],
2468
+ uint tiisg[[thread_index_in_simdgroup]],
2469
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2470
 
2471
  const uint16_t kmask1 = 0x3f3f;
2472
  const uint16_t kmask2 = 0x0f0f;
 
2474
 
2475
  const int ix = tiisg/8; // 0...3
2476
  const int it = tiisg%8; // 0...7
2477
+ const int iq = it/4; // 0 or 1
2478
  const int ir = it%4; // 0...3
2479
 
2480
  const int nb = ne00/QK_K;
2481
  const int r0 = tgpig.x;
2482
  const int r1 = tgpig.y;
2483
+ const int im = tgpig.z;
2484
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2485
  const int first_row = r0 * N_DST;
2486
  const int ib_row = first_row * nb;
2487
+
2488
+ const uint i12 = im%ne12;
2489
+ const uint i13 = im/ne12;
2490
+
2491
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2492
+
2493
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2494
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2495
+
2496
  float yl[16];
2497
  float yh[16];
2498
  float sumf[N_DST]={0.f}, all_sum;
2499
 
2500
  const int step = sizeof(block_q4_K) * nb / 2;
2501
 
2502
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
2503
 
2504
  uint16_t sc16[4];
2505
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
 
2514
  yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
2515
  }
2516
 
2517
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
2518
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
2519
  device const half * dh = &x[ib].d;
2520
 
2521
  for (int row = 0; row < N_DST; row++) {
 
2559
  for (int row = 0; row < N_DST; ++row) {
2560
  all_sum = simd_sum(sumf[row]);
2561
  if (tiisg == 0) {
2562
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
2563
  }
2564
  }
2565
  }
 
2573
  constant int64_t & ne02[[buffer(5)]],
2574
  constant int64_t & ne10[[buffer(9)]],
2575
  constant int64_t & ne12[[buffer(11)]],
2576
+ constant int64_t & ne0 [[buffer(15)]],
2577
+ constant int64_t & ne1 [[buffer(16)]],
2578
+ constant uint & r2 [[buffer(17)]],
2579
+ constant uint & r3 [[buffer(18)]],
2580
  uint3 tgpig[[threadgroup_position_in_grid]],
2581
  uint tiisg[[thread_index_in_simdgroup]],
2582
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2587
  const int nb = ne00/QK_K;
2588
  const int r0 = tgpig.x;
2589
  const int r1 = tgpig.y;
2590
+ const int im = tgpig.z;
2591
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2592
  const int ib_row = first_row * nb;
2593
+
2594
+ const uint i12 = im%ne12;
2595
+ const uint i13 = im/ne12;
2596
+
2597
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2598
+
2599
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2600
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2601
+
2602
  float yl[8];
2603
  float yh[8];
2604
  float sumf[N_DST]={0.f}, all_sum;
 
2654
  for (int row = 0; row < N_DST; ++row) {
2655
  all_sum = simd_sum(sumf[row]);
2656
  if (tiisg == 0) {
2657
+ dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
2658
  }
2659
  }
2660
  }
 
2669
  constant int64_t & ne02[[buffer(5)]],
2670
  constant int64_t & ne10[[buffer(9)]],
2671
  constant int64_t & ne12[[buffer(11)]],
2672
+ constant int64_t & ne0 [[buffer(15)]],
2673
+ constant int64_t & ne1 [[buffer(16)]],
2674
+ constant uint & r2 [[buffer(17)]],
2675
+ constant uint & r3 [[buffer(18)]],
2676
  uint3 tgpig[[threadgroup_position_in_grid]],
2677
  uint tiisg[[thread_index_in_simdgroup]],
2678
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2681
 
2682
  const int64_t r0 = tgpig.x;
2683
  const int64_t r1 = tgpig.y;
2684
+ const int im = tgpig.z;
2685
 
2686
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2687
+
2688
+ const uint i12 = im%ne12;
2689
+ const uint i13 = im/ne12;
2690
+
2691
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2692
+
2693
  device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
2694
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2695
 
2696
  float sumf[2]={0.f};
2697
 
 
2707
 
2708
  const int tid = tiisg/4;
2709
  const int ix = tiisg%4;
2710
+ const int iq = tid/4;
2711
  const int ir = tid%4;
2712
  const int n = 8;
2713
 
2714
  const int l0 = n*ir;
2715
+ const int q_offset = 32*iq + l0;
2716
+ const int y_offset = 64*iq + l0;
2717
 
2718
+ const uint8_t hm1 = 1u << (2*iq);
2719
  const uint8_t hm2 = hm1 << 1;
2720
  const uint8_t hm3 = hm1 << 4;
2721
  const uint8_t hm4 = hm2 << 4;
 
2730
  device const uint8_t * q1 = x[i].qs + q_offset;
2731
  device const uint8_t * qh = x[i].qh + l0;
2732
  device const half * dh = &x[i].d;
2733
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
2734
 
2735
  device const float * y2 = y1 + 128;
2736
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
 
2786
 
2787
  const int il = 4 * (tiisg/8); // 0, 4, 8, 12
2788
  const int ix = tiisg%8;
2789
+ const int iq = il/8; // 0, 0, 1, 1
2790
  const int in = il%8; // 0, 4, 0, 4
2791
 
2792
  device const float * y = yy + ix*QK_K + il;
 
2811
 
2812
  float2 acc = {0.f, 0.f};
2813
  for (int l = 0; l < 4; ++l) {
2814
+ const uint8_t hl = h[l] >> iq;
2815
  acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
2816
  + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
2817
  acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
 
2833
  for (int row = 0; row < 2; ++row) {
2834
  const float tot = simd_sum(sumf[row]);
2835
  if (tiisg == 0) {
2836
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2837
  }
2838
  }
2839
 
 
2848
  constant int64_t & ne02[[buffer(5)]],
2849
  constant int64_t & ne10[[buffer(9)]],
2850
  constant int64_t & ne12[[buffer(11)]],
2851
+ constant int64_t & ne0 [[buffer(15)]],
2852
+ constant int64_t & ne1 [[buffer(16)]],
2853
+ constant uint & r2 [[buffer(17)]],
2854
+ constant uint & r3 [[buffer(18)]],
2855
  uint3 tgpig[[threadgroup_position_in_grid]],
2856
  uint tiisg[[thread_index_in_simdgroup]],
2857
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2865
 
2866
  const int64_t r0 = tgpig.x;
2867
  const int64_t r1 = tgpig.y;
2868
+ const int im = tgpig.z;
2869
 
2870
  const int row = 2 * r0 + sgitg;
2871
+
2872
+ const uint i12 = im%ne12;
2873
+ const uint i13 = im/ne12;
2874
+
2875
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2876
+
2877
  device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
2878
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2879
 
2880
  float sumf = 0;
2881
 
 
2941
 
2942
  const float tot = simd_sum(sumf);
2943
  if (tiisg == 0) {
2944
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
2945
  }
2946
  }
2947
 
 
3251
 
3252
  // each block_q contains 16*nl weights
3253
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3254
+ void kernel_mul_mm_impl(device const uchar * src0,
3255
+ device const uchar * src1,
3256
+ device float * dst,
3257
+ constant int64_t & ne00,
3258
+ constant int64_t & ne02,
3259
+ constant int64_t & nb01,
3260
+ constant int64_t & nb02,
3261
+ constant int64_t & ne12,
3262
+ constant int64_t & nb10,
3263
+ constant int64_t & nb11,
3264
+ constant int64_t & nb12,
3265
+ constant int64_t & ne0,
3266
+ constant int64_t & ne1,
3267
+ constant uint & r2,
3268
+ constant uint & r3,
3269
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3270
+ uint3 tgpig[[threadgroup_position_in_grid]],
3271
+ uint tiitg[[thread_index_in_threadgroup]],
3272
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3273
 
3274
  threadgroup half * sa = (threadgroup half *)(shared_memory);
3275
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
 
3295
 
3296
  short il = (tiitg % THREAD_PER_ROW);
3297
 
3298
+ const uint i12 = im%ne12;
3299
+ const uint i13 = im/ne12;
3300
+
3301
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
3302
  ushort offset1 = il/nl;
3303
 
3304
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
 
3382
  }
3383
  }
3384
 
3385
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3386
+ kernel void kernel_mul_mm(device const uchar * src0,
3387
+ device const uchar * src1,
3388
+ device float * dst,
3389
+ constant int64_t & ne00,
3390
+ constant int64_t & ne02,
3391
+ constant int64_t & nb01,
3392
+ constant int64_t & nb02,
3393
+ constant int64_t & ne12,
3394
+ constant int64_t & nb10,
3395
+ constant int64_t & nb11,
3396
+ constant int64_t & nb12,
3397
+ constant int64_t & ne0,
3398
+ constant int64_t & ne1,
3399
+ constant uint & r2,
3400
+ constant uint & r3,
3401
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3402
+ uint3 tgpig[[threadgroup_position_in_grid]],
3403
+ uint tiitg[[thread_index_in_threadgroup]],
3404
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3405
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3406
+ src0,
3407
+ src1,
3408
+ dst,
3409
+ ne00,
3410
+ ne02,
3411
+ nb01,
3412
+ nb02,
3413
+ ne12,
3414
+ nb10,
3415
+ nb11,
3416
+ nb12,
3417
+ ne0,
3418
+ ne1,
3419
+ r2,
3420
+ r3,
3421
+ shared_memory,
3422
+ tgpig,
3423
+ tiitg,
3424
+ sgitg);
3425
+ }
3426
+
3427
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3428
+ kernel void kernel_mul_mm_id(
3429
+ device const int32_t * ids,
3430
+ device const uchar * src1,
3431
+ device float * dst,
3432
+ constant int64_t & ne00,
3433
+ constant int64_t & ne02,
3434
+ constant int64_t & nb01,
3435
+ constant int64_t & nb02,
3436
+ constant int64_t & ne12,
3437
+ constant int64_t & nb10,
3438
+ constant int64_t & nb11,
3439
+ constant int64_t & nb12,
3440
+ constant int64_t & ne0,
3441
+ constant int64_t & ne1,
3442
+ constant uint & r2,
3443
+ constant uint & r3,
3444
+ constant int & idx,
3445
+ device const uchar * src00,
3446
+ device const uchar * src01,
3447
+ device const uchar * src02,
3448
+ device const uchar * src03,
3449
+ device const uchar * src04,
3450
+ device const uchar * src05,
3451
+ device const uchar * src06,
3452
+ device const uchar * src07,
3453
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3454
+ uint3 tgpig[[threadgroup_position_in_grid]],
3455
+ uint tiitg[[thread_index_in_threadgroup]],
3456
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3457
+ device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3458
+
3459
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3460
+ src0[ids[idx]],
3461
+ src1,
3462
+ dst,
3463
+ ne00,
3464
+ ne02,
3465
+ nb01,
3466
+ nb02,
3467
+ ne12,
3468
+ nb10,
3469
+ nb11,
3470
+ nb12,
3471
+ ne0,
3472
+ ne1,
3473
+ r2,
3474
+ r3,
3475
+ shared_memory,
3476
+ tgpig,
3477
+ tiitg,
3478
+ sgitg);
3479
+ }
3480
+
3481
  #if QK_K == 256
3482
  #define QK_NL 16
3483
  #else
3484
  #define QK_NL 4
3485
  #endif
3486
 
3487
+ typedef void (get_rows_t)(
3488
+ device const void * src0,
3489
+ device const int * src1,
3490
+ device float * dst,
3491
+ constant int64_t & ne00,
3492
+ constant uint64_t & nb01,
3493
+ constant uint64_t & nb1,
3494
+ uint, uint, uint);
3495
 
3496
  template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
3497
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
 
3520
  constant int64_t & nb12,
3521
  constant int64_t & ne0,
3522
  constant int64_t & ne1,
3523
+ constant uint & r2,
3524
+ constant uint & r3,
3525
+ threadgroup uchar *,
3526
+ uint3, uint, uint);
3527
 
3528
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
3529
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
 
3537
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
3538
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
3539
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
3540
+
3541
+ typedef void (mat_mm_id_t)(
3542
+ device const int32_t * ids,
3543
+ device const uchar * src1,
3544
+ device float * dst,
3545
+ constant int64_t & ne00,
3546
+ constant int64_t & ne02,
3547
+ constant int64_t & nb01,
3548
+ constant int64_t & nb02,
3549
+ constant int64_t & ne12,
3550
+ constant int64_t & nb10,
3551
+ constant int64_t & nb11,
3552
+ constant int64_t & nb12,
3553
+ constant int64_t & ne0,
3554
+ constant int64_t & ne1,
3555
+ constant uint & r2,
3556
+ constant uint & r3,
3557
+ constant int & idx,
3558
+ device const uchar * src00,
3559
+ device const uchar * src01,
3560
+ device const uchar * src02,
3561
+ device const uchar * src03,
3562
+ device const uchar * src04,
3563
+ device const uchar * src05,
3564
+ device const uchar * src06,
3565
+ device const uchar * src07,
3566
+ threadgroup uchar *,
3567
+ uint3, uint, uint);
3568
+
3569
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
3570
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
3571
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
3572
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
3573
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
3574
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
3575
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
3576
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
3577
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
3578
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
3579
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
3580
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
ggml-opencl.cpp CHANGED
@@ -1,20 +1,18 @@
 
1
  #include "ggml-opencl.h"
2
 
3
  #include <array>
4
  #include <atomic>
 
 
 
 
5
  #include <sstream>
6
  #include <vector>
7
- #include <limits>
8
 
9
  #define CL_TARGET_OPENCL_VERSION 110
10
  #include <clblast.h>
11
 
12
- #include <stdlib.h>
13
- #include <stdio.h>
14
- #include <string.h>
15
-
16
- #include "ggml.h"
17
-
18
  #if defined(_MSC_VER)
19
  #pragma warning(disable: 4244 4267) // possible loss of data
20
  #endif
 
1
+ #include "ggml.h"
2
  #include "ggml-opencl.h"
3
 
4
  #include <array>
5
  #include <atomic>
6
+ #include <cstdio>
7
+ #include <cstdlib>
8
+ #include <cstring>
9
+ #include <limits>
10
  #include <sstream>
11
  #include <vector>
 
12
 
13
  #define CL_TARGET_OPENCL_VERSION 110
14
  #include <clblast.h>
15
 
 
 
 
 
 
 
16
  #if defined(_MSC_VER)
17
  #pragma warning(disable: 4244 4267) // possible loss of data
18
  #endif
ggml-quants.c CHANGED
@@ -19,7 +19,7 @@
19
  #ifdef __wasm_simd128__
20
  #include <wasm_simd128.h>
21
  #else
22
- #ifdef __POWER9_VECTOR__
23
  #include <altivec.h>
24
  #undef bool
25
  #define bool _Bool
 
19
  #ifdef __wasm_simd128__
20
  #include <wasm_simd128.h>
21
  #else
22
+ #if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
23
  #include <altivec.h>
24
  #undef bool
25
  #define bool _Bool
ggml.c CHANGED
@@ -233,24 +233,6 @@ inline static void * ggml_aligned_malloc(size_t size) {
233
  #define UNUSED GGML_UNUSED
234
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
235
 
236
- //
237
- // tensor access macros
238
- //
239
-
240
- #define GGML_TENSOR_UNARY_OP_LOCALS \
241
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
242
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
243
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
244
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
245
-
246
- #define GGML_TENSOR_BINARY_OP_LOCALS \
247
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
248
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
249
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
250
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
251
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
252
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
253
-
254
  #if defined(GGML_USE_ACCELERATE)
255
  #include <Accelerate/Accelerate.h>
256
  #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
@@ -1613,6 +1595,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1613
  "GROUP_NORM",
1614
 
1615
  "MUL_MAT",
 
1616
  "OUT_PROD",
1617
 
1618
  "SCALE",
@@ -1640,6 +1623,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1640
  "POOL_1D",
1641
  "POOL_2D",
1642
  "UPSCALE",
 
1643
 
1644
  "FLASH_ATTN",
1645
  "FLASH_FF",
@@ -1666,7 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1666
  "CROSS_ENTROPY_LOSS_BACK",
1667
  };
1668
 
1669
- static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
1670
 
1671
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1672
  "none",
@@ -1695,6 +1679,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1695
  "group_norm(x)",
1696
 
1697
  "X*Y",
 
1698
  "X*Y",
1699
 
1700
  "x*v",
@@ -1722,6 +1707,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1722
  "pool_1d(x)",
1723
  "pool_2d(x)",
1724
  "upscale(x)",
 
1725
 
1726
  "flash_attn(x)",
1727
  "flash_ff(x)",
@@ -1748,10 +1734,28 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1748
  "cross_entropy_loss_back(x,y)",
1749
  };
1750
 
1751
- static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
1752
 
1753
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1755
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
1756
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
1757
 
@@ -1771,6 +1775,7 @@ static void ggml_setup_op_has_task_pass(void) {
1771
 
1772
  p[GGML_OP_ACC ] = true;
1773
  p[GGML_OP_MUL_MAT ] = true;
 
1774
  p[GGML_OP_OUT_PROD ] = true;
1775
  p[GGML_OP_SET ] = true;
1776
  p[GGML_OP_GET_ROWS_BACK ] = true;
@@ -2023,6 +2028,20 @@ const char * ggml_op_symbol(enum ggml_op op) {
2023
  return GGML_OP_SYMBOL[op];
2024
  }
2025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2026
  size_t ggml_element_size(const struct ggml_tensor * tensor) {
2027
  return ggml_type_size(tensor->type);
2028
  }
@@ -3154,9 +3173,7 @@ static struct ggml_tensor * ggml_add_impl(
3154
  struct ggml_tensor * a,
3155
  struct ggml_tensor * b,
3156
  bool inplace) {
3157
- // TODO: support less-strict constraint
3158
- // GGML_ASSERT(ggml_can_repeat(b, a));
3159
- GGML_ASSERT(ggml_can_repeat_rows(b, a));
3160
 
3161
  bool is_node = false;
3162
 
@@ -3371,9 +3388,7 @@ static struct ggml_tensor * ggml_mul_impl(
3371
  struct ggml_tensor * a,
3372
  struct ggml_tensor * b,
3373
  bool inplace) {
3374
- // TODO: support less-strict constraint
3375
- // GGML_ASSERT(ggml_can_repeat(b, a));
3376
- GGML_ASSERT(ggml_can_repeat_rows(b, a));
3377
 
3378
  bool is_node = false;
3379
 
@@ -3418,7 +3433,7 @@ static struct ggml_tensor * ggml_div_impl(
3418
  struct ggml_tensor * a,
3419
  struct ggml_tensor * b,
3420
  bool inplace) {
3421
- GGML_ASSERT(ggml_are_same_shape(a, b));
3422
 
3423
  bool is_node = false;
3424
 
@@ -4056,6 +4071,49 @@ struct ggml_tensor * ggml_mul_mat(
4056
  return result;
4057
  }
4058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4059
  // ggml_out_prod
4060
 
4061
  struct ggml_tensor * ggml_out_prod(
@@ -4209,7 +4267,7 @@ struct ggml_tensor * ggml_set_2d_inplace(
4209
  struct ggml_tensor * b,
4210
  size_t nb1,
4211
  size_t offset) {
4212
- return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
4213
  }
4214
 
4215
  // ggml_cpy
@@ -4826,7 +4884,17 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
4826
  static struct ggml_tensor * ggml_soft_max_impl(
4827
  struct ggml_context * ctx,
4828
  struct ggml_tensor * a,
 
 
4829
  bool inplace) {
 
 
 
 
 
 
 
 
4830
  bool is_node = false;
4831
 
4832
  if (a->grad) {
@@ -4835,9 +4903,13 @@ static struct ggml_tensor * ggml_soft_max_impl(
4835
 
4836
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4837
 
 
 
 
4838
  result->op = GGML_OP_SOFT_MAX;
4839
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4840
  result->src[0] = a;
 
4841
 
4842
  return result;
4843
  }
@@ -4845,13 +4917,21 @@ static struct ggml_tensor * ggml_soft_max_impl(
4845
  struct ggml_tensor * ggml_soft_max(
4846
  struct ggml_context * ctx,
4847
  struct ggml_tensor * a) {
4848
- return ggml_soft_max_impl(ctx, a, false);
4849
  }
4850
 
4851
  struct ggml_tensor * ggml_soft_max_inplace(
4852
  struct ggml_context * ctx,
4853
  struct ggml_tensor * a) {
4854
- return ggml_soft_max_impl(ctx, a, true);
 
 
 
 
 
 
 
 
4855
  }
4856
 
4857
  // ggml_soft_max_back
@@ -5446,6 +5526,43 @@ struct ggml_tensor * ggml_upscale(
5446
  return ggml_upscale_impl(ctx, a, scale_factor);
5447
  }
5448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5449
  // ggml_flash_attn
5450
 
5451
  struct ggml_tensor * ggml_flash_attn(
@@ -6805,7 +6922,7 @@ static void ggml_compute_forward_add_f32(
6805
  const struct ggml_tensor * src0,
6806
  const struct ggml_tensor * src1,
6807
  struct ggml_tensor * dst) {
6808
- GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
6809
 
6810
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6811
  return;
@@ -6838,16 +6955,19 @@ static void ggml_compute_forward_add_f32(
6838
  const int64_t i13 = i03 % ne13;
6839
  const int64_t i12 = i02 % ne12;
6840
  const int64_t i11 = i01 % ne11;
 
6841
 
6842
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
6843
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
6844
  float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
6845
 
 
6846
  #ifdef GGML_USE_ACCELERATE
6847
- vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
6848
  #else
6849
- ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
6850
  #endif
 
6851
  }
6852
  } else {
6853
  // src1 is not contiguous
@@ -6864,8 +6984,9 @@ static void ggml_compute_forward_add_f32(
6864
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
6865
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
6866
 
6867
- for (int i0 = 0; i0 < ne0; i0++) {
6868
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
 
6869
 
6870
  dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
6871
  }
@@ -7585,7 +7706,7 @@ static void ggml_compute_forward_mul_f32(
7585
  const struct ggml_tensor * src0,
7586
  const struct ggml_tensor * src1,
7587
  struct ggml_tensor * dst) {
7588
- GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
7589
 
7590
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
7591
  return;
@@ -7608,7 +7729,6 @@ static void ggml_compute_forward_mul_f32(
7608
 
7609
  GGML_ASSERT( nb0 == sizeof(float));
7610
  GGML_ASSERT(nb00 == sizeof(float));
7611
- GGML_ASSERT(ne00 == ne10);
7612
 
7613
  if (nb10 == sizeof(float)) {
7614
  for (int64_t ir = ith; ir < nr; ir += nth) {
@@ -7620,20 +7740,21 @@ static void ggml_compute_forward_mul_f32(
7620
  const int64_t i13 = i03 % ne13;
7621
  const int64_t i12 = i02 % ne12;
7622
  const int64_t i11 = i01 % ne11;
 
7623
 
7624
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7625
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7626
  float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
7627
 
 
7628
  #ifdef GGML_USE_ACCELERATE
7629
- UNUSED(ggml_vec_mul_f32);
7630
 
7631
- vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
7632
  #else
7633
- ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
7634
  #endif
7635
- // }
7636
- // }
7637
  }
7638
  } else {
7639
  // src1 is not contiguous
@@ -7651,8 +7772,9 @@ static void ggml_compute_forward_mul_f32(
7651
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7652
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7653
 
7654
- for (int64_t i0 = 0; i0 < ne00; i0++) {
7655
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
 
7656
 
7657
  dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
7658
  }
@@ -7686,14 +7808,16 @@ static void ggml_compute_forward_div_f32(
7686
  const struct ggml_tensor * src0,
7687
  const struct ggml_tensor * src1,
7688
  struct ggml_tensor * dst) {
7689
- assert(params->ith == 0);
7690
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
7691
 
7692
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
7693
  return;
7694
  }
7695
 
7696
- const int nr = ggml_nrows(src0);
 
 
 
7697
 
7698
  GGML_TENSOR_BINARY_OP_LOCALS
7699
 
@@ -7701,41 +7825,50 @@ static void ggml_compute_forward_div_f32(
7701
  GGML_ASSERT(nb00 == sizeof(float));
7702
 
7703
  if (nb10 == sizeof(float)) {
7704
- for (int ir = 0; ir < nr; ++ir) {
7705
- // src0, src1 and dst are same shape => same indices
7706
- const int i3 = ir/(ne2*ne1);
7707
- const int i2 = (ir - i3*ne2*ne1)/ne1;
7708
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
 
 
 
 
 
7709
 
 
 
 
 
 
7710
  #ifdef GGML_USE_ACCELERATE
7711
- UNUSED(ggml_vec_div_f32);
7712
 
7713
- vDSP_vdiv(
7714
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
7715
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
7716
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
7717
- ne0);
7718
  #else
7719
- ggml_vec_div_f32(ne0,
7720
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
7721
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
7722
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
7723
  #endif
7724
- // }
7725
- // }
7726
  }
7727
  } else {
7728
  // src1 is not contiguous
7729
- for (int ir = 0; ir < nr; ++ir) {
7730
- // src0, src1 and dst are same shape => same indices
7731
- const int i3 = ir/(ne2*ne1);
7732
- const int i2 = (ir - i3*ne2*ne1)/ne1;
7733
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
 
7734
 
7735
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
7736
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
7737
- for (int i0 = 0; i0 < ne0; i0++) {
7738
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
 
 
 
 
 
 
7739
 
7740
  dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
7741
  }
@@ -8181,7 +8314,7 @@ static void ggml_compute_forward_repeat_f16(
8181
  return;
8182
  }
8183
 
8184
- GGML_TENSOR_UNARY_OP_LOCALS;
8185
 
8186
  // guaranteed to be an integer due to the check in ggml_can_repeat
8187
  const int nr0 = (int)(ne0/ne00);
@@ -8326,6 +8459,7 @@ static void ggml_compute_forward_concat_f32(
8326
  GGML_ASSERT(src0->nb[0] == sizeof(float));
8327
 
8328
  const int ith = params->ith;
 
8329
 
8330
  GGML_TENSOR_BINARY_OP_LOCALS
8331
 
@@ -8335,7 +8469,7 @@ static void ggml_compute_forward_concat_f32(
8335
  GGML_ASSERT(nb10 == sizeof(float));
8336
 
8337
  for (int i3 = 0; i3 < ne3; i3++) {
8338
- for (int i2 = ith; i2 < ne2; i2++) {
8339
  if (i2 < ne02) { // src0
8340
  for (int i1 = 0; i1 < ne1; i1++) {
8341
  for (int i0 = 0; i0 < ne0; i0++) {
@@ -9495,6 +9629,8 @@ static void ggml_compute_forward_mul_mat(
9495
  char * wdata = params->wdata;
9496
  const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
9497
 
 
 
9498
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
9499
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
9500
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
@@ -9596,6 +9732,26 @@ static void ggml_compute_forward_mul_mat(
9596
  }
9597
  }
9598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9599
  // ggml_compute_forward_out_prod
9600
 
9601
  static void ggml_compute_forward_out_prod_f32(
@@ -9611,10 +9767,12 @@ static void ggml_compute_forward_out_prod_f32(
9611
  const int ith = params->ith;
9612
  const int nth = params->nth;
9613
 
 
 
 
9614
  GGML_ASSERT(ne02 == ne12);
9615
- GGML_ASSERT(ne03 == ne13);
9616
- GGML_ASSERT(ne2 == ne12);
9617
  GGML_ASSERT(ne3 == ne13);
 
9618
 
9619
  // we don't support permuted src0 or src1
9620
  GGML_ASSERT(nb00 == sizeof(float));
@@ -9625,18 +9783,25 @@ static void ggml_compute_forward_out_prod_f32(
9625
  // GGML_ASSERT(nb1 <= nb2);
9626
  // GGML_ASSERT(nb2 <= nb3);
9627
 
9628
- GGML_ASSERT(ne0 == ne00);
9629
- GGML_ASSERT(ne1 == ne10);
9630
- GGML_ASSERT(ne2 == ne02);
9631
- GGML_ASSERT(ne3 == ne03);
9632
-
9633
  // nb01 >= nb00 - src0 is not transposed
9634
  // compute by src0 rows
9635
 
9636
  // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
9637
- // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
 
 
 
 
 
 
 
9638
 
9639
  if (params->type == GGML_TASK_INIT) {
 
 
 
 
 
9640
  ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9641
  return;
9642
  }
@@ -9645,6 +9810,50 @@ static void ggml_compute_forward_out_prod_f32(
9645
  return;
9646
  }
9647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9648
  // dst[:,:,:,:] = 0
9649
  // for i2,i3:
9650
  // for i1:
@@ -10498,20 +10707,25 @@ static void ggml_compute_forward_diag_mask_zero(
10498
  static void ggml_compute_forward_soft_max_f32(
10499
  const struct ggml_compute_params * params,
10500
  const struct ggml_tensor * src0,
10501
- struct ggml_tensor * dst) {
10502
- GGML_ASSERT(ggml_is_contiguous(src0));
10503
- GGML_ASSERT(ggml_is_contiguous(dst));
10504
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
10505
 
10506
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10507
  return;
10508
  }
10509
 
 
 
 
10510
  // TODO: handle transposed/permuted matrices
10511
 
10512
  const int ith = params->ith;
10513
  const int nth = params->nth;
10514
 
 
 
10515
  const int nc = src0->ne[0];
10516
  const int nr = ggml_nrows(src0);
10517
 
@@ -10522,29 +10736,40 @@ static void ggml_compute_forward_soft_max_f32(
10522
  const int ir0 = dr*ith;
10523
  const int ir1 = MIN(ir0 + dr, nr);
10524
 
 
 
10525
  for (int i1 = ir0; i1 < ir1; i1++) {
10526
- float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10527
- float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
 
 
 
 
 
 
 
 
 
10528
 
10529
  #ifndef NDEBUG
10530
  for (int i = 0; i < nc; ++i) {
10531
  //printf("p[%d] = %f\n", i, p[i]);
10532
- assert(!isnan(sp[i]));
10533
  }
10534
  #endif
10535
 
10536
  float max = -INFINITY;
10537
- ggml_vec_max_f32(nc, &max, sp);
10538
 
10539
  ggml_float sum = 0.0;
10540
 
10541
  uint16_t scvt;
10542
  for (int i = 0; i < nc; i++) {
10543
- if (sp[i] == -INFINITY) {
10544
  dp[i] = 0.0f;
10545
  } else {
10546
- // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
10547
- ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
10548
  memcpy(&scvt, &s, sizeof(scvt));
10549
  const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
10550
  sum += (ggml_float)val;
@@ -10569,11 +10794,12 @@ static void ggml_compute_forward_soft_max_f32(
10569
  static void ggml_compute_forward_soft_max(
10570
  const struct ggml_compute_params * params,
10571
  const struct ggml_tensor * src0,
10572
- struct ggml_tensor * dst) {
 
10573
  switch (src0->type) {
10574
  case GGML_TYPE_F32:
10575
  {
10576
- ggml_compute_forward_soft_max_f32(params, src0, dst);
10577
  } break;
10578
  default:
10579
  {
@@ -11929,6 +12155,67 @@ static void ggml_compute_forward_upscale(
11929
  }
11930
  }
11931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11932
  // ggml_compute_forward_flash_attn
11933
 
11934
  static void ggml_compute_forward_flash_attn_f32(
@@ -13752,6 +14039,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13752
  {
13753
  ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
13754
  } break;
 
 
 
 
13755
  case GGML_OP_OUT_PROD:
13756
  {
13757
  ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
@@ -13810,7 +14101,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13810
  } break;
13811
  case GGML_OP_SOFT_MAX:
13812
  {
13813
- ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
13814
  } break;
13815
  case GGML_OP_SOFT_MAX_BACK:
13816
  {
@@ -13856,6 +14147,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13856
  {
13857
  ggml_compute_forward_upscale(params, tensor->src[0], tensor);
13858
  } break;
 
 
 
 
13859
  case GGML_OP_FLASH_ATTN:
13860
  {
13861
  const int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -14506,6 +14801,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
14506
  zero_table);
14507
  }
14508
  } break;
 
 
 
 
14509
  case GGML_OP_OUT_PROD:
14510
  {
14511
  GGML_ASSERT(false); // TODO: not implemented
@@ -14844,6 +15143,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
14844
  {
14845
  GGML_ASSERT(false); // TODO: not implemented
14846
  } break;
 
 
 
 
14847
  case GGML_OP_FLASH_ATTN:
14848
  {
14849
  struct ggml_tensor * flash_grad = NULL;
@@ -15204,12 +15507,8 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
15204
  return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
15205
  }
15206
 
15207
- struct ggml_cgraph * ggml_graph_view(struct ggml_context * ctx, struct ggml_cgraph * cgraph0, int i0, int i1) {
15208
- const size_t obj_size = sizeof(struct ggml_cgraph);
15209
- struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size);
15210
- struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
15211
-
15212
- *cgraph = (struct ggml_cgraph) {
15213
  /*.size =*/ 0,
15214
  /*.n_nodes =*/ i1 - i0,
15215
  /*.n_leafs =*/ 0,
@@ -15444,7 +15743,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15444
  n_tasks = n_threads;
15445
  } break;
15446
  case GGML_OP_SUB:
15447
- case GGML_OP_DIV:
15448
  case GGML_OP_SQR:
15449
  case GGML_OP_SQRT:
15450
  case GGML_OP_LOG:
@@ -15477,10 +15775,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15477
  {
15478
  n_tasks = n_threads;
15479
  } break;
 
 
15480
  }
15481
  break;
15482
  case GGML_OP_SILU_BACK:
15483
  case GGML_OP_MUL:
 
15484
  case GGML_OP_NORM:
15485
  case GGML_OP_RMS_NORM:
15486
  case GGML_OP_RMS_NORM_BACK:
@@ -15518,6 +15819,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15518
  }
15519
  #endif
15520
  } break;
 
 
 
 
 
15521
  case GGML_OP_OUT_PROD:
15522
  {
15523
  n_tasks = n_threads;
@@ -15537,7 +15843,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15537
  } break;
15538
  case GGML_OP_DIAG_MASK_ZERO:
15539
  case GGML_OP_DIAG_MASK_INF:
15540
- case GGML_OP_SOFT_MAX:
15541
  case GGML_OP_SOFT_MAX_BACK:
15542
  case GGML_OP_ROPE:
15543
  case GGML_OP_ROPE_BACK:
@@ -15553,6 +15858,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15553
  {
15554
  n_tasks = 1; //TODO
15555
  } break;
 
 
 
 
15556
  case GGML_OP_CONV_TRANSPOSE_1D:
15557
  {
15558
  n_tasks = n_threads;
@@ -15574,6 +15883,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15574
  {
15575
  n_tasks = n_threads;
15576
  } break;
 
 
 
 
15577
  case GGML_OP_FLASH_ATTN:
15578
  {
15579
  n_tasks = n_threads;
@@ -15642,7 +15955,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15642
  } break;
15643
  default:
15644
  {
15645
- printf("%s: op %s not implemented\n", __func__, ggml_op_name(node->op));
 
 
 
 
 
15646
  GGML_ASSERT(false);
15647
  } break;
15648
  }
@@ -15783,18 +16101,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15783
 
15784
  // thread scheduling for the different operations + work buffer size estimation
15785
  for (int i = 0; i < cgraph->n_nodes; i++) {
15786
- int n_tasks = 1;
15787
-
15788
  struct ggml_tensor * node = cgraph->nodes[i];
15789
 
 
 
15790
  size_t cur = 0;
15791
 
15792
  switch (node->op) {
15793
  case GGML_OP_CPY:
15794
  case GGML_OP_DUP:
15795
  {
15796
- n_tasks = n_threads;
15797
-
15798
  if (ggml_is_quantized(node->type)) {
15799
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
15800
  }
@@ -15802,16 +16118,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15802
  case GGML_OP_ADD:
15803
  case GGML_OP_ADD1:
15804
  {
15805
- n_tasks = n_threads;
15806
-
15807
  if (ggml_is_quantized(node->src[0]->type)) {
15808
  cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
15809
  }
15810
  } break;
15811
  case GGML_OP_ACC:
15812
  {
15813
- n_tasks = n_threads;
15814
-
15815
  if (ggml_is_quantized(node->src[0]->type)) {
15816
  cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
15817
  }
@@ -15837,14 +16149,33 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15837
  cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type);
15838
  }
15839
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15840
  case GGML_OP_OUT_PROD:
15841
  {
15842
- n_tasks = n_threads;
15843
-
15844
  if (ggml_is_quantized(node->src[0]->type)) {
15845
  cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
15846
  }
15847
  } break;
 
 
 
 
15848
  case GGML_OP_CONV_TRANSPOSE_1D:
15849
  {
15850
  GGML_ASSERT(node->src[0]->ne[3] == 1);
@@ -15870,10 +16201,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15870
  GGML_ASSERT(false);
15871
  }
15872
  } break;
15873
- case GGML_OP_IM2COL:
15874
- {
15875
- n_tasks = n_threads;
15876
- } break;
15877
  case GGML_OP_CONV_TRANSPOSE_2D:
15878
  {
15879
  const int64_t ne00 = node->src[0]->ne[0]; // W
@@ -15890,8 +16217,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15890
  } break;
15891
  case GGML_OP_FLASH_ATTN:
15892
  {
15893
- n_tasks = n_threads;
15894
-
15895
  const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
15896
 
15897
  if (node->src[1]->type == GGML_TYPE_F32) {
@@ -15904,8 +16229,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15904
  } break;
15905
  case GGML_OP_FLASH_FF:
15906
  {
15907
- n_tasks = n_threads;
15908
-
15909
  if (node->src[1]->type == GGML_TYPE_F32) {
15910
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
15911
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -15916,8 +16239,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15916
  } break;
15917
  case GGML_OP_FLASH_ATTN_BACK:
15918
  {
15919
- n_tasks = n_threads;
15920
-
15921
  const int64_t D = node->src[0]->ne[0];
15922
  const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
15923
  const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
@@ -15932,8 +16253,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15932
 
15933
  case GGML_OP_CROSS_ENTROPY_LOSS:
15934
  {
15935
- n_tasks = n_threads;
15936
-
15937
  cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
15938
  } break;
15939
  case GGML_OP_COUNT:
@@ -17720,8 +18039,8 @@ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t *
17720
  memcpy(&qh, &y[i].qh, sizeof(qh));
17721
 
17722
  for (int j = 0; j < QK5_0; j += 2) {
17723
- const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
17724
- const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
17725
 
17726
  // cast to 16 bins
17727
  const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -17750,8 +18069,8 @@ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t *
17750
  memcpy(&qh, &y[i].qh, sizeof(qh));
17751
 
17752
  for (int j = 0; j < QK5_1; j += 2) {
17753
- const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
17754
- const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
17755
 
17756
  // cast to 16 bins
17757
  const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -17941,6 +18260,7 @@ struct gguf_kv {
17941
 
17942
  struct gguf_header {
17943
  char magic[4];
 
17944
  uint32_t version;
17945
  uint64_t n_tensors; // GGUFv2
17946
  uint64_t n_kv; // GGUFv2
@@ -18030,7 +18350,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
18030
 
18031
  for (uint32_t i = 0; i < sizeof(magic); i++) {
18032
  if (magic[i] != GGUF_MAGIC[i]) {
18033
- fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic);
18034
  fclose(file);
18035
  return NULL;
18036
  }
@@ -18045,7 +18365,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
18045
  {
18046
  strncpy(ctx->header.magic, magic, 4);
18047
 
18048
-
18049
  ctx->kv = NULL;
18050
  ctx->infos = NULL;
18051
  ctx->data = NULL;
@@ -18399,24 +18718,29 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
18399
  }
18400
 
18401
  const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
 
18402
  return ctx->kv[key_id].key.data;
18403
  }
18404
 
18405
  enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
 
18406
  return ctx->kv[key_id].type;
18407
  }
18408
 
18409
  enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
 
18410
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
18411
  return ctx->kv[key_id].value.arr.type;
18412
  }
18413
 
18414
  const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
 
18415
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
18416
  return ctx->kv[key_id].value.arr.data;
18417
  }
18418
 
18419
  const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
 
18420
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
18421
  struct gguf_kv * kv = &ctx->kv[key_id];
18422
  struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
@@ -18424,70 +18748,90 @@ const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i
18424
  }
18425
 
18426
  int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
 
18427
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
18428
  return ctx->kv[key_id].value.arr.n;
18429
  }
18430
 
18431
  uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
 
18432
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
18433
  return ctx->kv[key_id].value.uint8;
18434
  }
18435
 
18436
  int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
 
18437
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
18438
  return ctx->kv[key_id].value.int8;
18439
  }
18440
 
18441
  uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
 
18442
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
18443
  return ctx->kv[key_id].value.uint16;
18444
  }
18445
 
18446
  int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
 
18447
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
18448
  return ctx->kv[key_id].value.int16;
18449
  }
18450
 
18451
  uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
 
18452
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
18453
  return ctx->kv[key_id].value.uint32;
18454
  }
18455
 
18456
  int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
 
18457
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
18458
  return ctx->kv[key_id].value.int32;
18459
  }
18460
 
18461
  float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
 
18462
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
18463
  return ctx->kv[key_id].value.float32;
18464
  }
18465
 
18466
  uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
 
18467
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
18468
  return ctx->kv[key_id].value.uint64;
18469
  }
18470
 
18471
  int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
 
18472
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
18473
  return ctx->kv[key_id].value.int64;
18474
  }
18475
 
18476
  double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
 
18477
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
18478
  return ctx->kv[key_id].value.float64;
18479
  }
18480
 
18481
  bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
 
18482
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
18483
  return ctx->kv[key_id].value.bool_;
18484
  }
18485
 
18486
  const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
 
18487
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
18488
  return ctx->kv[key_id].value.str.data;
18489
  }
18490
 
 
 
 
 
 
 
 
18491
  int gguf_get_n_tensors(const struct gguf_context * ctx) {
18492
  return ctx->header.n_tensors;
18493
  }
 
233
  #define UNUSED GGML_UNUSED
234
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  #if defined(GGML_USE_ACCELERATE)
237
  #include <Accelerate/Accelerate.h>
238
  #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
 
1595
  "GROUP_NORM",
1596
 
1597
  "MUL_MAT",
1598
+ "MUL_MAT_ID",
1599
  "OUT_PROD",
1600
 
1601
  "SCALE",
 
1623
  "POOL_1D",
1624
  "POOL_2D",
1625
  "UPSCALE",
1626
+ "ARGSORT",
1627
 
1628
  "FLASH_ATTN",
1629
  "FLASH_FF",
 
1650
  "CROSS_ENTROPY_LOSS_BACK",
1651
  };
1652
 
1653
+ static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
1654
 
1655
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1656
  "none",
 
1679
  "group_norm(x)",
1680
 
1681
  "X*Y",
1682
+ "X[i]*Y",
1683
  "X*Y",
1684
 
1685
  "x*v",
 
1707
  "pool_1d(x)",
1708
  "pool_2d(x)",
1709
  "upscale(x)",
1710
+ "argsort(x)",
1711
 
1712
  "flash_attn(x)",
1713
  "flash_ff(x)",
 
1734
  "cross_entropy_loss_back(x,y)",
1735
  };
1736
 
1737
+ static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
1738
 
1739
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1740
 
1741
+
1742
+ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1743
+ "ABS",
1744
+ "SGN",
1745
+ "NEG",
1746
+ "STEP",
1747
+ "TANH",
1748
+ "ELU",
1749
+ "RELU",
1750
+ "GELU",
1751
+ "GELU_QUICK",
1752
+ "SILU",
1753
+ "LEAKY",
1754
+ };
1755
+
1756
+ static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
1757
+
1758
+
1759
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
1760
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
1761
 
 
1775
 
1776
  p[GGML_OP_ACC ] = true;
1777
  p[GGML_OP_MUL_MAT ] = true;
1778
+ p[GGML_OP_MUL_MAT_ID ] = true;
1779
  p[GGML_OP_OUT_PROD ] = true;
1780
  p[GGML_OP_SET ] = true;
1781
  p[GGML_OP_GET_ROWS_BACK ] = true;
 
2028
  return GGML_OP_SYMBOL[op];
2029
  }
2030
 
2031
+ const char * ggml_unary_op_name(enum ggml_unary_op op) {
2032
+ return GGML_UNARY_OP_NAME[op];
2033
+ }
2034
+
2035
+ const char * ggml_op_desc(const struct ggml_tensor * t) {
2036
+ if (t->op == GGML_OP_UNARY) {
2037
+ enum ggml_unary_op uop = ggml_get_unary_op(t);
2038
+ return ggml_unary_op_name(uop);
2039
+ }
2040
+ else {
2041
+ return ggml_op_name(t->op);
2042
+ }
2043
+ }
2044
+
2045
  size_t ggml_element_size(const struct ggml_tensor * tensor) {
2046
  return ggml_type_size(tensor->type);
2047
  }
 
3173
  struct ggml_tensor * a,
3174
  struct ggml_tensor * b,
3175
  bool inplace) {
3176
+ GGML_ASSERT(ggml_can_repeat(b, a));
 
 
3177
 
3178
  bool is_node = false;
3179
 
 
3388
  struct ggml_tensor * a,
3389
  struct ggml_tensor * b,
3390
  bool inplace) {
3391
+ GGML_ASSERT(ggml_can_repeat(b, a));
 
 
3392
 
3393
  bool is_node = false;
3394
 
 
3433
  struct ggml_tensor * a,
3434
  struct ggml_tensor * b,
3435
  bool inplace) {
3436
+ GGML_ASSERT(ggml_can_repeat(b, a));
3437
 
3438
  bool is_node = false;
3439
 
 
4071
  return result;
4072
  }
4073
 
4074
+ // ggml_mul_mat_id
4075
+
4076
+ struct ggml_tensor * ggml_mul_mat_id(
4077
+ struct ggml_context * ctx,
4078
+ struct ggml_tensor * as[],
4079
+ struct ggml_tensor * ids,
4080
+ int id,
4081
+ struct ggml_tensor * b) {
4082
+
4083
+ int64_t n_as = ids->ne[0];
4084
+
4085
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
4086
+ GGML_ASSERT(ggml_is_vector(ids));
4087
+ GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
4088
+ GGML_ASSERT(id >= 0 && id < n_as);
4089
+
4090
+ bool is_node = false;
4091
+
4092
+ if (as[0]->grad || b->grad) {
4093
+ is_node = true;
4094
+ }
4095
+
4096
+ const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
4097
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
4098
+
4099
+ ggml_set_op_params_i32(result, 0, id);
4100
+
4101
+ result->op = GGML_OP_MUL_MAT_ID;
4102
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4103
+ result->src[0] = ids;
4104
+ result->src[1] = b;
4105
+
4106
+ for (int64_t i = 0; i < n_as; i++) {
4107
+ struct ggml_tensor * a = as[i];
4108
+ GGML_ASSERT(ggml_are_same_shape(as[0], a));
4109
+ GGML_ASSERT(ggml_can_mul_mat(a, b));
4110
+ GGML_ASSERT(!ggml_is_transposed(a));
4111
+ result->src[i + 2] = a;
4112
+ }
4113
+
4114
+ return result;
4115
+ }
4116
+
4117
  // ggml_out_prod
4118
 
4119
  struct ggml_tensor * ggml_out_prod(
 
4267
  struct ggml_tensor * b,
4268
  size_t nb1,
4269
  size_t offset) {
4270
+ return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
4271
  }
4272
 
4273
  // ggml_cpy
 
4884
  static struct ggml_tensor * ggml_soft_max_impl(
4885
  struct ggml_context * ctx,
4886
  struct ggml_tensor * a,
4887
+ struct ggml_tensor * mask,
4888
+ float scale,
4889
  bool inplace) {
4890
+ GGML_ASSERT(ggml_is_contiguous(a));
4891
+ if (mask) {
4892
+ GGML_ASSERT(ggml_is_contiguous(mask));
4893
+ GGML_ASSERT(mask->ne[2] == 1);
4894
+ GGML_ASSERT(mask->ne[3] == 1);
4895
+ GGML_ASSERT(ggml_can_repeat_rows(mask, a));
4896
+ }
4897
+
4898
  bool is_node = false;
4899
 
4900
  if (a->grad) {
 
4903
 
4904
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4905
 
4906
+ float params[] = { scale };
4907
+ ggml_set_op_params(result, params, sizeof(params));
4908
+
4909
  result->op = GGML_OP_SOFT_MAX;
4910
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4911
  result->src[0] = a;
4912
+ result->src[1] = mask;
4913
 
4914
  return result;
4915
  }
 
4917
  struct ggml_tensor * ggml_soft_max(
4918
  struct ggml_context * ctx,
4919
  struct ggml_tensor * a) {
4920
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
4921
  }
4922
 
4923
  struct ggml_tensor * ggml_soft_max_inplace(
4924
  struct ggml_context * ctx,
4925
  struct ggml_tensor * a) {
4926
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
4927
+ }
4928
+
4929
+ struct ggml_tensor * ggml_soft_max_ext(
4930
+ struct ggml_context * ctx,
4931
+ struct ggml_tensor * a,
4932
+ struct ggml_tensor * mask,
4933
+ float scale) {
4934
+ return ggml_soft_max_impl(ctx, a, mask, scale, false);
4935
  }
4936
 
4937
  // ggml_soft_max_back
 
5526
  return ggml_upscale_impl(ctx, a, scale_factor);
5527
  }
5528
 
5529
+ // ggml_argsort
5530
+
5531
+ struct ggml_tensor * ggml_argsort(
5532
+ struct ggml_context * ctx,
5533
+ struct ggml_tensor * a,
5534
+ enum ggml_sort_order order) {
5535
+ bool is_node = false;
5536
+
5537
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, a->n_dims, a->ne);
5538
+
5539
+ ggml_set_op_params_i32(result, 0, (int32_t) order);
5540
+
5541
+ result->op = GGML_OP_ARGSORT;
5542
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5543
+ result->src[0] = a;
5544
+
5545
+ return result;
5546
+ }
5547
+
5548
+ // ggml_top_k
5549
+
5550
+ struct ggml_tensor * ggml_top_k(
5551
+ struct ggml_context * ctx,
5552
+ struct ggml_tensor * a,
5553
+ int k) {
5554
+ GGML_ASSERT(a->ne[0] >= k);
5555
+
5556
+ struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_DESC);
5557
+
5558
+ result = ggml_view_4d(ctx, result,
5559
+ k, result->ne[1], result->ne[2], result->ne[3],
5560
+ result->nb[1], result->nb[2], result->nb[3],
5561
+ 0);
5562
+
5563
+ return result;
5564
+ }
5565
+
5566
  // ggml_flash_attn
5567
 
5568
  struct ggml_tensor * ggml_flash_attn(
 
6922
  const struct ggml_tensor * src0,
6923
  const struct ggml_tensor * src1,
6924
  struct ggml_tensor * dst) {
6925
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
6926
 
6927
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6928
  return;
 
6955
  const int64_t i13 = i03 % ne13;
6956
  const int64_t i12 = i02 % ne12;
6957
  const int64_t i11 = i01 % ne11;
6958
+ const int64_t nr0 = ne00 / ne10;
6959
 
6960
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
6961
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
6962
  float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
6963
 
6964
+ for (int64_t r = 0; r < nr0; ++r) {
6965
  #ifdef GGML_USE_ACCELERATE
6966
+ vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
6967
  #else
6968
+ ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
6969
  #endif
6970
+ }
6971
  }
6972
  } else {
6973
  // src1 is not contiguous
 
6984
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
6985
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
6986
 
6987
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
6988
+ const int64_t i10 = i0 % ne10;
6989
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
6990
 
6991
  dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
6992
  }
 
7706
  const struct ggml_tensor * src0,
7707
  const struct ggml_tensor * src1,
7708
  struct ggml_tensor * dst) {
7709
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
7710
 
7711
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
7712
  return;
 
7729
 
7730
  GGML_ASSERT( nb0 == sizeof(float));
7731
  GGML_ASSERT(nb00 == sizeof(float));
 
7732
 
7733
  if (nb10 == sizeof(float)) {
7734
  for (int64_t ir = ith; ir < nr; ir += nth) {
 
7740
  const int64_t i13 = i03 % ne13;
7741
  const int64_t i12 = i02 % ne12;
7742
  const int64_t i11 = i01 % ne11;
7743
+ const int64_t nr0 = ne00 / ne10;
7744
 
7745
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7746
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7747
  float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
7748
 
7749
+ for (int64_t r = 0 ; r < nr0; ++r) {
7750
  #ifdef GGML_USE_ACCELERATE
7751
+ UNUSED(ggml_vec_mul_f32);
7752
 
7753
+ vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
7754
  #else
7755
+ ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
7756
  #endif
7757
+ }
 
7758
  }
7759
  } else {
7760
  // src1 is not contiguous
 
7772
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7773
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7774
 
7775
+ for (int64_t i0 = 0; i0 < ne00; ++i0) {
7776
+ const int64_t i10 = i0 % ne10;
7777
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
7778
 
7779
  dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
7780
  }
 
7808
  const struct ggml_tensor * src0,
7809
  const struct ggml_tensor * src1,
7810
  struct ggml_tensor * dst) {
7811
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
7812
 
7813
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
7814
  return;
7815
  }
7816
 
7817
+ const int ith = params->ith;
7818
+ const int nth = params->nth;
7819
+
7820
+ const int64_t nr = ggml_nrows(src0);
7821
 
7822
  GGML_TENSOR_BINARY_OP_LOCALS
7823
 
 
7825
  GGML_ASSERT(nb00 == sizeof(float));
7826
 
7827
  if (nb10 == sizeof(float)) {
7828
+ for (int64_t ir = ith; ir < nr; ir += nth) {
7829
+ // src0 and dst are same shape => same indices
7830
+ const int64_t i03 = ir/(ne02*ne01);
7831
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
7832
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
7833
+
7834
+ const int64_t i13 = i03 % ne13;
7835
+ const int64_t i12 = i02 % ne12;
7836
+ const int64_t i11 = i01 % ne11;
7837
+ const int64_t nr0 = ne00 / ne10;
7838
 
7839
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7840
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7841
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
7842
+
7843
+ for (int64_t r = 0; r < nr0; ++r) {
7844
  #ifdef GGML_USE_ACCELERATE
7845
+ UNUSED(ggml_vec_div_f32);
7846
 
7847
+ vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
 
 
 
 
7848
  #else
7849
+ ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 
 
 
7850
  #endif
7851
+ }
 
7852
  }
7853
  } else {
7854
  // src1 is not contiguous
7855
+ for (int64_t ir = ith; ir < nr; ir += nth) {
7856
+ // src0 and dst are same shape => same indices
7857
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
7858
+ const int64_t i03 = ir/(ne02*ne01);
7859
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
7860
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
7861
 
7862
+ const int64_t i13 = i03 % ne13;
7863
+ const int64_t i12 = i02 % ne12;
7864
+ const int64_t i11 = i01 % ne11;
7865
+
7866
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7867
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7868
+
7869
+ for (int64_t i0 = 0; i0 < ne00; ++i0) {
7870
+ const int64_t i10 = i0 % ne10;
7871
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
7872
 
7873
  dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
7874
  }
 
8314
  return;
8315
  }
8316
 
8317
+ GGML_TENSOR_UNARY_OP_LOCALS
8318
 
8319
  // guaranteed to be an integer due to the check in ggml_can_repeat
8320
  const int nr0 = (int)(ne0/ne00);
 
8459
  GGML_ASSERT(src0->nb[0] == sizeof(float));
8460
 
8461
  const int ith = params->ith;
8462
+ const int nth = params->nth;
8463
 
8464
  GGML_TENSOR_BINARY_OP_LOCALS
8465
 
 
8469
  GGML_ASSERT(nb10 == sizeof(float));
8470
 
8471
  for (int i3 = 0; i3 < ne3; i3++) {
8472
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
8473
  if (i2 < ne02) { // src0
8474
  for (int i1 = 0; i1 < ne1; i1++) {
8475
  for (int i0 = 0; i0 < ne0; i0++) {
 
9629
  char * wdata = params->wdata;
9630
  const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
9631
 
9632
+ assert(params->wsize >= ne11*ne12*ne13*row_size);
9633
+
9634
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
9635
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
9636
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
 
9732
  }
9733
  }
9734
 
9735
+ // ggml_compute_forward_mul_mat_id
9736
+
9737
+ static void ggml_compute_forward_mul_mat_id(
9738
+ const struct ggml_compute_params * params,
9739
+ struct ggml_tensor * dst) {
9740
+
9741
+ const struct ggml_tensor * ids = dst->src[0];
9742
+ const struct ggml_tensor * src1 = dst->src[1];
9743
+
9744
+ const int id = ggml_get_op_params_i32(dst, 0);
9745
+
9746
+ const int a_id = ((int32_t *)ids->data)[id];
9747
+
9748
+ GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
9749
+
9750
+ const struct ggml_tensor * src0 = dst->src[a_id + 2];
9751
+
9752
+ ggml_compute_forward_mul_mat(params, src0, src1, dst);
9753
+ }
9754
+
9755
  // ggml_compute_forward_out_prod
9756
 
9757
  static void ggml_compute_forward_out_prod_f32(
 
9767
  const int ith = params->ith;
9768
  const int nth = params->nth;
9769
 
9770
+ GGML_ASSERT(ne0 == ne00);
9771
+ GGML_ASSERT(ne1 == ne10);
9772
+ GGML_ASSERT(ne2 == ne02);
9773
  GGML_ASSERT(ne02 == ne12);
 
 
9774
  GGML_ASSERT(ne3 == ne13);
9775
+ GGML_ASSERT(ne03 == ne13);
9776
 
9777
  // we don't support permuted src0 or src1
9778
  GGML_ASSERT(nb00 == sizeof(float));
 
9783
  // GGML_ASSERT(nb1 <= nb2);
9784
  // GGML_ASSERT(nb2 <= nb3);
9785
 
 
 
 
 
 
9786
  // nb01 >= nb00 - src0 is not transposed
9787
  // compute by src0 rows
9788
 
9789
  // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
9790
+ // TODO: #if defined(GGML_USE_CLBLAST)
9791
+
9792
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
9793
+ bool use_blas = ggml_is_matrix(src0) &&
9794
+ ggml_is_matrix(src1) &&
9795
+ ggml_is_contiguous(src0) &&
9796
+ (ggml_is_contiguous(src1) || ggml_is_transposed(src1));
9797
+ #endif
9798
 
9799
  if (params->type == GGML_TASK_INIT) {
9800
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
9801
+ if (use_blas) {
9802
+ return;
9803
+ }
9804
+ #endif
9805
  ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9806
  return;
9807
  }
 
9810
  return;
9811
  }
9812
 
9813
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
9814
+ if (use_blas) {
9815
+ if (params->ith != 0) { // All threads other than the first do no work.
9816
+ return;
9817
+ }
9818
+ // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
9819
+ // src0: (k,n)
9820
+ // src1: (k,m)
9821
+ // dst: (m,n)
9822
+ //
9823
+ // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
9824
+ // Also expressed as (major,minor)
9825
+ // a: (m,k): so src1 transposed
9826
+ // b: (k,n): so src0
9827
+ // c: (m,n)
9828
+ //
9829
+ // However, if ggml_is_transposed(src1) is true, then
9830
+ // src1->data already contains a transposed version, so sgemm mustn't
9831
+ // transpose it further.
9832
+
9833
+ int n = src0->ne[0];
9834
+ int k = src0->ne[1];
9835
+ int m = src1->ne[0];
9836
+
9837
+ int transposeA, lda;
9838
+
9839
+ if (!ggml_is_transposed(src1)) {
9840
+ transposeA = CblasTrans;
9841
+ lda = m;
9842
+ } else {
9843
+ transposeA = CblasNoTrans;
9844
+ lda = k;
9845
+ }
9846
+
9847
+ float * a = (float *) ((char *) src1->data);
9848
+ float * b = (float *) ((char *) src0->data);
9849
+ float * c = (float *) ((char *) dst->data);
9850
+
9851
+ cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
9852
+
9853
+ return;
9854
+ }
9855
+ #endif
9856
+
9857
  // dst[:,:,:,:] = 0
9858
  // for i2,i3:
9859
  // for i1:
 
10707
  static void ggml_compute_forward_soft_max_f32(
10708
  const struct ggml_compute_params * params,
10709
  const struct ggml_tensor * src0,
10710
+ const struct ggml_tensor * src1,
10711
+ struct ggml_tensor * dst) {
10712
+ assert(ggml_is_contiguous(dst));
10713
+ assert(ggml_are_same_shape(src0, dst));
10714
 
10715
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10716
  return;
10717
  }
10718
 
10719
+ float scale = 1.0f;
10720
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
10721
+
10722
  // TODO: handle transposed/permuted matrices
10723
 
10724
  const int ith = params->ith;
10725
  const int nth = params->nth;
10726
 
10727
+ const int64_t ne11 = src1 ? src1->ne[1] : 1;
10728
+
10729
  const int nc = src0->ne[0];
10730
  const int nr = ggml_nrows(src0);
10731
 
 
10736
  const int ir0 = dr*ith;
10737
  const int ir1 = MIN(ir0 + dr, nr);
10738
 
10739
+ float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
10740
+
10741
  for (int i1 = ir0; i1 < ir1; i1++) {
10742
+ float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10743
+ float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10744
+
10745
+ // broadcast the mask across rows
10746
+ float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
10747
+
10748
+ ggml_vec_cpy_f32 (nc, wp, sp);
10749
+ ggml_vec_scale_f32(nc, wp, scale);
10750
+ if (mp) {
10751
+ ggml_vec_acc_f32(nc, wp, mp);
10752
+ }
10753
 
10754
  #ifndef NDEBUG
10755
  for (int i = 0; i < nc; ++i) {
10756
  //printf("p[%d] = %f\n", i, p[i]);
10757
+ assert(!isnan(wp[i]));
10758
  }
10759
  #endif
10760
 
10761
  float max = -INFINITY;
10762
+ ggml_vec_max_f32(nc, &max, wp);
10763
 
10764
  ggml_float sum = 0.0;
10765
 
10766
  uint16_t scvt;
10767
  for (int i = 0; i < nc; i++) {
10768
+ if (wp[i] == -INFINITY) {
10769
  dp[i] = 0.0f;
10770
  } else {
10771
+ // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
10772
+ ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
10773
  memcpy(&scvt, &s, sizeof(scvt));
10774
  const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
10775
  sum += (ggml_float)val;
 
10794
  static void ggml_compute_forward_soft_max(
10795
  const struct ggml_compute_params * params,
10796
  const struct ggml_tensor * src0,
10797
+ const struct ggml_tensor * src1,
10798
+ struct ggml_tensor * dst) {
10799
  switch (src0->type) {
10800
  case GGML_TYPE_F32:
10801
  {
10802
+ ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
10803
  } break;
10804
  default:
10805
  {
 
12155
  }
12156
  }
12157
 
12158
+ // ggml_compute_forward_argsort
12159
+
12160
+ static void ggml_compute_forward_argsort_f32(
12161
+ const struct ggml_compute_params * params,
12162
+ const struct ggml_tensor * src0,
12163
+ struct ggml_tensor * dst) {
12164
+
12165
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12166
+ return;
12167
+ }
12168
+
12169
+ GGML_TENSOR_UNARY_OP_LOCALS
12170
+
12171
+ GGML_ASSERT(nb0 == sizeof(float));
12172
+
12173
+ const int ith = params->ith;
12174
+ const int nth = params->nth;
12175
+
12176
+ const int64_t nr = ggml_nrows(src0);
12177
+
12178
+ enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
12179
+
12180
+ for (int64_t i = ith; i < nr; i += nth) {
12181
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
12182
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
12183
+
12184
+ for (int64_t j = 0; j < ne0; j++) {
12185
+ dst_data[j] = j;
12186
+ }
12187
+
12188
+ // C doesn't have a functional sort, so we do a bubble sort instead
12189
+ for (int64_t j = 0; j < ne0; j++) {
12190
+ for (int64_t k = j + 1; k < ne0; k++) {
12191
+ if ((order == GGML_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
12192
+ (order == GGML_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
12193
+ int32_t tmp = dst_data[j];
12194
+ dst_data[j] = dst_data[k];
12195
+ dst_data[k] = tmp;
12196
+ }
12197
+ }
12198
+ }
12199
+ }
12200
+ }
12201
+
12202
+ static void ggml_compute_forward_argsort(
12203
+ const struct ggml_compute_params * params,
12204
+ const struct ggml_tensor * src0,
12205
+ struct ggml_tensor * dst) {
12206
+
12207
+ switch (src0->type) {
12208
+ case GGML_TYPE_F32:
12209
+ {
12210
+ ggml_compute_forward_argsort_f32(params, src0, dst);
12211
+ } break;
12212
+ default:
12213
+ {
12214
+ GGML_ASSERT(false);
12215
+ } break;
12216
+ }
12217
+ }
12218
+
12219
  // ggml_compute_forward_flash_attn
12220
 
12221
  static void ggml_compute_forward_flash_attn_f32(
 
14039
  {
14040
  ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14041
  } break;
14042
+ case GGML_OP_MUL_MAT_ID:
14043
+ {
14044
+ ggml_compute_forward_mul_mat_id(params, tensor);
14045
+ } break;
14046
  case GGML_OP_OUT_PROD:
14047
  {
14048
  ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
 
14101
  } break;
14102
  case GGML_OP_SOFT_MAX:
14103
  {
14104
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
14105
  } break;
14106
  case GGML_OP_SOFT_MAX_BACK:
14107
  {
 
14147
  {
14148
  ggml_compute_forward_upscale(params, tensor->src[0], tensor);
14149
  } break;
14150
+ case GGML_OP_ARGSORT:
14151
+ {
14152
+ ggml_compute_forward_argsort(params, tensor->src[0], tensor);
14153
+ } break;
14154
  case GGML_OP_FLASH_ATTN:
14155
  {
14156
  const int32_t t = ggml_get_op_params_i32(tensor, 0);
 
14801
  zero_table);
14802
  }
14803
  } break;
14804
+ case GGML_OP_MUL_MAT_ID:
14805
+ {
14806
+ GGML_ASSERT(false); // TODO: not implemented
14807
+ } break;
14808
  case GGML_OP_OUT_PROD:
14809
  {
14810
  GGML_ASSERT(false); // TODO: not implemented
 
15143
  {
15144
  GGML_ASSERT(false); // TODO: not implemented
15145
  } break;
15146
+ case GGML_OP_ARGSORT:
15147
+ {
15148
+ GGML_ASSERT(false); // TODO: not implemented
15149
+ } break;
15150
  case GGML_OP_FLASH_ATTN:
15151
  {
15152
  struct ggml_tensor * flash_grad = NULL;
 
15507
  return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
15508
  }
15509
 
15510
+ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
15511
+ struct ggml_cgraph cgraph = {
 
 
 
 
15512
  /*.size =*/ 0,
15513
  /*.n_nodes =*/ i1 - i0,
15514
  /*.n_leafs =*/ 0,
 
15743
  n_tasks = n_threads;
15744
  } break;
15745
  case GGML_OP_SUB:
 
15746
  case GGML_OP_SQR:
15747
  case GGML_OP_SQRT:
15748
  case GGML_OP_LOG:
 
15775
  {
15776
  n_tasks = n_threads;
15777
  } break;
15778
+ default:
15779
+ GGML_ASSERT(false);
15780
  }
15781
  break;
15782
  case GGML_OP_SILU_BACK:
15783
  case GGML_OP_MUL:
15784
+ case GGML_OP_DIV:
15785
  case GGML_OP_NORM:
15786
  case GGML_OP_RMS_NORM:
15787
  case GGML_OP_RMS_NORM_BACK:
 
15819
  }
15820
  #endif
15821
  } break;
15822
+ case GGML_OP_MUL_MAT_ID:
15823
+ {
15824
+ // FIXME: blas
15825
+ n_tasks = n_threads;
15826
+ } break;
15827
  case GGML_OP_OUT_PROD:
15828
  {
15829
  n_tasks = n_threads;
 
15843
  } break;
15844
  case GGML_OP_DIAG_MASK_ZERO:
15845
  case GGML_OP_DIAG_MASK_INF:
 
15846
  case GGML_OP_SOFT_MAX_BACK:
15847
  case GGML_OP_ROPE:
15848
  case GGML_OP_ROPE_BACK:
 
15858
  {
15859
  n_tasks = 1; //TODO
15860
  } break;
15861
+ case GGML_OP_SOFT_MAX:
15862
+ {
15863
+ n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
15864
+ } break;
15865
  case GGML_OP_CONV_TRANSPOSE_1D:
15866
  {
15867
  n_tasks = n_threads;
 
15883
  {
15884
  n_tasks = n_threads;
15885
  } break;
15886
+ case GGML_OP_ARGSORT:
15887
+ {
15888
+ n_tasks = n_threads;
15889
+ } break;
15890
  case GGML_OP_FLASH_ATTN:
15891
  {
15892
  n_tasks = n_threads;
 
15955
  } break;
15956
  default:
15957
  {
15958
+ fprintf(stderr, "%s: op not implemented: ", __func__);
15959
+ if (node->op < GGML_OP_COUNT) {
15960
+ fprintf(stderr, "%s\n", ggml_op_name(node->op));
15961
+ } else {
15962
+ fprintf(stderr, "%d\n", node->op);
15963
+ }
15964
  GGML_ASSERT(false);
15965
  } break;
15966
  }
 
16101
 
16102
  // thread scheduling for the different operations + work buffer size estimation
16103
  for (int i = 0; i < cgraph->n_nodes; i++) {
 
 
16104
  struct ggml_tensor * node = cgraph->nodes[i];
16105
 
16106
+ const int n_tasks = ggml_get_n_tasks(node, n_threads);
16107
+
16108
  size_t cur = 0;
16109
 
16110
  switch (node->op) {
16111
  case GGML_OP_CPY:
16112
  case GGML_OP_DUP:
16113
  {
 
 
16114
  if (ggml_is_quantized(node->type)) {
16115
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
16116
  }
 
16118
  case GGML_OP_ADD:
16119
  case GGML_OP_ADD1:
16120
  {
 
 
16121
  if (ggml_is_quantized(node->src[0]->type)) {
16122
  cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
16123
  }
16124
  } break;
16125
  case GGML_OP_ACC:
16126
  {
 
 
16127
  if (ggml_is_quantized(node->src[0]->type)) {
16128
  cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
16129
  }
 
16149
  cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type);
16150
  }
16151
  } break;
16152
+ case GGML_OP_MUL_MAT_ID:
16153
+ {
16154
+ const struct ggml_tensor * a = node->src[2];
16155
+ const struct ggml_tensor * b = node->src[1];
16156
+ const enum ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
16157
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
16158
+ if (ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
16159
+ if (a->type != GGML_TYPE_F32) {
16160
+ // here we need memory just for single 2D matrix from src0
16161
+ cur = ggml_type_size(GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
16162
+ }
16163
+ } else
16164
+ #endif
16165
+ if (b->type != vec_dot_type) {
16166
+ cur = ggml_type_size(vec_dot_type)*ggml_nelements(b)/ggml_blck_size(vec_dot_type);
16167
+ }
16168
+ } break;
16169
  case GGML_OP_OUT_PROD:
16170
  {
 
 
16171
  if (ggml_is_quantized(node->src[0]->type)) {
16172
  cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
16173
  }
16174
  } break;
16175
+ case GGML_OP_SOFT_MAX:
16176
+ {
16177
+ cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
16178
+ } break;
16179
  case GGML_OP_CONV_TRANSPOSE_1D:
16180
  {
16181
  GGML_ASSERT(node->src[0]->ne[3] == 1);
 
16201
  GGML_ASSERT(false);
16202
  }
16203
  } break;
 
 
 
 
16204
  case GGML_OP_CONV_TRANSPOSE_2D:
16205
  {
16206
  const int64_t ne00 = node->src[0]->ne[0]; // W
 
16217
  } break;
16218
  case GGML_OP_FLASH_ATTN:
16219
  {
 
 
16220
  const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16221
 
16222
  if (node->src[1]->type == GGML_TYPE_F32) {
 
16229
  } break;
16230
  case GGML_OP_FLASH_FF:
16231
  {
 
 
16232
  if (node->src[1]->type == GGML_TYPE_F32) {
16233
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16234
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
 
16239
  } break;
16240
  case GGML_OP_FLASH_ATTN_BACK:
16241
  {
 
 
16242
  const int64_t D = node->src[0]->ne[0];
16243
  const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16244
  const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
 
16253
 
16254
  case GGML_OP_CROSS_ENTROPY_LOSS:
16255
  {
 
 
16256
  cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
16257
  } break;
16258
  case GGML_OP_COUNT:
 
18039
  memcpy(&qh, &y[i].qh, sizeof(qh));
18040
 
18041
  for (int j = 0; j < QK5_0; j += 2) {
18042
+ const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
18043
+ const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
18044
 
18045
  // cast to 16 bins
18046
  const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
 
18069
  memcpy(&qh, &y[i].qh, sizeof(qh));
18070
 
18071
  for (int j = 0; j < QK5_1; j += 2) {
18072
+ const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
18073
+ const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
18074
 
18075
  // cast to 16 bins
18076
  const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
 
18260
 
18261
  struct gguf_header {
18262
  char magic[4];
18263
+
18264
  uint32_t version;
18265
  uint64_t n_tensors; // GGUFv2
18266
  uint64_t n_kv; // GGUFv2
 
18350
 
18351
  for (uint32_t i = 0; i < sizeof(magic); i++) {
18352
  if (magic[i] != GGUF_MAGIC[i]) {
18353
+ fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
18354
  fclose(file);
18355
  return NULL;
18356
  }
 
18365
  {
18366
  strncpy(ctx->header.magic, magic, 4);
18367
 
 
18368
  ctx->kv = NULL;
18369
  ctx->infos = NULL;
18370
  ctx->data = NULL;
 
18718
  }
18719
 
18720
  const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
18721
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18722
  return ctx->kv[key_id].key.data;
18723
  }
18724
 
18725
  enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
18726
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18727
  return ctx->kv[key_id].type;
18728
  }
18729
 
18730
  enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
18731
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18732
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
18733
  return ctx->kv[key_id].value.arr.type;
18734
  }
18735
 
18736
  const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
18737
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18738
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
18739
  return ctx->kv[key_id].value.arr.data;
18740
  }
18741
 
18742
  const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
18743
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18744
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
18745
  struct gguf_kv * kv = &ctx->kv[key_id];
18746
  struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
 
18748
  }
18749
 
18750
  int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
18751
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18752
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
18753
  return ctx->kv[key_id].value.arr.n;
18754
  }
18755
 
18756
  uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
18757
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18758
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
18759
  return ctx->kv[key_id].value.uint8;
18760
  }
18761
 
18762
  int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
18763
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18764
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
18765
  return ctx->kv[key_id].value.int8;
18766
  }
18767
 
18768
  uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
18769
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18770
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
18771
  return ctx->kv[key_id].value.uint16;
18772
  }
18773
 
18774
  int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
18775
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18776
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
18777
  return ctx->kv[key_id].value.int16;
18778
  }
18779
 
18780
  uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
18781
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18782
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
18783
  return ctx->kv[key_id].value.uint32;
18784
  }
18785
 
18786
  int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
18787
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18788
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
18789
  return ctx->kv[key_id].value.int32;
18790
  }
18791
 
18792
  float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
18793
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18794
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
18795
  return ctx->kv[key_id].value.float32;
18796
  }
18797
 
18798
  uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
18799
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18800
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
18801
  return ctx->kv[key_id].value.uint64;
18802
  }
18803
 
18804
  int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
18805
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18806
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
18807
  return ctx->kv[key_id].value.int64;
18808
  }
18809
 
18810
  double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
18811
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18812
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
18813
  return ctx->kv[key_id].value.float64;
18814
  }
18815
 
18816
  bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
18817
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18818
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
18819
  return ctx->kv[key_id].value.bool_;
18820
  }
18821
 
18822
  const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
18823
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18824
  GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
18825
  return ctx->kv[key_id].value.str.data;
18826
  }
18827
 
18828
+ const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
18829
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
18830
+ GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
18831
+ GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
18832
+ return &ctx->kv[key_id].value;
18833
+ }
18834
+
18835
  int gguf_get_n_tensors(const struct gguf_context * ctx) {
18836
  return ctx->header.n_tensors;
18837
  }
ggml.h CHANGED
@@ -244,11 +244,10 @@
244
  #define GGML_ASSERT(x) \
245
  do { \
246
  if (!(x)) { \
247
- fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
248
- fflush(stderr); \
249
  fflush(stdout); \
 
250
  ggml_print_backtrace(); \
251
- exit(1); \
252
  } \
253
  } while (0)
254
 
@@ -284,6 +283,20 @@
284
  const type prefix##3 = (pointer)->array[3]; \
285
  GGML_UNUSED(prefix##3);
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  #ifdef __cplusplus
288
  extern "C" {
289
  #endif
@@ -382,6 +395,7 @@ extern "C" {
382
  GGML_OP_GROUP_NORM,
383
 
384
  GGML_OP_MUL_MAT,
 
385
  GGML_OP_OUT_PROD,
386
 
387
  GGML_OP_SCALE,
@@ -408,8 +422,8 @@ extern "C" {
408
  GGML_OP_CONV_TRANSPOSE_2D,
409
  GGML_OP_POOL_1D,
410
  GGML_OP_POOL_2D,
411
-
412
  GGML_OP_UPSCALE, // nearest interpolate
 
413
 
414
  GGML_OP_FLASH_ATTN,
415
  GGML_OP_FLASH_FF,
@@ -449,7 +463,9 @@ extern "C" {
449
  GGML_UNARY_OP_GELU,
450
  GGML_UNARY_OP_GELU_QUICK,
451
  GGML_UNARY_OP_SILU,
452
- GGML_UNARY_OP_LEAKY
 
 
453
  };
454
 
455
  enum ggml_object_type {
@@ -632,6 +648,9 @@ extern "C" {
632
  GGML_API const char * ggml_op_name (enum ggml_op op);
633
  GGML_API const char * ggml_op_symbol(enum ggml_op op);
634
 
 
 
 
635
  GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
636
 
637
  GGML_API bool ggml_is_quantized(enum ggml_type type);
@@ -1028,6 +1047,15 @@ extern "C" {
1028
  struct ggml_tensor * a,
1029
  struct ggml_tensor * b);
1030
 
 
 
 
 
 
 
 
 
 
1031
  // A: m columns, n rows,
1032
  // B: p columns, n rows,
1033
  // result is m columns, p rows
@@ -1283,6 +1311,14 @@ extern "C" {
1283
  struct ggml_context * ctx,
1284
  struct ggml_tensor * a);
1285
 
 
 
 
 
 
 
 
 
1286
  GGML_API struct ggml_tensor * ggml_soft_max_back(
1287
  struct ggml_context * ctx,
1288
  struct ggml_tensor * a,
@@ -1513,6 +1549,23 @@ extern "C" {
1513
  struct ggml_tensor * a,
1514
  int scale_factor);
1515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1516
  GGML_API struct ggml_tensor * ggml_flash_attn(
1517
  struct ggml_context * ctx,
1518
  struct ggml_tensor * q,
@@ -1574,7 +1627,6 @@ extern "C" {
1574
  int kh);
1575
 
1576
  // used in sam
1577
-
1578
  GGML_API struct ggml_tensor * ggml_add_rel_pos(
1579
  struct ggml_context * ctx,
1580
  struct ggml_tensor * a,
@@ -1749,7 +1801,7 @@ extern "C" {
1749
  GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
1750
  GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
1751
  GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
1752
- GGML_API struct ggml_cgraph * ggml_graph_view (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
1753
  GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
1754
  GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
1755
  GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
@@ -2045,6 +2097,7 @@ extern "C" {
2045
  GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
2046
  GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
2047
  GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
 
2048
  GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
2049
  GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
2050
  GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
 
244
  #define GGML_ASSERT(x) \
245
  do { \
246
  if (!(x)) { \
 
 
247
  fflush(stdout); \
248
+ fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
249
  ggml_print_backtrace(); \
250
+ abort(); \
251
  } \
252
  } while (0)
253
 
 
283
  const type prefix##3 = (pointer)->array[3]; \
284
  GGML_UNUSED(prefix##3);
285
 
286
+ #define GGML_TENSOR_UNARY_OP_LOCALS \
287
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
288
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
289
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
290
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
291
+
292
+ #define GGML_TENSOR_BINARY_OP_LOCALS \
293
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
294
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
295
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
296
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
297
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
298
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
299
+
300
  #ifdef __cplusplus
301
  extern "C" {
302
  #endif
 
395
  GGML_OP_GROUP_NORM,
396
 
397
  GGML_OP_MUL_MAT,
398
+ GGML_OP_MUL_MAT_ID,
399
  GGML_OP_OUT_PROD,
400
 
401
  GGML_OP_SCALE,
 
422
  GGML_OP_CONV_TRANSPOSE_2D,
423
  GGML_OP_POOL_1D,
424
  GGML_OP_POOL_2D,
 
425
  GGML_OP_UPSCALE, // nearest interpolate
426
+ GGML_OP_ARGSORT,
427
 
428
  GGML_OP_FLASH_ATTN,
429
  GGML_OP_FLASH_FF,
 
463
  GGML_UNARY_OP_GELU,
464
  GGML_UNARY_OP_GELU_QUICK,
465
  GGML_UNARY_OP_SILU,
466
+ GGML_UNARY_OP_LEAKY,
467
+
468
+ GGML_UNARY_OP_COUNT,
469
  };
470
 
471
  enum ggml_object_type {
 
648
  GGML_API const char * ggml_op_name (enum ggml_op op);
649
  GGML_API const char * ggml_op_symbol(enum ggml_op op);
650
 
651
+ GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
652
+ GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
653
+
654
  GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
655
 
656
  GGML_API bool ggml_is_quantized(enum ggml_type type);
 
1047
  struct ggml_tensor * a,
1048
  struct ggml_tensor * b);
1049
 
1050
+ // indirect matrix multiplication
1051
+ // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
1052
+ GGML_API struct ggml_tensor * ggml_mul_mat_id(
1053
+ struct ggml_context * ctx,
1054
+ struct ggml_tensor * as[],
1055
+ struct ggml_tensor * ids,
1056
+ int id,
1057
+ struct ggml_tensor * b);
1058
+
1059
  // A: m columns, n rows,
1060
  // B: p columns, n rows,
1061
  // result is m columns, p rows
 
1311
  struct ggml_context * ctx,
1312
  struct ggml_tensor * a);
1313
 
1314
+ // fused soft_max(a*scale + mask)
1315
+ // mask is optional
1316
+ GGML_API struct ggml_tensor * ggml_soft_max_ext(
1317
+ struct ggml_context * ctx,
1318
+ struct ggml_tensor * a,
1319
+ struct ggml_tensor * mask,
1320
+ float scale);
1321
+
1322
  GGML_API struct ggml_tensor * ggml_soft_max_back(
1323
  struct ggml_context * ctx,
1324
  struct ggml_tensor * a,
 
1549
  struct ggml_tensor * a,
1550
  int scale_factor);
1551
 
1552
+ // sort rows
1553
+ enum ggml_sort_order {
1554
+ GGML_SORT_ASC,
1555
+ GGML_SORT_DESC,
1556
+ };
1557
+
1558
+ GGML_API struct ggml_tensor * ggml_argsort(
1559
+ struct ggml_context * ctx,
1560
+ struct ggml_tensor * a,
1561
+ enum ggml_sort_order order);
1562
+
1563
+ // top k elements per row
1564
+ GGML_API struct ggml_tensor * ggml_top_k(
1565
+ struct ggml_context * ctx,
1566
+ struct ggml_tensor * a,
1567
+ int k);
1568
+
1569
  GGML_API struct ggml_tensor * ggml_flash_attn(
1570
  struct ggml_context * ctx,
1571
  struct ggml_tensor * q,
 
1627
  int kh);
1628
 
1629
  // used in sam
 
1630
  GGML_API struct ggml_tensor * ggml_add_rel_pos(
1631
  struct ggml_context * ctx,
1632
  struct ggml_tensor * a,
 
1801
  GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
1802
  GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
1803
  GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
1804
+ GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
1805
  GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
1806
  GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
1807
  GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
 
2097
  GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
2098
  GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
2099
  GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
2100
+ GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
2101
  GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
2102
  GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
2103
  GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
whisper.cpp CHANGED
@@ -1063,7 +1063,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
1063
  #ifdef GGML_USE_CUBLAS
1064
  if (params.use_gpu && ggml_cublas_loaded()) {
1065
  WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1066
- backend_gpu = ggml_backend_cuda_init();
1067
  if (!backend_gpu) {
1068
  WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
1069
  }
@@ -1077,8 +1077,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
1077
  backend_gpu = ggml_backend_metal_init();
1078
  if (!backend_gpu) {
1079
  WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
1080
- }
1081
- if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
1082
  WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1083
  ggml_backend_free(backend_gpu);
1084
  backend_gpu = NULL;
@@ -1346,10 +1345,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1346
  model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1347
 
1348
  model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1349
- model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
1350
 
1351
  model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1352
- model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_ctx, n_audio_state);
1353
 
1354
  model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1355
  model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
@@ -1579,29 +1578,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1579
 
1580
  auto tensor = model.tensors[name.data()];
1581
 
1582
- const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
1583
-
1584
- if (!is_conv_bias) {
1585
- if (ggml_nelements(tensor) != nelements) {
1586
- WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1587
- WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1588
- __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
1589
- return false;
1590
- }
1591
 
1592
- if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1593
- WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1594
- __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
1595
- return false;
1596
- }
1597
 
1598
- const size_t bpe = ggml_type_size(ggml_type(ttype));
1599
 
1600
- if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
1601
- WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1602
- __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1603
- return false;
1604
- }
1605
  }
1606
 
1607
  ggml_backend_t backend = wctx.backend;
@@ -1612,7 +1607,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1612
  #ifdef GGML_USE_METAL
1613
  || ggml_backend_is_metal(backend)
1614
  #endif
1615
- ) && !is_conv_bias) {
1616
  // for the CPU and Metal backend, we can read directly into the tensor
1617
  loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1618
  BYTESWAP_TENSOR(tensor);
@@ -1620,24 +1615,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1620
  // read into a temporary buffer first, then copy to device memory
1621
  read_buf.resize(ggml_nbytes(tensor));
1622
 
1623
- // we repeat the 2 bias tensors along dim 0:
1624
- // [1, 512] -> [3000, 512] (conv1.bias)
1625
- // [1, 512] -> [1500, 512] (conv2.bias)
1626
- if (is_conv_bias) {
1627
- loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
1628
-
1629
- float * data_f32 = (float *) read_buf.data();
1630
- for (int64_t y = 0; y < tensor->ne[1]; ++y) {
1631
- const int64_t yy = tensor->ne[1] - y - 1;
1632
- const float val = data_f32[yy];
1633
-
1634
- for (int64_t x = 0; x < tensor->ne[0]; ++x) {
1635
- data_f32[yy*tensor->ne[0] + x] = val;
1636
- }
1637
- }
1638
- } else {
1639
- loader->read(loader->context, read_buf.data(), read_buf.size());
1640
- }
1641
 
1642
  ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
1643
  }
@@ -1737,20 +1715,12 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1737
  // convolution + gelu
1738
  {
1739
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1740
- if (n_ctx == hparams.n_audio_ctx) {
1741
- cur = ggml_add(ctx0, cur, model.e_conv_1_b);
1742
- } else {
1743
- cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_1_b, cur->ne[0], cur->ne[1], model.e_conv_1_b->nb[1], 0)));
1744
- }
1745
 
1746
  cur = ggml_gelu(ctx0, cur);
1747
 
1748
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1749
- if (n_ctx == hparams.n_audio_ctx) {
1750
- cur = ggml_add(ctx0, cur, model.e_conv_2_b);
1751
- } else {
1752
- cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_2_b, cur->ne[0], cur->ne[1], model.e_conv_2_b->nb[1], 0)));
1753
- }
1754
 
1755
  cur = ggml_gelu(ctx0, cur);
1756
  }
 
1063
  #ifdef GGML_USE_CUBLAS
1064
  if (params.use_gpu && ggml_cublas_loaded()) {
1065
  WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1066
+ backend_gpu = ggml_backend_cuda_init(0);
1067
  if (!backend_gpu) {
1068
  WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
1069
  }
 
1077
  backend_gpu = ggml_backend_metal_init();
1078
  if (!backend_gpu) {
1079
  WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
1080
+ } else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
 
1081
  WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1082
  ggml_backend_free(backend_gpu);
1083
  backend_gpu = NULL;
 
1345
  model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1346
 
1347
  model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1348
+ model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1349
 
1350
  model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1351
+ model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1352
 
1353
  model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1354
  model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
 
1578
 
1579
  auto tensor = model.tensors[name.data()];
1580
 
1581
+ if (ggml_nelements(tensor) != nelements) {
1582
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1583
+ WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1584
+ __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
1585
+ return false;
1586
+ }
 
 
 
1587
 
1588
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1589
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1590
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
1591
+ return false;
1592
+ }
1593
 
1594
+ const size_t bpe = ggml_type_size(ggml_type(ttype));
1595
 
1596
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
1597
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1598
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1599
+ return false;
 
1600
  }
1601
 
1602
  ggml_backend_t backend = wctx.backend;
 
1607
  #ifdef GGML_USE_METAL
1608
  || ggml_backend_is_metal(backend)
1609
  #endif
1610
+ )) {
1611
  // for the CPU and Metal backend, we can read directly into the tensor
1612
  loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1613
  BYTESWAP_TENSOR(tensor);
 
1615
  // read into a temporary buffer first, then copy to device memory
1616
  read_buf.resize(ggml_nbytes(tensor));
1617
 
1618
+ loader->read(loader->context, read_buf.data(), read_buf.size());
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1619
 
1620
  ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
1621
  }
 
1715
  // convolution + gelu
1716
  {
1717
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1718
+ cur = ggml_add(ctx0, cur, model.e_conv_1_b);
 
 
 
 
1719
 
1720
  cur = ggml_gelu(ctx0, cur);
1721
 
1722
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1723
+ cur = ggml_add(ctx0, cur, model.e_conv_2_b);
 
 
 
 
1724
 
1725
  cur = ggml_gelu(ctx0, cur);
1726
  }