Engininja2 commited on
Commit
aadbd67
·
unverified ·
1 Parent(s): 20a4ca1

cuda : fix 2-bit quants on amd hip (llama/5105)

Browse files

* cuda : fix 2-bit quants on amd hip

* use __low2float intrinsic function for new quants

Files changed (1) hide show
  1. ggml-cuda.cu +3 -3
ggml-cuda.cu CHANGED
@@ -4283,7 +4283,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
4283
  q8 += 8;
4284
  aux32 >>= 7;
4285
  }
4286
- const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
4287
  return d * sumi;
4288
  #else
4289
  // iqs is 0...15
@@ -4294,7 +4294,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
4294
  const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4295
  const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4296
  const uint32_t aux32 = q2[2] | (q2[3] << 16);
4297
- const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
4298
  const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
4299
  const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
4300
  const int8_t * q8 = bq8_1[ib32].qs + 16*il;
@@ -4339,7 +4339,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
4339
  }
4340
  q8 += 8;
4341
  }
4342
- const float d = (float)bq2->d * (float)bq8_1[ib32].ds.x * 0.25f;
4343
  return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
4344
  #else
4345
  assert(false);
 
4283
  q8 += 8;
4284
  aux32 >>= 7;
4285
  }
4286
+ const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.25f;
4287
  return d * sumi;
4288
  #else
4289
  // iqs is 0...15
 
4294
  const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4295
  const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4296
  const uint32_t aux32 = q2[2] | (q2[3] << 16);
4297
+ const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * __low2float(bq8_1[ib32].ds) * 0.25f;
4298
  const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
4299
  const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
4300
  const int8_t * q8 = bq8_1[ib32].qs + 16*il;
 
4339
  }
4340
  q8 += 8;
4341
  }
4342
+ const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
4343
  return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
4344
  #else
4345
  assert(false);