Byte-lingua-code / offline_packing.py
2ira's picture
offline_compression_graph_code
72c0672 verified
import json
import base64
from offline_utils import packed_bytes_to_pseudo
def pack_compressed_spans(data, bits_per_compressed: int, compression_bit_threshold: int, compression_offset: int = 256):
"""
Convert consecutive compressed values into larger integers.
Args:
data: List of integers where 0-255 are raw bytes, compression_offset+ are compressed bytes
bits_per_compressed: Number of bits to use for each packed value
compression_bit_threshold: Number of bits each compressed value actually uses
compression_offset: Offset that marks start of compressed values (default 256)
Returns:
List with consecutive compressed spans packed into larger integers
"""
if not data:
return []
result = []
i = 0
assert compression_bit_threshold % bits_per_compressed == 0, "compression_bit_threshold must be divisible by bits_per_compressed"
packing_mask = (1 << bits_per_compressed) - 1
compression_mask = (1 << compression_bit_threshold) - 1
# Calculate byte-aligned padded size
padded_compression_bit_threshold = ((compression_bit_threshold + 7) // 8) * 8
padded_mask = (1 << padded_compression_bit_threshold) - 1
padding_bits = padded_compression_bit_threshold - compression_bit_threshold
while i < len(data):
if data[i] >= compression_offset:
# Find the end of consecutive compressed bytes
span_start = i
while i < len(data) and data[i] >= compression_offset:
i += 1
# Extract the span of compressed bytes
compressed_span = data[span_start:i]
base_values = [x - compression_offset for x in compressed_span]
# Process bytes incrementally to avoid large numbers
bit_buffer = 0
bits_in_buffer = 0
packed_values = []
for val in base_values:
# Add this byte to bit buffer
bit_buffer = (bit_buffer << 8) | val
bits_in_buffer += 8
# Extract padded chunks as soon as we have enough bits
while bits_in_buffer >= padded_compression_bit_threshold:
shift_amount = bits_in_buffer - padded_compression_bit_threshold
padded_val = (bit_buffer >> shift_amount) & padded_mask
# Remove the extracted bits from buffer
bit_buffer &= (1 << shift_amount) - 1
bits_in_buffer -= padded_compression_bit_threshold
# Strip padding by extracting only the meaningful bits
extracted_val = (padded_val >> padding_bits) & compression_mask
pack_buffer = extracted_val
pack_bits = compression_bit_threshold
# Pack values as soon as we have enough bits
while pack_bits >= bits_per_compressed:
pack_shift = pack_bits - bits_per_compressed
packed_val = (pack_buffer >> pack_shift) & packing_mask
packed_values.append(packed_val + compression_offset)
# Remove packed bits from pack buffer
pack_buffer &= (1 << pack_shift) - 1
pack_bits -= bits_per_compressed
assert bits_in_buffer == 0, "bits_in_buffer must be 0 after processing compressed span"
assert pack_bits == 0, "pack_bits must be 0 after packing"
result.extend(packed_values)
else:
# Raw byte (0-255), keep as is
result.append(data[i])
i += 1
return result
def unpack_compressed_spans(packed_data, bits_per_compressed: int, compression_bit_threshold: int, compression_offset: int = 256):
"""
Reverse operation: unpack larger integers back to consecutive compressed bytes.
Args:
packed_data: List with packed compressed spans
bits_per_compressed: Number of bits used for packing
compression_bit_threshold: Number of bits each compressed value actually uses
compression_offset: Offset used for compressed values
Returns:
Original format with consecutive compressed bytes
"""
result = []
i = 0
# Calculate byte-aligned padded size
padded_compression_bit_threshold = ((compression_bit_threshold + 7) // 8) * 8
padding_bits = padded_compression_bit_threshold - compression_bit_threshold
while i < len(packed_data):
if packed_data[i] >= compression_offset: # Start of compressed span
# Find consecutive packed values
span_start = i
while i < len(packed_data) and packed_data[i] >= compression_offset:
i += 1
packed_span = packed_data[span_start:i]
base_values = [x - compression_offset for x in packed_span]
# Unpack using two-phase process to handle padding
unpacked_bytes = []
bit_buffer = 0
bits_in_buffer = 0
for val in base_values:
# Add this packed value to our bit buffer
bit_buffer = (bit_buffer << bits_per_compressed) | val
bits_in_buffer += bits_per_compressed
# Extract compression_bit_threshold values as soon as we have enough bits
while bits_in_buffer >= compression_bit_threshold:
# Extract the top compression_bit_threshold bits
shift_amount = bits_in_buffer - compression_bit_threshold
compressed_val = (bit_buffer >> shift_amount) & ((1 << compression_bit_threshold) - 1)
# Remove the extracted bits from buffer
bit_buffer &= (1 << shift_amount) - 1
bits_in_buffer -= compression_bit_threshold
# Add padding back to make it byte-aligned
padded_val = compressed_val << padding_bits
# Convert padded value back to bytes
bytes_needed = padded_compression_bit_threshold // 8
for byte_idx in range(bytes_needed):
shift = (bytes_needed - 1 - byte_idx) * 8
byte_val = (padded_val >> shift) & 0xFF
unpacked_bytes.append(byte_val + compression_offset)
# Verify all bits were processed cleanly
assert bits_in_buffer == 0, "bits_in_buffer must be 0 after unpacking compressed span"
result.extend(unpacked_bytes)
else:
# Raw byte, keep as is
result.append(packed_data[i])
i += 1
return result
def run_test_case(test_name: str, data: list, bits_per_compressed: int, compression_bit_threshold: int):
"""Run a single test case with comprehensive validation."""
print(f"πŸ§ͺ {test_name}")
print(f" Original: {data}")
print(f" Config: bits_per_compressed={bits_per_compressed}, compression_bit_threshold={compression_bit_threshold}")
try:
# Test packing
packed = pack_compressed_spans(data, bits_per_compressed, compression_bit_threshold)
print(f" Packed: {packed}")
# Test unpacking
unpacked = unpack_compressed_spans(packed, bits_per_compressed, compression_bit_threshold)
print(f" Unpacked: {unpacked}")
# Verify round-trip
success = data == unpacked
print(f" Result: {'βœ… PASS' if success else '❌ FAIL'}")
# Show compression stats
original_compressed = len([x for x in data if x >= 256])
packed_compressed = len([x for x in packed if x >= 256])
if original_compressed > 0:
ratio = original_compressed / packed_compressed if packed_compressed > 0 else 0
print(f" Stats: {original_compressed} β†’ {packed_compressed} compressed values ({ratio:.2f}x)")
return success
except Exception as e:
print(f" Result: ❌ ERROR: {e}")
return False
def test_packing_comprehensive():
from m1_compression import utils
import random
def random_bytes_generator(n: int, bit_threshold: int):
ret = []
length = random.randint(n // 2, n)
for _ in range(length):
bits = ""
for _ in range(bit_threshold):
bits += "0" if random.random() < 0.5 else "1"
compressed_bytes, _ = utils.bits_to_bytes_padding_to_threshold(bits, bit_threshold)
ret.extend([c + 256 for c in list(compressed_bytes)])
ret.extend([random.randint(0, 255)])
return ret
"""Comprehensive test suite for packing functions."""
print("=" * 60)
print("πŸš€ COMPREHENSIVE PACKING TESTS")
print("=" * 60)
test_results = []
# Test 1: Basic functionality - 16-bit alignment (no padding)
test_results.append(run_test_case(
"Basic 16-bit packing (no padding)",
random_bytes_generator(100, 16),
bits_per_compressed=16,
compression_bit_threshold=16
))
print()
# Test 2: 12-bit values with padding (12 bits stored in 16 bits)
test_results.append(run_test_case(
"12-bit values with 4-bit padding",
random_bytes_generator(100, 12),
bits_per_compressed=12,
compression_bit_threshold=12
))
print()
# Test 3: 20-bit values with padding (20 bits stored in 24 bits)
test_results.append(run_test_case(
"20-bit values with 4-bit padding",
random_bytes_generator(100, 20),
bits_per_compressed=20,
compression_bit_threshold=20
))
print()
# Test 5: Edge case - single compressed byte
test_results.append(run_test_case(
"Single compressed byte",
[100, 256, 200],
bits_per_compressed=8,
compression_bit_threshold=8
))
print()
# Test 6: Edge case - no compressed bytes
test_results.append(run_test_case(
"No compressed bytes",
[100, 200, 50, 150],
bits_per_compressed=16,
compression_bit_threshold=16
))
print()
# Test 7: Edge case - all compressed bytes
test_results.append(run_test_case(
"All compressed bytes",
[256, 257, 258, 259, 260, 261],
bits_per_compressed=8,
compression_bit_threshold=8
))
print()
# Test 8: Mixed compression ratios
test_results.append(run_test_case(
"24-bit to 12-bit packing (2:1 ratio)",
random_bytes_generator(100, 24),
bits_per_compressed=12,
compression_bit_threshold=24
))
print()
# Summary
passed = sum(test_results)
total = len(test_results)
print("=" * 60)
print(f"πŸ“Š TEST SUMMARY: {passed}/{total} tests passed")
print("=" * 60)
if passed == total:
print("πŸŽ‰ All tests passed! The implementation is working correctly.")
else:
print("⚠️ Some tests failed. Please review the implementation.")
return passed, total
def test_real_data():
print("=" * 40)
print("πŸ”§ REAL DATA TESTS")
print("=" * 40)
key = "m1_ac_ow20_escapefb-False_iterative-True"
with open("output_compress/m1.chunk.0_out_0_out_0_writer_0.jsonl", "r") as f:
for i, line in enumerate(f):
data = json.loads(line)
# NOTE: for visualization purposes, we replace values > 256 with byte '_'
# bytes_array = packed_bytes_to_pseudo(base64.b64decode(data[key]))
# bytes_array = [b if b < 256 else ord('_') for b in bytes_array]
# bytes_string = bytes(bytes_array).decode("utf-8", errors="replace")
# print(bytes_string)
# extract bit_threshold in key
key_splits = key.split("_")
bit_threshold = None
for k in key_splits:
if k.startswith(
):
bit_threshold = int(k[len("ow"):])
break
assert bit_threshold is not None
print(f"Bit threshold: {bit_threshold}")
# NEW: Apply the packing function to the original bytes_array (before replacement)
original_bytes_array = packed_bytes_to_pseudo(base64.b64decode(data[key]))
run_test_case(
f"Packing {bit_threshold}-bit values",
original_bytes_array,
10,
bit_threshold
)
if i > 4:
break
def test_error_conditions():
"""Test error conditions and edge cases."""
print("\nπŸ”§ ERROR CONDITION TESTS")
print("=" * 40)
# Test invalid bit alignment
try:
pack_compressed_spans([256, 257], 10, 15) # 15 % 10 != 0
print("❌ Should have failed on invalid bit alignment")
except AssertionError:
print("βœ… Correctly caught invalid bit alignment")
# Test empty data
result = pack_compressed_spans([], 16, 16)
print(f"βœ… Empty data handling: {result == []}")
print()
if __name__ == "__main__":
# Run comprehensive tests
passed, total = test_packing_comprehensive()
# Final result
if passed == total:
print("πŸ† ALL TESTS COMPLETED SUCCESSFULLY!")
else:
print("πŸ’₯ SOME TESTS FAILED!")
test_real_data()
# Run error condition tests
test_error_conditions()