Spaces:
Running
Running
AidanBeltonS
commited on
Commit
·
3984ba6
1
Parent(s):
8325ed5
Update SYCL upscale operation (llama/7321)
Browse files* Update SYCL upscale operation
* Formatting
* Remove messages
- ggml-sycl.cpp +35 -30
ggml-sycl.cpp
CHANGED
|
@@ -3847,21 +3847,27 @@ static void concat_f32(const float *x,const float *y, float *dst, const int ne
|
|
| 3847 |
}
|
| 3848 |
}
|
| 3849 |
|
| 3850 |
-
static void upscale_f32(const float *x, float *dst, const int
|
| 3851 |
-
const
|
| 3852 |
-
|
| 3853 |
-
|
| 3854 |
-
|
| 3855 |
-
|
|
|
|
| 3856 |
return;
|
| 3857 |
}
|
| 3858 |
// operation
|
| 3859 |
-
int
|
| 3860 |
-
int
|
| 3861 |
-
int
|
| 3862 |
-
int
|
| 3863 |
-
|
| 3864 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3865 |
}
|
| 3866 |
|
| 3867 |
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
|
@@ -10085,18 +10091,17 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
|
| 10085 |
});
|
| 10086 |
}
|
| 10087 |
|
| 10088 |
-
static void upscale_f32_sycl(const float *x, float *dst, const int
|
| 10089 |
-
const int
|
| 10090 |
-
const int
|
| 10091 |
-
|
| 10092 |
-
int
|
| 10093 |
-
|
|
|
|
| 10094 |
stream->parallel_for(
|
| 10095 |
-
sycl::nd_range<
|
| 10096 |
-
|
| 10097 |
-
|
| 10098 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 10099 |
-
upscale_f32(x, dst, ne00, ne00 * ne01, scale_factor, item_ct1);
|
| 10100 |
});
|
| 10101 |
}
|
| 10102 |
|
|
@@ -13985,15 +13990,15 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0,
|
|
| 13985 |
|
| 13986 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 13987 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 13988 |
-
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
| 13989 |
-
|
| 13990 |
-
#pragma message("TODO: generalize upscale operator")
|
| 13991 |
-
#pragma message(" https://github.com/ggerganov/ggml/pull/814")
|
| 13992 |
-
GGML_ASSERT(false && "TODO: generalize upscale operator");
|
| 13993 |
|
| 13994 |
-
const
|
|
|
|
|
|
|
|
|
|
| 13995 |
|
| 13996 |
-
upscale_f32_sycl(src0_dd, dst_dd, src0->
|
|
|
|
|
|
|
| 13997 |
|
| 13998 |
(void) src1;
|
| 13999 |
(void) dst;
|
|
|
|
| 3847 |
}
|
| 3848 |
}
|
| 3849 |
|
| 3850 |
+
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
| 3851 |
+
const int nb02, const int nb03, const int ne10, const int ne11,
|
| 3852 |
+
const int ne12, const int ne13, const float sf0, const float sf1,
|
| 3853 |
+
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
| 3854 |
+
int index = item_ct1.get_local_id(0) +
|
| 3855 |
+
item_ct1.get_group(0) * item_ct1.get_local_range(0);
|
| 3856 |
+
if (index >= ne10 * ne11 * ne12 * ne13) {
|
| 3857 |
return;
|
| 3858 |
}
|
| 3859 |
// operation
|
| 3860 |
+
int i10 = index % ne10;
|
| 3861 |
+
int i11 = (index / ne10) % ne11;
|
| 3862 |
+
int i12 = (index / (ne10 * ne11)) % ne12;
|
| 3863 |
+
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
|
| 3864 |
+
|
| 3865 |
+
int i00 = i10 / sf0;
|
| 3866 |
+
int i01 = i11 / sf1;
|
| 3867 |
+
int i02 = i12 / sf2;
|
| 3868 |
+
int i03 = i13 / sf3;
|
| 3869 |
+
|
| 3870 |
+
dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
| 3871 |
}
|
| 3872 |
|
| 3873 |
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
|
|
|
| 10091 |
});
|
| 10092 |
}
|
| 10093 |
|
| 10094 |
+
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
| 10095 |
+
const int nb02, const int nb03, const int ne10, const int ne11,
|
| 10096 |
+
const int ne12, const int ne13, const float sf0, const float sf1,
|
| 10097 |
+
const float sf2, const float sf3, dpct::queue_ptr stream) {
|
| 10098 |
+
int dst_size = ne10 * ne11 * ne12 * ne13;
|
| 10099 |
+
int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
|
| 10100 |
+
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
|
| 10101 |
stream->parallel_for(
|
| 10102 |
+
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
|
| 10103 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 10104 |
+
upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
|
|
|
|
|
|
|
| 10105 |
});
|
| 10106 |
}
|
| 10107 |
|
|
|
|
| 13990 |
|
| 13991 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 13992 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13993 |
|
| 13994 |
+
const float sf0 = (float)dst->ne[0]/src0->ne[0];
|
| 13995 |
+
const float sf1 = (float)dst->ne[1]/src0->ne[1];
|
| 13996 |
+
const float sf2 = (float)dst->ne[2]/src0->ne[2];
|
| 13997 |
+
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
| 13998 |
|
| 13999 |
+
upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
| 14000 |
+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
|
| 14001 |
+
main_stream);
|
| 14002 |
|
| 14003 |
(void) src1;
|
| 14004 |
(void) dst;
|