import json import base64 from offline_utils import packed_bytes_to_pseudo def pack_compressed_spans(data, bits_per_compressed: int, compression_bit_threshold: int, compression_offset: int = 256): """ Convert consecutive compressed values into larger integers. Args: data: List of integers where 0-255 are raw bytes, compression_offset+ are compressed bytes bits_per_compressed: Number of bits to use for each packed value compression_bit_threshold: Number of bits each compressed value actually uses compression_offset: Offset that marks start of compressed values (default 256) Returns: List with consecutive compressed spans packed into larger integers """ if not data: return [] result = [] i = 0 assert compression_bit_threshold % bits_per_compressed == 0, "compression_bit_threshold must be divisible by bits_per_compressed" packing_mask = (1 << bits_per_compressed) - 1 compression_mask = (1 << compression_bit_threshold) - 1 # Calculate byte-aligned padded size padded_compression_bit_threshold = ((compression_bit_threshold + 7) // 8) * 8 padded_mask = (1 << padded_compression_bit_threshold) - 1 padding_bits = padded_compression_bit_threshold - compression_bit_threshold while i < len(data): if data[i] >= compression_offset: # Find the end of consecutive compressed bytes span_start = i while i < len(data) and data[i] >= compression_offset: i += 1 # Extract the span of compressed bytes compressed_span = data[span_start:i] base_values = [x - compression_offset for x in compressed_span] # Process bytes incrementally to avoid large numbers bit_buffer = 0 bits_in_buffer = 0 packed_values = [] for val in base_values: # Add this byte to bit buffer bit_buffer = (bit_buffer << 8) | val bits_in_buffer += 8 # Extract padded chunks as soon as we have enough bits while bits_in_buffer >= padded_compression_bit_threshold: shift_amount = bits_in_buffer - padded_compression_bit_threshold padded_val = (bit_buffer >> shift_amount) & padded_mask # Remove the extracted bits from buffer bit_buffer &= (1 << shift_amount) - 1 bits_in_buffer -= padded_compression_bit_threshold # Strip padding by extracting only the meaningful bits extracted_val = (padded_val >> padding_bits) & compression_mask pack_buffer = extracted_val pack_bits = compression_bit_threshold # Pack values as soon as we have enough bits while pack_bits >= bits_per_compressed: pack_shift = pack_bits - bits_per_compressed packed_val = (pack_buffer >> pack_shift) & packing_mask packed_values.append(packed_val + compression_offset) # Remove packed bits from pack buffer pack_buffer &= (1 << pack_shift) - 1 pack_bits -= bits_per_compressed assert bits_in_buffer == 0, "bits_in_buffer must be 0 after processing compressed span" assert pack_bits == 0, "pack_bits must be 0 after packing" result.extend(packed_values) else: # Raw byte (0-255), keep as is result.append(data[i]) i += 1 return result def unpack_compressed_spans(packed_data, bits_per_compressed: int, compression_bit_threshold: int, compression_offset: int = 256): """ Reverse operation: unpack larger integers back to consecutive compressed bytes. Args: packed_data: List with packed compressed spans bits_per_compressed: Number of bits used for packing compression_bit_threshold: Number of bits each compressed value actually uses compression_offset: Offset used for compressed values Returns: Original format with consecutive compressed bytes """ result = [] i = 0 # Calculate byte-aligned padded size padded_compression_bit_threshold = ((compression_bit_threshold + 7) // 8) * 8 padding_bits = padded_compression_bit_threshold - compression_bit_threshold while i < len(packed_data): if packed_data[i] >= compression_offset: # Start of compressed span # Find consecutive packed values span_start = i while i < len(packed_data) and packed_data[i] >= compression_offset: i += 1 packed_span = packed_data[span_start:i] base_values = [x - compression_offset for x in packed_span] # Unpack using two-phase process to handle padding unpacked_bytes = [] bit_buffer = 0 bits_in_buffer = 0 for val in base_values: # Add this packed value to our bit buffer bit_buffer = (bit_buffer << bits_per_compressed) | val bits_in_buffer += bits_per_compressed # Extract compression_bit_threshold values as soon as we have enough bits while bits_in_buffer >= compression_bit_threshold: # Extract the top compression_bit_threshold bits shift_amount = bits_in_buffer - compression_bit_threshold compressed_val = (bit_buffer >> shift_amount) & ((1 << compression_bit_threshold) - 1) # Remove the extracted bits from buffer bit_buffer &= (1 << shift_amount) - 1 bits_in_buffer -= compression_bit_threshold # Add padding back to make it byte-aligned padded_val = compressed_val << padding_bits # Convert padded value back to bytes bytes_needed = padded_compression_bit_threshold // 8 for byte_idx in range(bytes_needed): shift = (bytes_needed - 1 - byte_idx) * 8 byte_val = (padded_val >> shift) & 0xFF unpacked_bytes.append(byte_val + compression_offset) # Verify all bits were processed cleanly assert bits_in_buffer == 0, "bits_in_buffer must be 0 after unpacking compressed span" result.extend(unpacked_bytes) else: # Raw byte, keep as is result.append(packed_data[i]) i += 1 return result def run_test_case(test_name: str, data: list, bits_per_compressed: int, compression_bit_threshold: int): """Run a single test case with comprehensive validation.""" print(f"๐Ÿงช {test_name}") print(f" Original: {data}") print(f" Config: bits_per_compressed={bits_per_compressed}, compression_bit_threshold={compression_bit_threshold}") try: # Test packing packed = pack_compressed_spans(data, bits_per_compressed, compression_bit_threshold) print(f" Packed: {packed}") # Test unpacking unpacked = unpack_compressed_spans(packed, bits_per_compressed, compression_bit_threshold) print(f" Unpacked: {unpacked}") # Verify round-trip success = data == unpacked print(f" Result: {'โœ… PASS' if success else 'โŒ FAIL'}") # Show compression stats original_compressed = len([x for x in data if x >= 256]) packed_compressed = len([x for x in packed if x >= 256]) if original_compressed > 0: ratio = original_compressed / packed_compressed if packed_compressed > 0 else 0 print(f" Stats: {original_compressed} โ†’ {packed_compressed} compressed values ({ratio:.2f}x)") return success except Exception as e: print(f" Result: โŒ ERROR: {e}") return False def test_packing_comprehensive(): from m1_compression import utils import random def random_bytes_generator(n: int, bit_threshold: int): ret = [] length = random.randint(n // 2, n) for _ in range(length): bits = "" for _ in range(bit_threshold): bits += "0" if random.random() < 0.5 else "1" compressed_bytes, _ = utils.bits_to_bytes_padding_to_threshold(bits, bit_threshold) ret.extend([c + 256 for c in list(compressed_bytes)]) ret.extend([random.randint(0, 255)]) return ret """Comprehensive test suite for packing functions.""" print("=" * 60) print("๐Ÿš€ COMPREHENSIVE PACKING TESTS") print("=" * 60) test_results = [] # Test 1: Basic functionality - 16-bit alignment (no padding) test_results.append(run_test_case( "Basic 16-bit packing (no padding)", random_bytes_generator(100, 16), bits_per_compressed=16, compression_bit_threshold=16 )) print() # Test 2: 12-bit values with padding (12 bits stored in 16 bits) test_results.append(run_test_case( "12-bit values with 4-bit padding", random_bytes_generator(100, 12), bits_per_compressed=12, compression_bit_threshold=12 )) print() # Test 3: 20-bit values with padding (20 bits stored in 24 bits) test_results.append(run_test_case( "20-bit values with 4-bit padding", random_bytes_generator(100, 20), bits_per_compressed=20, compression_bit_threshold=20 )) print() # Test 5: Edge case - single compressed byte test_results.append(run_test_case( "Single compressed byte", [100, 256, 200], bits_per_compressed=8, compression_bit_threshold=8 )) print() # Test 6: Edge case - no compressed bytes test_results.append(run_test_case( "No compressed bytes", [100, 200, 50, 150], bits_per_compressed=16, compression_bit_threshold=16 )) print() # Test 7: Edge case - all compressed bytes test_results.append(run_test_case( "All compressed bytes", [256, 257, 258, 259, 260, 261], bits_per_compressed=8, compression_bit_threshold=8 )) print() # Test 8: Mixed compression ratios test_results.append(run_test_case( "24-bit to 12-bit packing (2:1 ratio)", random_bytes_generator(100, 24), bits_per_compressed=12, compression_bit_threshold=24 )) print() # Summary passed = sum(test_results) total = len(test_results) print("=" * 60) print(f"๐Ÿ“Š TEST SUMMARY: {passed}/{total} tests passed") print("=" * 60) if passed == total: print("๐ŸŽ‰ All tests passed! The implementation is working correctly.") else: print("โš ๏ธ Some tests failed. Please review the implementation.") return passed, total def test_real_data(): print("=" * 40) print("๐Ÿ”ง REAL DATA TESTS") print("=" * 40) key = "m1_ac_ow20_escapefb-False_iterative-True" with open("output_compress/m1.chunk.0_out_0_out_0_writer_0.jsonl", "r") as f: for i, line in enumerate(f): data = json.loads(line) # NOTE: for visualization purposes, we replace values > 256 with byte '_' # bytes_array = packed_bytes_to_pseudo(base64.b64decode(data[key])) # bytes_array = [b if b < 256 else ord('_') for b in bytes_array] # bytes_string = bytes(bytes_array).decode("utf-8", errors="replace") # print(bytes_string) # extract bit_threshold in key key_splits = key.split("_") bit_threshold = None for k in key_splits: if k.startswith( ): bit_threshold = int(k[len("ow"):]) break assert bit_threshold is not None print(f"Bit threshold: {bit_threshold}") # NEW: Apply the packing function to the original bytes_array (before replacement) original_bytes_array = packed_bytes_to_pseudo(base64.b64decode(data[key])) run_test_case( f"Packing {bit_threshold}-bit values", original_bytes_array, 10, bit_threshold ) if i > 4: break def test_error_conditions(): """Test error conditions and edge cases.""" print("\n๐Ÿ”ง ERROR CONDITION TESTS") print("=" * 40) # Test invalid bit alignment try: pack_compressed_spans([256, 257], 10, 15) # 15 % 10 != 0 print("โŒ Should have failed on invalid bit alignment") except AssertionError: print("โœ… Correctly caught invalid bit alignment") # Test empty data result = pack_compressed_spans([], 16, 16) print(f"โœ… Empty data handling: {result == []}") print() if __name__ == "__main__": # Run comprehensive tests passed, total = test_packing_comprehensive() # Final result if passed == total: print("๐Ÿ† ALL TESTS COMPLETED SUCCESSFULLY!") else: print("๐Ÿ’ฅ SOME TESTS FAILED!") test_real_data() # Run error condition tests test_error_conditions()