Spaces:
Running
Running
Engininja2
commited on
cuda : fix 2-bit quants on amd hip (llama/5105)
Browse files* cuda : fix 2-bit quants on amd hip
* use __low2float intrinsic function for new quants
- ggml-cuda.cu +3 -3
ggml-cuda.cu
CHANGED
|
@@ -4283,7 +4283,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|
| 4283 |
q8 += 8;
|
| 4284 |
aux32 >>= 7;
|
| 4285 |
}
|
| 4286 |
-
const float d = (float)bq2->d * (0.5f + aux32) * (
|
| 4287 |
return d * sumi;
|
| 4288 |
#else
|
| 4289 |
// iqs is 0...15
|
|
@@ -4294,7 +4294,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|
| 4294 |
const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
| 4295 |
const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
| 4296 |
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 4297 |
-
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (
|
| 4298 |
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
| 4299 |
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
|
| 4300 |
const int8_t * q8 = bq8_1[ib32].qs + 16*il;
|
|
@@ -4339,7 +4339,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
|
| 4339 |
}
|
| 4340 |
q8 += 8;
|
| 4341 |
}
|
| 4342 |
-
const float d = (float)bq2->d * (
|
| 4343 |
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
| 4344 |
#else
|
| 4345 |
assert(false);
|
|
|
|
| 4283 |
q8 += 8;
|
| 4284 |
aux32 >>= 7;
|
| 4285 |
}
|
| 4286 |
+
const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.25f;
|
| 4287 |
return d * sumi;
|
| 4288 |
#else
|
| 4289 |
// iqs is 0...15
|
|
|
|
| 4294 |
const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
| 4295 |
const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
| 4296 |
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 4297 |
+
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * __low2float(bq8_1[ib32].ds) * 0.25f;
|
| 4298 |
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
| 4299 |
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
|
| 4300 |
const int8_t * q8 = bq8_1[ib32].qs + 16*il;
|
|
|
|
| 4339 |
}
|
| 4340 |
q8 += 8;
|
| 4341 |
}
|
| 4342 |
+
const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
|
| 4343 |
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
| 4344 |
#else
|
| 4345 |
assert(false);
|