JohannesGaessler commited on
Commit
52069b8
·
1 Parent(s): 2a0805f

ggml/ex: calculate accuracy in graph, adapt MNIST (ggml/980)

Browse files
ggml/include/ggml.h CHANGED
@@ -466,6 +466,7 @@ extern "C" {
466
  GGML_OP_SUM_ROWS,
467
  GGML_OP_MEAN,
468
  GGML_OP_ARGMAX,
 
469
  GGML_OP_REPEAT,
470
  GGML_OP_REPEAT_BACK,
471
  GGML_OP_CONCAT,
@@ -1004,6 +1005,12 @@ extern "C" {
1004
  struct ggml_context * ctx,
1005
  struct ggml_tensor * a);
1006
 
 
 
 
 
 
 
1007
  // if a is the same shape as b, and a is not parameter, return a
1008
  // otherwise, return a new tensor: repeat(a) to fit in b
1009
  GGML_API struct ggml_tensor * ggml_repeat(
 
466
  GGML_OP_SUM_ROWS,
467
  GGML_OP_MEAN,
468
  GGML_OP_ARGMAX,
469
+ GGML_OP_COUNT_EQUAL,
470
  GGML_OP_REPEAT,
471
  GGML_OP_REPEAT_BACK,
472
  GGML_OP_CONCAT,
 
1005
  struct ggml_context * ctx,
1006
  struct ggml_tensor * a);
1007
 
1008
+ // count number of equal elements in a and b
1009
+ GGML_API struct ggml_tensor * ggml_count_equal(
1010
+ struct ggml_context * ctx,
1011
+ struct ggml_tensor * a,
1012
+ struct ggml_tensor * b);
1013
+
1014
  // if a is the same shape as b, and a is not parameter, return a
1015
  // otherwise, return a new tensor: repeat(a) to fit in b
1016
  GGML_API struct ggml_tensor * ggml_repeat(
ggml/src/ggml-cuda.cu CHANGED
@@ -5,12 +5,14 @@
5
  #include "ggml-cuda/common.cuh"
6
  #include "ggml-cuda/acc.cuh"
7
  #include "ggml-cuda/arange.cuh"
 
8
  #include "ggml-cuda/argsort.cuh"
9
  #include "ggml-cuda/binbcast.cuh"
10
  #include "ggml-cuda/clamp.cuh"
11
  #include "ggml-cuda/concat.cuh"
12
  #include "ggml-cuda/conv-transpose-1d.cuh"
13
  #include "ggml-cuda/convert.cuh"
 
14
  #include "ggml-cuda/cpy.cuh"
15
  #include "ggml-cuda/cross-entropy-loss.cuh"
16
  #include "ggml-cuda/diagmask.cuh"
@@ -2178,6 +2180,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2178
  }
2179
 
2180
  switch (dst->op) {
 
 
 
 
 
 
2181
  case GGML_OP_REPEAT:
2182
  ggml_cuda_op_repeat(ctx, dst);
2183
  break;
@@ -2929,6 +2937,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2929
  return false;
2930
  } break;
2931
  case GGML_OP_DUP:
 
 
 
 
 
 
 
 
 
2932
  case GGML_OP_REPEAT:
2933
  {
2934
  ggml_type src0_type = op->src[0]->type;
 
5
  #include "ggml-cuda/common.cuh"
6
  #include "ggml-cuda/acc.cuh"
7
  #include "ggml-cuda/arange.cuh"
8
+ #include "ggml-cuda/argmax.cuh"
9
  #include "ggml-cuda/argsort.cuh"
10
  #include "ggml-cuda/binbcast.cuh"
11
  #include "ggml-cuda/clamp.cuh"
12
  #include "ggml-cuda/concat.cuh"
13
  #include "ggml-cuda/conv-transpose-1d.cuh"
14
  #include "ggml-cuda/convert.cuh"
15
+ #include "ggml-cuda/count-equal.cuh"
16
  #include "ggml-cuda/cpy.cuh"
17
  #include "ggml-cuda/cross-entropy-loss.cuh"
18
  #include "ggml-cuda/diagmask.cuh"
 
2180
  }
2181
 
2182
  switch (dst->op) {
2183
+ case GGML_OP_ARGMAX:
2184
+ ggml_cuda_argmax(ctx, dst);
2185
+ break;
2186
+ case GGML_OP_COUNT_EQUAL:
2187
+ ggml_cuda_count_equal(ctx, dst);
2188
+ break;
2189
  case GGML_OP_REPEAT:
2190
  ggml_cuda_op_repeat(ctx, dst);
2191
  break;
 
2937
  return false;
2938
  } break;
2939
  case GGML_OP_DUP:
2940
+ {
2941
+ ggml_type src0_type = op->src[0]->type;
2942
+ return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
2943
+ } break;
2944
+ case GGML_OP_ARGMAX:
2945
+ case GGML_OP_COUNT_EQUAL:
2946
+ {
2947
+ return true;
2948
+ } break;
2949
  case GGML_OP_REPEAT:
2950
  {
2951
  ggml_type src0_type = op->src[0]->type;
ggml/src/ggml-cuda/argmax.cu ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "argmax.cuh"
3
+ #include "sum.cuh"
4
+
5
+ #include <cstdint>
6
+
7
+ static __global__ void argmax_f32(
8
+ const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
9
+
10
+ int argmax_thread = 0;
11
+ const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
12
+
13
+ #pragma unroll
14
+ for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
15
+ const int64_t row = row0 + row1;
16
+
17
+ if (row >= nrows) {
18
+ break;
19
+ }
20
+
21
+ float maxval = -FLT_MAX;
22
+ int argmax = -1;
23
+
24
+ for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
25
+ const float val = x[row*ncols + col];
26
+ const int bigger = val > maxval;
27
+ const int not_bigger = bigger ^ 0x00000001;
28
+
29
+ maxval = maxval*not_bigger + val*bigger;
30
+ argmax = argmax*not_bigger + col*bigger;
31
+ }
32
+
33
+ #pragma unroll
34
+ for (int mask = 16; mask > 0; mask >>= 1) {
35
+ const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
36
+ const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
37
+ const int bigger = val > maxval;
38
+ const int not_bigger = bigger ^ 0x00000001;
39
+
40
+ maxval = maxval*not_bigger + val*bigger;
41
+ argmax = argmax*not_bigger + col*bigger;
42
+ }
43
+
44
+ const int store = row1 == threadIdx.x;
45
+ argmax_thread += store*argmax;
46
+ }
47
+
48
+ const int row = row0 + threadIdx.x;
49
+
50
+ if (row >= nrows) {
51
+ return;
52
+ }
53
+
54
+ dst[row] = argmax_thread;
55
+ }
56
+
57
+ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
58
+ const ggml_tensor * src0 = dst->src[0];
59
+
60
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
61
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
62
+
63
+ GGML_ASSERT(ggml_is_contiguous(src0));
64
+
65
+ const int64_t ne00 = src0->ne[0];
66
+ const int64_t nrows = ggml_nrows(src0);
67
+
68
+ const float * src0_d = (const float *) src0->data;
69
+ int32_t * dst_d = (int32_t *) dst->data;
70
+
71
+ cudaStream_t stream = ctx.stream();
72
+
73
+ const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;
74
+
75
+ const dim3 blocks_dim(WARP_SIZE, 1, 1);
76
+ const dim3 blocks_num(num_blocks, 1, 1);
77
+
78
+ argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
79
+ }
ggml/src/ggml-cuda/argmax.cuh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -175,6 +175,18 @@ static __device__ void no_device_code(
175
  #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
176
  #endif // __CUDA_ARCH__
177
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  static __device__ __forceinline__ float warp_reduce_sum(float x) {
179
  #pragma unroll
180
  for (int mask = 16; mask > 0; mask >>= 1) {
 
175
  #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
176
  #endif // __CUDA_ARCH__
177
 
178
+ static __device__ __forceinline__ int warp_reduce_sum(int x) {
179
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
180
+ return __reduce_add_sync(0xffffffff, x);
181
+ #else
182
+ #pragma unroll
183
+ for (int mask = 16; mask > 0; mask >>= 1) {
184
+ x += __shfl_xor_sync(0xffffffff, x, mask, 32);
185
+ }
186
+ return x;
187
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
188
+ }
189
+
190
  static __device__ __forceinline__ float warp_reduce_sum(float x) {
191
  #pragma unroll
192
  for (int mask = 16; mask > 0; mask >>= 1) {
ggml/src/ggml-cuda/count-equal.cu ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "count-equal.cuh"
3
+
4
+ #include <cstdint>
5
+
6
+ template <typename T>
7
+ static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) {
8
+ const int64_t i0 = (int64_t) blockIdx.x*dk;
9
+ const int64_t i1 = min(i0 + dk, k);
10
+
11
+ int nequal = 0;
12
+
13
+ for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) {
14
+ const T xi = x[i];
15
+ const T yi = y[i];
16
+ nequal += xi == yi;
17
+ }
18
+
19
+ nequal = warp_reduce_sum(nequal);
20
+
21
+ if (threadIdx.x != 0) {
22
+ return;
23
+ }
24
+
25
+ atomicAdd((int *) dst, nequal);
26
+ }
27
+
28
+ void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
29
+ const ggml_tensor * src0 = dst->src[0];
30
+ const ggml_tensor * src1 = dst->src[1];
31
+
32
+ GGML_ASSERT(src0->type == src1->type);
33
+ GGML_ASSERT( dst->type == GGML_TYPE_I64);
34
+
35
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
36
+ GGML_ASSERT(ggml_is_contiguous(src0));
37
+ GGML_ASSERT(ggml_is_contiguous(src1));
38
+ GGML_ASSERT(ggml_is_contiguous(dst));
39
+
40
+ int64_t * dst_d = (int64_t *) dst->data;
41
+
42
+ cudaStream_t stream = ctx.stream();
43
+ const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
44
+
45
+ const int64_t ne = ggml_nelements(src0);
46
+ GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
47
+ const int64_t dne = GGML_PAD(ne / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);
48
+
49
+ CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream));
50
+
51
+ const dim3 blocks_dim(WARP_SIZE, 1, 1);
52
+ const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1);
53
+
54
+ switch (src0->type) {
55
+ case GGML_TYPE_I32: {
56
+ const int * src0_d = (const int *) src0->data;
57
+ const int * src1_d = (const int *) src1->data;
58
+ count_equal<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_d, dne, ne);
59
+ } break;
60
+ default:
61
+ GGML_ASSERT(false);
62
+ break;
63
+ }
64
+ }
ggml/src/ggml-cuda/count-equal.cuh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_COUNT_EQUAL_CHUNK_SIZE 128
4
+
5
+ void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -259,7 +259,7 @@ static __global__ void flash_attn_tile_ext_f16(
259
  }
