Srihari-mcw commited on
Commit
cf52931
·
1 Parent(s): 2bfeba3

Add provisions for windows support for BF16 code including CMake provision for enabling AVX512_BF16 (llama/7258)

Browse files
Files changed (3) hide show
  1. ggml-impl.h +12 -0
  2. ggml.c +16 -8
  3. ggml.h +1 -0
ggml-impl.h CHANGED
@@ -17,6 +17,18 @@
17
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
18
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
19
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  /**
21
  * Converts brain16 to float32.
22
  *
 
17
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
18
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
19
 
20
+ #if defined(_WIN32)
21
+
22
+ #define m512bh(p) p
23
+ #define m512i(p) p
24
+
25
+ #else
26
+
27
+ #define m512bh(p) (__m512bh)(p)
28
+ #define m512i(p) (__m512i)(p)
29
+
30
+ #endif
31
+
32
  /**
33
  * Converts brain16 to float32.
34
  *
ggml.c CHANGED
@@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
406
  int i = 0;
407
  #if defined(__AVX512BF16__)
408
  for (; i + 32 <= n; i += 32) {
409
- _mm512_storeu_ps(
410
- (__m512 *)(y + i),
411
- (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
- _mm512_loadu_ps(x + i)));
413
  }
414
  #endif
415
  for (; i < n; i++) {
@@ -1666,10 +1666,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
1666
  __m512 c1 = _mm512_setzero_ps();
1667
  __m512 c2 = _mm512_setzero_ps();
1668
  for (; i + 64 <= n; i += 64) {
1669
- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1670
- (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1671
- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1672
- (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1673
  }
1674
  sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1675
  sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -23137,6 +23137,14 @@ int ggml_cpu_has_avx512_vnni(void) {
23137
  #endif
23138
  }
23139
 
 
 
 
 
 
 
 
 
23140
  int ggml_cpu_has_fma(void) {
23141
  #if defined(__FMA__)
23142
  return 1;
 
406
  int i = 0;
407
  #if defined(__AVX512BF16__)
408
  for (; i + 32 <= n; i += 32) {
409
+ _mm512_storeu_si512(
410
+ (__m512i *)(y + i),
411
+ m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
+ _mm512_loadu_ps(x + i))));
413
  }
414
  #endif
415
  for (; i < n; i++) {
 
1666
  __m512 c1 = _mm512_setzero_ps();
1667
  __m512 c2 = _mm512_setzero_ps();
1668
  for (; i + 64 <= n; i += 64) {
1669
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
1670
+ m512bh(_mm512_loadu_si512((y + i))));
1671
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
1672
+ m512bh(_mm512_loadu_si512((y + i + 32))));
1673
  }
1674
  sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1675
  sumf += (ggml_float)_mm512_reduce_add_ps(c2);
 
23137
  #endif
23138
  }
23139
 
23140
+ int ggml_cpu_has_avx512_bf16(void) {
23141
+ #if defined(__AVX512BF16__)
23142
+ return 1;
23143
+ #else
23144
+ return 0;
23145
+ #endif
23146
+ }
23147
+
23148
  int ggml_cpu_has_fma(void) {
23149
  #if defined(__FMA__)
23150
  return 1;
ggml.h CHANGED
@@ -2390,6 +2390,7 @@ extern "C" {
2390
  GGML_API int ggml_cpu_has_avx512 (void);
2391
  GGML_API int ggml_cpu_has_avx512_vbmi(void);
2392
  GGML_API int ggml_cpu_has_avx512_vnni(void);
 
2393
  GGML_API int ggml_cpu_has_fma (void);
2394
  GGML_API int ggml_cpu_has_neon (void);
2395
  GGML_API int ggml_cpu_has_arm_fma (void);
 
2390
  GGML_API int ggml_cpu_has_avx512 (void);
2391
  GGML_API int ggml_cpu_has_avx512_vbmi(void);
2392
  GGML_API int ggml_cpu_has_avx512_vnni(void);
2393
+ GGML_API int ggml_cpu_has_avx512_bf16(void);
2394
  GGML_API int ggml_cpu_has_fma (void);
2395
  GGML_API int ggml_cpu_has_neon (void);
2396
  GGML_API int ggml_cpu_has_arm_fma (void);