Spaces:
Running
Running
cuda : non-cont concat support (llama/7610)
Browse files* tests : add non-cont concat tests
* cuda : non-cont concat support
ggml-ci
- ggml-cuda/concat.cu +88 -22
ggml-cuda/concat.cu
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
#include "concat.cuh"
|
| 2 |
|
|
|
|
| 3 |
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
|
| 4 |
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 5 |
if (nidx >= ne0) {
|
|
@@ -92,39 +93,104 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
|
|
| 92 |
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
|
| 93 |
}
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 96 |
const ggml_tensor * src0 = dst->src[0];
|
| 97 |
const ggml_tensor * src1 = dst->src[1];
|
| 98 |
|
| 99 |
-
const float * src0_d = (const float *)src0->data;
|
| 100 |
-
const float * src1_d = (const float *)src1->data;
|
| 101 |
-
|
| 102 |
-
float * dst_d = (float *)dst->data;
|
| 103 |
cudaStream_t stream = ctx.stream();
|
| 104 |
|
| 105 |
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
| 106 |
|
| 107 |
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 108 |
-
GGML_ASSERT(ggml_is_contiguous(src1));
|
| 109 |
-
|
| 110 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 111 |
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 112 |
-
GGML_ASSERT(dst->type
|
| 113 |
-
|
| 114 |
-
if (
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
}
|
| 123 |
} else {
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
}
|
| 130 |
}
|
|
|
|
| 1 |
#include "concat.cuh"
|
| 2 |
|
| 3 |
+
// contiguous kernels
|
| 4 |
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
|
| 5 |
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 6 |
if (nidx >= ne0) {
|
|
|
|
| 93 |
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
|
| 94 |
}
|
| 95 |
|
| 96 |
+
// non-contiguous kernel (slow)
|
| 97 |
+
static __global__ void concat_f32_non_cont(
|
| 98 |
+
const char * src0,
|
| 99 |
+
const char * src1,
|
| 100 |
+
char * dst,
|
| 101 |
+
int64_t ne00,
|
| 102 |
+
int64_t ne01,
|
| 103 |
+
int64_t ne02,
|
| 104 |
+
int64_t ne03,
|
| 105 |
+
uint64_t nb00,
|
| 106 |
+
uint64_t nb01,
|
| 107 |
+
uint64_t nb02,
|
| 108 |
+
uint64_t nb03,
|
| 109 |
+
int64_t /*ne10*/,
|
| 110 |
+
int64_t /*ne11*/,
|
| 111 |
+
int64_t /*ne12*/,
|
| 112 |
+
int64_t /*ne13*/,
|
| 113 |
+
uint64_t nb10,
|
| 114 |
+
uint64_t nb11,
|
| 115 |
+
uint64_t nb12,
|
| 116 |
+
uint64_t nb13,
|
| 117 |
+
int64_t ne0,
|
| 118 |
+
int64_t /*ne1*/,
|
| 119 |
+
int64_t /*ne2*/,
|
| 120 |
+
int64_t /*ne3*/,
|
| 121 |
+
uint64_t nb0,
|
| 122 |
+
uint64_t nb1,
|
| 123 |
+
uint64_t nb2,
|
| 124 |
+
uint64_t nb3,
|
| 125 |
+
int32_t dim) {
|
| 126 |
+
const int64_t i3 = blockIdx.z;
|
| 127 |
+
const int64_t i2 = blockIdx.y;
|
| 128 |
+
const int64_t i1 = blockIdx.x;
|
| 129 |
+
|
| 130 |
+
int64_t o[4] = {0, 0, 0, 0};
|
| 131 |
+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
| 132 |
+
|
| 133 |
+
const float * x;
|
| 134 |
+
|
| 135 |
+
for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
| 136 |
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
| 137 |
+
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
|
| 138 |
+
} else {
|
| 139 |
+
x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 143 |
+
|
| 144 |
+
*y = *x;
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 150 |
const ggml_tensor * src0 = dst->src[0];
|
| 151 |
const ggml_tensor * src1 = dst->src[1];
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
cudaStream_t stream = ctx.stream();
|
| 154 |
|
| 155 |
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 158 |
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 159 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 160 |
+
|
| 161 |
+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
| 162 |
+
const float * src0_d = (const float *)src0->data;
|
| 163 |
+
const float * src1_d = (const float *)src1->data;
|
| 164 |
+
|
| 165 |
+
float * dst_d = (float *)dst->data;
|
| 166 |
+
|
| 167 |
+
if (dim != 3) {
|
| 168 |
+
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
| 169 |
+
concat_f32_cuda(
|
| 170 |
+
src0_d + i3 * (src0->nb[3] / 4),
|
| 171 |
+
src1_d + i3 * (src1->nb[3] / 4),
|
| 172 |
+
dst_d + i3 * ( dst->nb[3] / 4),
|
| 173 |
+
src0->ne[0], src0->ne[1], src0->ne[2],
|
| 174 |
+
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
|
| 175 |
+
}
|
| 176 |
+
} else {
|
| 177 |
+
const size_t size0 = ggml_nbytes(src0);
|
| 178 |
+
const size_t size1 = ggml_nbytes(src1);
|
| 179 |
+
|
| 180 |
+
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
|
| 181 |
+
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
|
| 182 |
}
|
| 183 |
} else {
|
| 184 |
+
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
|
| 185 |
+
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
|
| 186 |
+
(const char *)src0->data,
|
| 187 |
+
(const char *)src1->data,
|
| 188 |
+
( char *)dst->data,
|
| 189 |
+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
| 190 |
+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
| 191 |
+
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
| 192 |
+
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
|
| 193 |
+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
| 194 |
+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
|
| 195 |
}
|
| 196 |
}
|