260
 
261
  half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
262
- kqsum_j = warp_reduce_sum(kqsum_j);
263
 
264
  #pragma unroll
265
  for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
 
259
  }
260
 
261
  half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
262
+ kqsum_j = warp_reduce_sum((float)kqsum_j);
263
 
264
  #pragma unroll
265
  for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -196,7 +196,7 @@ static __global__ void flash_attn_vec_ext_f16(
196
  #pragma unroll
197
  for (int j = 0; j < ncols; ++j) {
198
  half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
199
- sum = warp_reduce_sum(sum);
200
 
201
  if (use_logit_softcap) {
202
  sum = logit_softcap*tanhf(sum);
@@ -265,7 +265,7 @@ static __global__ void flash_attn_vec_ext_f16(
265
 
266
  #pragma unroll
267
  for (int j = 0; j < ncols; ++j) {
268
- kqsum[j] = warp_reduce_sum(kqsum[j]);
269
  if (threadIdx.x == 0) {
270
  kqsum_shared[j][threadIdx.y] = kqsum[j];
271
  }
@@ -280,7 +280,7 @@ static __global__ void flash_attn_vec_ext_f16(
280
  }
281
 
282
  kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
283
- kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
284
 
285
  half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
286
  if (parallel_blocks == 1) {
 
196
  #pragma unroll
197
  for (int j = 0; j < ncols; ++j) {
198
  half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
199
+ sum = warp_reduce_sum((float)sum);
200
 
201
  if (use_logit_softcap) {
202
  sum = logit_softcap*tanhf(sum);
 
265
 
266
  #pragma unroll
267
  for (int j = 0; j < ncols; ++j) {
268
+ kqsum[j] = warp_reduce_sum((float)kqsum[j]);
269
  if (threadIdx.x == 0) {
270
  kqsum_shared[j][threadIdx.y] = kqsum[j];
271
  }
 
280
  }
281
 
282
  kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
283
+ kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
284
 
285
  half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
286
  if (parallel_blocks == 1) {
ggml/src/ggml.c CHANGED
@@ -2957,6 +2957,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2957
  "SUM_ROWS",
2958
  "MEAN",
2959
  "ARGMAX",
 
2960
  "REPEAT",
2961
  "REPEAT_BACK",
2962
  "CONCAT",
@@ -3030,7 +3031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3030
  "OPT_STEP_ADAMW",
3031
  };
3032
 
3033
- static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
3034
 
3035
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3036
  "none",
@@ -3051,6 +3052,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3051
  "Σx_k",
3052
  "Σx/n",
3053
  "argmax(x)",
 
3054
  "repeat(x)",
3055
  "repeat_back(x)",
3056
  "concat(x, y)",
@@ -3124,7 +3126,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3124
  "adamw(x)",
3125
  };
3126
 
3127
- static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
3128
 
3129
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
3130
 
@@ -5185,6 +5187,23 @@ struct ggml_tensor * ggml_argmax(
5185
  return result;
5186
  }
5187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5188
  // ggml_repeat
5189
 
5190
  struct ggml_tensor * ggml_repeat(
@@ -10772,6 +10791,86 @@ static void ggml_compute_forward_argmax(
10772
  }
10773
  }
10774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10775
  // ggml_compute_forward_repeat
10776
 
10777
  static void ggml_compute_forward_repeat_f32(
@@ -17146,6 +17245,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17146
  {
17147
  ggml_compute_forward_argmax(params, tensor);
17148
  } break;
 
 
 
 
17149
  case GGML_OP_REPEAT:
17150
  {
17151
  ggml_compute_forward_repeat(params, tensor);
@@ -17896,6 +17999,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17896
  } break;
17897
  case GGML_OP_MEAN:
17898
  case GGML_OP_ARGMAX:
 
17899
  {
17900
  GGML_ABORT("fatal error"); // TODO: implement
17901
  }
@@ -18669,6 +18773,10 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
18669
  for (int i = 0; i < gf->n_nodes; ++i) {
18670
  struct ggml_tensor * node = gf->nodes[i];
18671
 
 
 
 
 
18672
  bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
18673
  bool ignore_src[GGML_MAX_SRC] = {false};
18674
  switch (node->op) {
@@ -19072,6 +19180,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
19072
  case GGML_OP_SUM_ROWS:
19073
  case GGML_OP_MEAN:
19074
  case GGML_OP_ARGMAX:
 
 
 
 
 
 
 
19075
  case GGML_OP_REPEAT:
19076
  case GGML_OP_REPEAT_BACK:
19077
  case GGML_OP_LEAKY_RELU:
@@ -19570,6 +19685,10 @@ struct ggml_cplan ggml_graph_plan(
19570
  cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
19571
  }
19572
  } break;
 
 
 
 
19573
  case GGML_OP_MUL_MAT:
19574
  {
19575
  const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
 
2957
  "SUM_ROWS",
2958
  "MEAN",
2959
  "ARGMAX",
2960
+ "COUNT_EQUAL",
2961
  "REPEAT",
2962
  "REPEAT_BACK",
2963
  "CONCAT",
 
3031
  "OPT_STEP_ADAMW",
3032
  };
3033
 
3034
+ static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
3035
 
3036
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3037
  "none",
 
3052
  "Σx_k",
3053
  "Σx/n",
3054
  "argmax(x)",
3055
+ "count_equal(x)",
3056
  "repeat(x)",
3057
  "repeat_back(x)",
3058
  "concat(x, y)",
 
3126
  "adamw(x)",
3127
  };
3128
 
3129
+ static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
3130
 
3131
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
3132
 
 
5187
  return result;
5188
  }
5189
 
5190
+ // ggml_count_equal
5191
+
5192
+ struct ggml_tensor * ggml_count_equal(
5193
+ struct ggml_context * ctx,
5194
+ struct ggml_tensor * a,
5195
+ struct ggml_tensor * b) {
5196
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5197
+
5198
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1);
5199
+
5200
+ result->op = GGML_OP_COUNT_EQUAL;
5201
+ result->src[0] = a;
5202
+ result->src[1] = b;
5203
+
5204
+ return result;
5205
+ }
5206
+
5207
  // ggml_repeat
5208
 
5209
  struct ggml_tensor * ggml_repeat(
 
10791
  }
10792
  }
10793
 
10794
+ // ggml_compute_forward_count_equal
10795
+
10796
+ static void ggml_compute_forward_count_equal_i32(
10797
+ const struct ggml_compute_params * params,
10798
+ struct ggml_tensor * dst) {
10799
+
10800
+ const struct ggml_tensor * src0 = dst->src[0];
10801
+ const struct ggml_tensor * src1 = dst->src[1];
10802
+
10803
+ GGML_TENSOR_BINARY_OP_LOCALS;
10804
+
10805
+ GGML_ASSERT(src0->type == GGML_TYPE_I32);
10806
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
10807
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
10808
+ GGML_ASSERT(ggml_is_scalar(dst));
10809
+ GGML_ASSERT(dst->type == GGML_TYPE_I64);
10810
+
10811
+ const int64_t nr = ggml_nrows(src0);
10812
+
10813
+ const int ith = params->ith;
10814
+ const int nth = params->nth;
10815
+
10816
+ int64_t * sums = (int64_t *) params->wdata;
10817
+ int64_t sum_thread = 0;
10818
+
10819
+ // rows per thread
10820
+ const int64_t dr = (nr + nth - 1)/nth;
10821
+
10822
+ // row range for this thread
10823
+ const int64_t ir0 = dr*ith;
10824
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10825
+
10826
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10827
+ const int64_t i03 = ir / (ne02*ne01);
10828
+ const int64_t i02 = (ir - i03*ne03) / ne01;
10829
+ const int64_t i01 = ir - i03*ne03 - i02*ne02;
10830
+
10831
+ const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
10832
+ const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
10833
+
10834
+ for (int64_t i00 = 0; i00 < ne00; ++i00) {
10835
+ const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
10836
+ const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
10837
+
10838
+ sum_thread += val0 == val1;
10839
+ }
10840
+ }
10841
+ if (ith != 0) {
10842
+ sums[ith] = sum_thread;
10843
+ }
10844
+ ggml_barrier(params->threadpool);
10845
+
10846
+ if (ith != 0) {
10847
+ return;
10848
+ }
10849
+
10850
+ for (int ith_other = 1; ith_other < nth; ++ith_other) {
10851
+ sum_thread += sums[ith_other];
10852
+ }
10853
+ *((int64_t *) dst->data) = sum_thread;
10854
+ }
10855
+
10856
+ static void ggml_compute_forward_count_equal(
10857
+ const struct ggml_compute_params * params,
10858
+ struct ggml_tensor * dst) {
10859
+
10860
+ const struct ggml_tensor * src0 = dst->src[0];
10861
+
10862
+ switch (src0->type) {
10863
+ case GGML_TYPE_I32:
10864
+ {
10865
+ ggml_compute_forward_count_equal_i32(params, dst);
10866
+ } break;
10867
+ default:
10868
+ {
10869
+ GGML_ABORT("fatal error");
10870
+ }
10871
+ }
10872
+ }
10873
+
10874
  // ggml_compute_forward_repeat
10875
 
10876
  static void ggml_compute_forward_repeat_f32(
 
17245
  {
17246
  ggml_compute_forward_argmax(params, tensor);
17247
  } break;
17248
+ case GGML_OP_COUNT_EQUAL:
17249
+ {
17250
+ ggml_compute_forward_count_equal(params, tensor);
17251
+ } break;
17252
  case GGML_OP_REPEAT:
17253
  {
17254
  ggml_compute_forward_repeat(params, tensor);
 
17999
  } break;
18000
  case GGML_OP_MEAN:
18001
  case GGML_OP_ARGMAX:
18002
+ case GGML_OP_COUNT_EQUAL:
18003
  {
18004
  GGML_ABORT("fatal error"); // TODO: implement
18005
  }
 
18773
  for (int i = 0; i < gf->n_nodes; ++i) {
18774
  struct ggml_tensor * node = gf->nodes[i];
18775
 
18776
+ if (node->type == GGML_TYPE_I32) {
18777
+ continue;
18778
+ }
18779
+
18780
  bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
18781
  bool ignore_src[GGML_MAX_SRC] = {false};
18782
  switch (node->op) {
 
19180
  case GGML_OP_SUM_ROWS:
19181
  case GGML_OP_MEAN:
19182
  case GGML_OP_ARGMAX:
19183
+ {
19184
+ n_tasks = 1;
19185
+ } break;
19186
+ case GGML_OP_COUNT_EQUAL:
19187
+ {
19188
+ n_tasks = n_threads;
19189
+ } break;
19190
  case GGML_OP_REPEAT:
19191
  case GGML_OP_REPEAT_BACK:
19192
  case GGML_OP_LEAKY_RELU:
 
19685
  cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
19686
  }
19687
  } break;
19688
+ case GGML_OP_COUNT_EQUAL:
19689
+ {
19690
+ cur = ggml_type_size(node->type)*n_tasks;
19691
+ } break;
19692
  case GGML_OP_MUL_MAT:
19693
  {
19694
  const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;