|
|
import json |
|
|
import base64 |
|
|
from offline_utils import packed_bytes_to_pseudo |
|
|
|
|
|
def pack_compressed_spans(data, bits_per_compressed: int, compression_bit_threshold: int, compression_offset: int = 256): |
|
|
""" |
|
|
Convert consecutive compressed values into larger integers. |
|
|
|
|
|
Args: |
|
|
data: List of integers where 0-255 are raw bytes, compression_offset+ are compressed bytes |
|
|
bits_per_compressed: Number of bits to use for each packed value |
|
|
compression_bit_threshold: Number of bits each compressed value actually uses |
|
|
compression_offset: Offset that marks start of compressed values (default 256) |
|
|
|
|
|
Returns: |
|
|
List with consecutive compressed spans packed into larger integers |
|
|
""" |
|
|
if not data: |
|
|
return [] |
|
|
|
|
|
result = [] |
|
|
i = 0 |
|
|
|
|
|
assert compression_bit_threshold % bits_per_compressed == 0, "compression_bit_threshold must be divisible by bits_per_compressed" |
|
|
packing_mask = (1 << bits_per_compressed) - 1 |
|
|
compression_mask = (1 << compression_bit_threshold) - 1 |
|
|
|
|
|
|
|
|
padded_compression_bit_threshold = ((compression_bit_threshold + 7) // 8) * 8 |
|
|
padded_mask = (1 << padded_compression_bit_threshold) - 1 |
|
|
|
|
|
padding_bits = padded_compression_bit_threshold - compression_bit_threshold |
|
|
|
|
|
while i < len(data): |
|
|
if data[i] >= compression_offset: |
|
|
|
|
|
span_start = i |
|
|
while i < len(data) and data[i] >= compression_offset: |
|
|
i += 1 |
|
|
|
|
|
|
|
|
compressed_span = data[span_start:i] |
|
|
|
|
|
base_values = [x - compression_offset for x in compressed_span] |
|
|
|
|
|
|
|
|
bit_buffer = 0 |
|
|
bits_in_buffer = 0 |
|
|
packed_values = [] |
|
|
|
|
|
for val in base_values: |
|
|
|
|
|
bit_buffer = (bit_buffer << 8) | val |
|
|
bits_in_buffer += 8 |
|
|
|
|
|
|
|
|
while bits_in_buffer >= padded_compression_bit_threshold: |
|
|
shift_amount = bits_in_buffer - padded_compression_bit_threshold |
|
|
padded_val = (bit_buffer >> shift_amount) & padded_mask |
|
|
|
|
|
|
|
|
bit_buffer &= (1 << shift_amount) - 1 |
|
|
bits_in_buffer -= padded_compression_bit_threshold |
|
|
|
|
|
|
|
|
extracted_val = (padded_val >> padding_bits) & compression_mask |
|
|
|
|
|
pack_buffer = extracted_val |
|
|
pack_bits = compression_bit_threshold |
|
|
|
|
|
|
|
|
while pack_bits >= bits_per_compressed: |
|
|
pack_shift = pack_bits - bits_per_compressed |
|
|
packed_val = (pack_buffer >> pack_shift) & packing_mask |
|
|
packed_values.append(packed_val + compression_offset) |
|
|
|
|
|
|
|
|
pack_buffer &= (1 << pack_shift) - 1 |
|
|
pack_bits -= bits_per_compressed |
|
|
|
|
|
assert bits_in_buffer == 0, "bits_in_buffer must be 0 after processing compressed span" |
|
|
assert pack_bits == 0, "pack_bits must be 0 after packing" |
|
|
|
|
|
result.extend(packed_values) |
|
|
else: |
|
|
|
|
|
result.append(data[i]) |
|
|
i += 1 |
|
|
|
|
|
return result |
|
|
|
|
|
def unpack_compressed_spans(packed_data, bits_per_compressed: int, compression_bit_threshold: int, compression_offset: int = 256): |
|
|
""" |
|
|
Reverse operation: unpack larger integers back to consecutive compressed bytes. |
|
|
|
|
|
Args: |
|
|
packed_data: List with packed compressed spans |
|
|
bits_per_compressed: Number of bits used for packing |
|
|
compression_bit_threshold: Number of bits each compressed value actually uses |
|
|
compression_offset: Offset used for compressed values |
|
|
|
|
|
Returns: |
|
|
Original format with consecutive compressed bytes |
|
|
""" |
|
|
result = [] |
|
|
i = 0 |
|
|
|
|
|
|
|
|
padded_compression_bit_threshold = ((compression_bit_threshold + 7) // 8) * 8 |
|
|
padding_bits = padded_compression_bit_threshold - compression_bit_threshold |
|
|
|
|
|
while i < len(packed_data): |
|
|
if packed_data[i] >= compression_offset: |
|
|
|
|
|
span_start = i |
|
|
while i < len(packed_data) and packed_data[i] >= compression_offset: |
|
|
i += 1 |
|
|
|
|
|
packed_span = packed_data[span_start:i] |
|
|
base_values = [x - compression_offset for x in packed_span] |
|
|
|
|
|
|
|
|
unpacked_bytes = [] |
|
|
bit_buffer = 0 |
|
|
bits_in_buffer = 0 |
|
|
|
|
|
for val in base_values: |
|
|
|
|
|
bit_buffer = (bit_buffer << bits_per_compressed) | val |
|
|
bits_in_buffer += bits_per_compressed |
|
|
|
|
|
|
|
|
while bits_in_buffer >= compression_bit_threshold: |
|
|
|
|
|
shift_amount = bits_in_buffer - compression_bit_threshold |
|
|
compressed_val = (bit_buffer >> shift_amount) & ((1 << compression_bit_threshold) - 1) |
|
|
|
|
|
|
|
|
bit_buffer &= (1 << shift_amount) - 1 |
|
|
bits_in_buffer -= compression_bit_threshold |
|
|
|
|
|
|
|
|
padded_val = compressed_val << padding_bits |
|
|
|
|
|
|
|
|
bytes_needed = padded_compression_bit_threshold // 8 |
|
|
for byte_idx in range(bytes_needed): |
|
|
shift = (bytes_needed - 1 - byte_idx) * 8 |
|
|
byte_val = (padded_val >> shift) & 0xFF |
|
|
unpacked_bytes.append(byte_val + compression_offset) |
|
|
|
|
|
|
|
|
assert bits_in_buffer == 0, "bits_in_buffer must be 0 after unpacking compressed span" |
|
|
|
|
|
result.extend(unpacked_bytes) |
|
|
else: |
|
|
|
|
|
result.append(packed_data[i]) |
|
|
i += 1 |
|
|
|
|
|
return result |
|
|
|
|
|
def run_test_case(test_name: str, data: list, bits_per_compressed: int, compression_bit_threshold: int): |
|
|
"""Run a single test case with comprehensive validation.""" |
|
|
print(f"π§ͺ {test_name}") |
|
|
print(f" Original: {data}") |
|
|
print(f" Config: bits_per_compressed={bits_per_compressed}, compression_bit_threshold={compression_bit_threshold}") |
|
|
|
|
|
try: |
|
|
|
|
|
packed = pack_compressed_spans(data, bits_per_compressed, compression_bit_threshold) |
|
|
print(f" Packed: {packed}") |
|
|
|
|
|
|
|
|
unpacked = unpack_compressed_spans(packed, bits_per_compressed, compression_bit_threshold) |
|
|
print(f" Unpacked: {unpacked}") |
|
|
|
|
|
|
|
|
success = data == unpacked |
|
|
print(f" Result: {'β
PASS' if success else 'β FAIL'}") |
|
|
|
|
|
|
|
|
original_compressed = len([x for x in data if x >= 256]) |
|
|
packed_compressed = len([x for x in packed if x >= 256]) |
|
|
if original_compressed > 0: |
|
|
ratio = original_compressed / packed_compressed if packed_compressed > 0 else 0 |
|
|
print(f" Stats: {original_compressed} β {packed_compressed} compressed values ({ratio:.2f}x)") |
|
|
|
|
|
return success |
|
|
|
|
|
except Exception as e: |
|
|
print(f" Result: β ERROR: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_packing_comprehensive(): |
|
|
from m1_compression import utils |
|
|
import random |
|
|
def random_bytes_generator(n: int, bit_threshold: int): |
|
|
ret = [] |
|
|
length = random.randint(n // 2, n) |
|
|
for _ in range(length): |
|
|
bits = "" |
|
|
for _ in range(bit_threshold): |
|
|
bits += "0" if random.random() < 0.5 else "1" |
|
|
compressed_bytes, _ = utils.bits_to_bytes_padding_to_threshold(bits, bit_threshold) |
|
|
ret.extend([c + 256 for c in list(compressed_bytes)]) |
|
|
ret.extend([random.randint(0, 255)]) |
|
|
return ret |
|
|
|
|
|
"""Comprehensive test suite for packing functions.""" |
|
|
print("=" * 60) |
|
|
print("π COMPREHENSIVE PACKING TESTS") |
|
|
print("=" * 60) |
|
|
|
|
|
test_results = [] |
|
|
|
|
|
|
|
|
test_results.append(run_test_case( |
|
|
"Basic 16-bit packing (no padding)", |
|
|
random_bytes_generator(100, 16), |
|
|
bits_per_compressed=16, |
|
|
compression_bit_threshold=16 |
|
|
)) |
|
|
print() |
|
|
|
|
|
|
|
|
test_results.append(run_test_case( |
|
|
"12-bit values with 4-bit padding", |
|
|
random_bytes_generator(100, 12), |
|
|
bits_per_compressed=12, |
|
|
compression_bit_threshold=12 |
|
|
)) |
|
|
print() |
|
|
|
|
|
|
|
|
test_results.append(run_test_case( |
|
|
"20-bit values with 4-bit padding", |
|
|
random_bytes_generator(100, 20), |
|
|
bits_per_compressed=20, |
|
|
compression_bit_threshold=20 |
|
|
)) |
|
|
print() |
|
|
|
|
|
|
|
|
test_results.append(run_test_case( |
|
|
"Single compressed byte", |
|
|
[100, 256, 200], |
|
|
bits_per_compressed=8, |
|
|
compression_bit_threshold=8 |
|
|
)) |
|
|
print() |
|
|
|
|
|
|
|
|
test_results.append(run_test_case( |
|
|
"No compressed bytes", |
|
|
[100, 200, 50, 150], |
|
|
bits_per_compressed=16, |
|
|
compression_bit_threshold=16 |
|
|
)) |
|
|
print() |
|
|
|
|
|
|
|
|
test_results.append(run_test_case( |
|
|
"All compressed bytes", |
|
|
[256, 257, 258, 259, 260, 261], |
|
|
bits_per_compressed=8, |
|
|
compression_bit_threshold=8 |
|
|
)) |
|
|
print() |
|
|
|
|
|
|
|
|
test_results.append(run_test_case( |
|
|
"24-bit to 12-bit packing (2:1 ratio)", |
|
|
random_bytes_generator(100, 24), |
|
|
bits_per_compressed=12, |
|
|
compression_bit_threshold=24 |
|
|
)) |
|
|
print() |
|
|
|
|
|
|
|
|
passed = sum(test_results) |
|
|
total = len(test_results) |
|
|
print("=" * 60) |
|
|
print(f"π TEST SUMMARY: {passed}/{total} tests passed") |
|
|
print("=" * 60) |
|
|
|
|
|
if passed == total: |
|
|
print("π All tests passed! The implementation is working correctly.") |
|
|
else: |
|
|
print("β οΈ Some tests failed. Please review the implementation.") |
|
|
|
|
|
return passed, total |
|
|
|
|
|
def test_real_data(): |
|
|
print("=" * 40) |
|
|
print("π§ REAL DATA TESTS") |
|
|
print("=" * 40) |
|
|
|
|
|
key = "m1_ac_ow20_escapefb-False_iterative-True" |
|
|
|
|
|
with open("output_compress/m1.chunk.0_out_0_out_0_writer_0.jsonl", "r") as f: |
|
|
for i, line in enumerate(f): |
|
|
data = json.loads(line) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key_splits = key.split("_") |
|
|
bit_threshold = None |
|
|
for k in key_splits: |
|
|
if k.startswith( |
|
|
|
|
|
): |
|
|
bit_threshold = int(k[len("ow"):]) |
|
|
break |
|
|
assert bit_threshold is not None |
|
|
print(f"Bit threshold: {bit_threshold}") |
|
|
|
|
|
|
|
|
original_bytes_array = packed_bytes_to_pseudo(base64.b64decode(data[key])) |
|
|
|
|
|
run_test_case( |
|
|
f"Packing {bit_threshold}-bit values", |
|
|
original_bytes_array, |
|
|
10, |
|
|
bit_threshold |
|
|
) |
|
|
|
|
|
if i > 4: |
|
|
break |
|
|
|
|
|
|
|
|
def test_error_conditions(): |
|
|
"""Test error conditions and edge cases.""" |
|
|
print("\nπ§ ERROR CONDITION TESTS") |
|
|
print("=" * 40) |
|
|
|
|
|
|
|
|
try: |
|
|
pack_compressed_spans([256, 257], 10, 15) |
|
|
print("β Should have failed on invalid bit alignment") |
|
|
except AssertionError: |
|
|
print("β
Correctly caught invalid bit alignment") |
|
|
|
|
|
|
|
|
result = pack_compressed_spans([], 16, 16) |
|
|
print(f"β
Empty data handling: {result == []}") |
|
|
|
|
|
print() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
passed, total = test_packing_comprehensive() |
|
|
|
|
|
|
|
|
if passed == total: |
|
|
print("π ALL TESTS COMPLETED SUCCESSFULLY!") |
|
|
else: |
|
|
print("π₯ SOME TESTS FAILED!") |
|
|
|
|
|
test_real_data() |
|
|
|
|
|
|
|
|
test_error_conditions() |
|
|
|