import json import base64 import argparse import os import sys import gzip import time import math import gc import torch import torch.multiprocessing as mp import numpy as np import matplotlib.pyplot as plt import seaborn as sns import pandas as pd from tqdm import tqdm from typing import List, Dict, Any, Callable, Tuple, Optional import Levenshtein from collections import defaultdict # ========================================== # 0. 系统路径与环境修复 # ========================================== current_dir = os.getcwd() if current_dir not in sys.path: sys.path.append(current_dir) script_dir = os.path.dirname(os.path.abspath(__file__)) if script_dir not in sys.path: sys.path.append(script_dir) print(f"🔧 System Path Fixed. CWD: {current_dir}") # ========================================== # 1. 依赖检查 # ========================================== try: from transformers import AutoTokenizer except ImportError: print("❌ Error: 'transformers' not installed.") sys.exit(1) try: from m1_compression import utils from m1_compression.compressor import ( load_m1_model_and_tokenizer, ALPHABET_SIZE, ARITHMETIC_CODER_BASE, ARITHMETIC_CODER_PRECISION ) from m1_compression.hybrid_arithmetic_coder import CPUArithmeticEncoder from m1_compression.batched_arithmetic_coder import _pdf_to_cdf print("✅ Successfully imported m1_compression modules.") except ImportError as e: print(f"❌ FATAL ERROR: {e}") sys.exit(1) # ========================================== # 2. 辅助函数 # ========================================== def vread(buf: bytes, i: int): shift = val = 0 while True: b = buf[i] i += 1 val |= (b & 0x7F) << shift if b < 0x80: return val, i shift += 7 def unpack_windows(input_bytes: bytes, b64_stream: str) -> List[Tuple[bytes, int]]: try: if not b64_stream: return [] buf, i, cursor, byte_windows = base64.b64decode(b64_stream), 0, 0, [] while i < len(buf): gap, i = vread(buf, i) size, i = vread(buf, i) start = cursor + gap if gap > 0: byte_windows.append((input_bytes[cursor:start], 0)) end = start + size byte_windows.append((input_bytes[start:end], 1)) cursor = end if cursor < len(input_bytes): byte_windows.append((input_bytes[cursor:], 0)) return byte_windows except (base64.binascii.Error, IndexError): return [] def list_to_comparable_str(int_list: List[int]) -> str: return "".join([chr(min(x, 0x10FFFF)) for x in int_list]) def pad_batch(batch: List[bytes]): batch_tensors = [torch.tensor(list(data), dtype=torch.int64) for data in batch] lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64) padded_batch = torch.nn.utils.rnn.pad_sequence( batch_tensors, batch_first=True, padding_value=0 ) return padded_batch, lengths # ========================================== # 3. 核心压缩逻辑 (AC) # ========================================== def batched_m1_compress_predict_fn(model): def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor: if input_tensor.dim() == 1: input_tensor = input_tensor.unsqueeze(0) with torch.no_grad(): logits = model(input_tensor, **kwargs) logits = logits[..., :256].float() probs = torch.softmax(logits, dim=-1) return probs return predict_fn def compress_segments_ac_impl( sorted_segments: List[bytes], batched_predict_fn: Callable, first_byte_prob: torch.Tensor, device: torch.device ) -> List[List[int]]: """ 底层批处理函数:接收一大堆 segments,分批送入 GPU 计算,再用 CPU 编码 """ M = len(sorted_segments) if M == 0: return [] # 去重 segment_to_indices = defaultdict(list) for i, seg in enumerate(sorted_segments): segment_to_indices[seg].append(i) unique_segments = [seg for seg in segment_to_indices.keys() if len(seg) > 0] segment_to_compressed = {} encoder = CPUArithmeticEncoder(base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION) # 这里的 Batch Size 是送入 GPU 进行推理的 Batch,取决于显存大小 # 建议设大一点,比如 64 或 128,因为是在显存允许范围内并行 GPU_BATCH_SIZE = 256 for i in range(0, len(unique_segments), GPU_BATCH_SIZE): batch_segments = unique_segments[i : i + GPU_BATCH_SIZE] try: padded_batch, lengths = pad_batch(batch_segments) padded_batch = padded_batch.to(device) # lengths 在 CPU 上给 encoder 用 with torch.no_grad(): prompt_probs = batched_predict_fn(padded_batch) final_probs = torch.cat( [ first_byte_prob.expand(prompt_probs.shape[0], -1, -1), prompt_probs[:, :-1, ...] ], dim=1 ) final_probs = utils.batched_normalize_pdf_for_arithmetic_coding(final_probs) cdfs_gpu = _pdf_to_cdf(final_probs) cdf_low = cdfs_gpu.gather(2, padded_batch.unsqueeze(-1)).squeeze(-1) cdf_high = cdfs_gpu.gather(2, (padded_batch + 1).unsqueeze(-1)).squeeze(-1) cdf_ends = torch.stack([cdf_low, cdf_high], dim=-1) chunked_compressed_bytes, _, _ = encoder.incremental_batched_encode( cdf_ends.cpu(), ALPHABET_SIZE, lengths, bit_threshold=16, force_padding_to_threshold=False, return_num_padded_bits=True ) for seg, code in zip(batch_segments, chunked_compressed_bytes): segment_to_compressed[seg] = list(code) except Exception as e: # print(f"Batch Error: {e}") for seg in batch_segments: segment_to_compressed[seg] = list(seg) # 降级 all_results = [None] * M for seg, indices in segment_to_indices.items(): res = segment_to_compressed.get(seg, list(seg)) for idx in indices: all_results[idx] = res return all_results class M1ACManager: def __init__(self, model_path, first_prob_path, device_id): self.device = torch.device(f"cuda:{device_id}") print(f"[GPU {device_id}] Loading M1 Model...") self.model, _, _ = load_m1_model_and_tokenizer(model_path) self.model.to(self.device) self.model.eval() self.predict_fn = batched_m1_compress_predict_fn(self.model) if first_prob_path and os.path.exists(first_prob_path): with open(first_prob_path, 'r') as f: prob_data = json.load(f) self.first_byte_prob = torch.tensor(prob_data, dtype=torch.float32, device=self.device) if self.first_byte_prob.dim() == 1: self.first_byte_prob = self.first_byte_prob.unsqueeze(0).unsqueeze(0) else: self.first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device=self.device) / ALPHABET_SIZE def compress_batch(self, inputs: List[Tuple[str, Optional[str]]]) -> List[List[int]]: """ 新的批量压缩接口。 inputs: List of (text, windows_b64) Returns: List of compressed int lists """ all_segments_flat = [] # 记录每个 sample 对应的 segments 在 flat 列表中的起止位置 # map: sample_idx -> (start_idx, end_idx) sample_segment_map = [] current_idx = 0 # 1. 准备所有 Segments for text, windows_b64 in inputs: raw_bytes = text.encode('utf-8') sample_segs = [] if windows_b64: # Case 1: 原始数据 (Metadata Split) for seg, ind in unpack_windows(raw_bytes, windows_b64): if len(seg) > 0: sample_segs.append(seg) else: # Case 2: 扰动数据 (Fixed Chunking) CHUNK = 512 for i in range(0, len(raw_bytes), CHUNK): sample_segs.append(raw_bytes[i : i + CHUNK]) count = len(sample_segs) sample_segment_map.append((current_idx, current_idx + count)) all_segments_flat.extend(sample_segs) current_idx += count if not all_segments_flat: return [[] for _ in inputs] # 2. 批量调用 GPU 压缩 # compress_segments_ac_impl 内部会处理 GPU mini-batch,所以这里可以传入大列表 compressed_chunks_flat = compress_segments_ac_impl( all_segments_flat, self.predict_fn, self.first_byte_prob, self.device ) # 3. 结果重组 results = [] for start, end in sample_segment_map: # 将属于该 sample 的所有 chunk 拼起来 sample_chunks = compressed_chunks_flat[start:end] full_stream = [x for chunk in sample_chunks for x in chunk] results.append(full_stream) return results # ========================================== # 4. Worker Process (重构为 Batch 处理) # ========================================== def process_file_worker(rank, gpu_id, file_path, output_dir, model_path, prob_path, max_lines): try: torch.cuda.set_device(gpu_id) # Tokenizer Init try: tokenizer = AutoTokenizer.from_pretrained("infly/OpenCoder-1.5B-Base", trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token except: tokenizer = AutoTokenizer.from_pretrained("gpt2") # AC Init try: ac_manager = M1ACManager(model_path, prob_path, gpu_id) except Exception as e: print(f"❌ [GPU {gpu_id}] Init Failed: {e}") return results = {"Gzip": [], "Tokenizer": [], "AC_M1": []} filename = os.path.basename(file_path) print(f"[GPU {gpu_id}] Processing {filename}...") # Buffer for batch processing # 这里的 Batch 是指 "一次性读取多少行文本然后送去压缩" # 建议设大一点,比如 50 或 100,取决于每行文本的长度 WORKER_BATCH_SIZE = 200 batch_texts = [] # [text1, text2, ...] batch_pert_texts = [] # [text1_p, text2_p, ...] batch_metas = [] # [meta1, meta2, ...] processed_count = 0 def flush_batch(): nonlocal batch_texts, batch_pert_texts, batch_metas if not batch_texts: return # 1. Gzip (CPU fast enough, loop is fine) for t, tp in zip(batch_texts, batch_pert_texts): gz1 = list(gzip.compress(t.encode('utf-8'))) gz2 = list(gzip.compress(tp.encode('utf-8'))) if gz1: d = Levenshtein.distance(list_to_comparable_str(gz1), list_to_comparable_str(gz2)) results["Gzip"].append(d / len(gz1)) # 2. Tokenizer (CPU/GPU) # Tokenizer 通常很快,或者可以用 tokenizer.batch_encode_plus for t, tp in zip(batch_texts, batch_pert_texts): tok1 = tokenizer.encode(t, add_special_tokens=False) tok2 = tokenizer.encode(tp, add_special_tokens=False) if tok1: d = Levenshtein.distance(list_to_comparable_str(tok1), list_to_comparable_str(tok2)) results["Tokenizer"].append(d / len(tok1)) # 3. AC (GPU BATCHING IS HERE) # 准备输入数据 orig_inputs = list(zip(batch_texts, batch_metas)) # (text, meta) pert_inputs = list(zip(batch_pert_texts, [None]*len(batch_pert_texts))) # (text, None) try: # 批量压缩 ac1_list = ac_manager.compress_batch(orig_inputs) ac2_list = ac_manager.compress_batch(pert_inputs) for ac1, ac2 in zip(ac1_list, ac2_list): if ac1 and len(ac1) > 0: d = Levenshtein.distance(list_to_comparable_str(ac1), list_to_comparable_str(ac2)) results["AC_M1"].append(d / len(ac1)) except Exception as e: print(f"[GPU {gpu_id}] AC Batch Error: {e}") # Clear buffer batch_texts, batch_pert_texts, batch_metas = [], [], [] with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: for i, line in enumerate(f): if max_lines > 0 and i >= max_lines: break try: data = json.loads(line) text = data.get('text', '') windows_b64 = data.get('windows_starts_lens_b64') if not text or len(text) < 100: continue cut_idx = max(1, int(len(text) * 0.1)) text_pert = text[cut_idx:] # Add to buffer batch_texts.append(text) batch_pert_texts.append(text_pert) batch_metas.append(windows_b64) processed_count += 1 # Flush if full if len(batch_texts) >= WORKER_BATCH_SIZE: flush_batch() if processed_count % 500 == 0: print(f"[GPU {gpu_id}] Processed {processed_count} lines...") except Exception: continue # Flush remaining flush_batch() output_file = os.path.join(output_dir, f"partial_result_{rank}_{filename}.json") with open(output_file, 'w') as f: json.dump(results, f) print(f"✅ [GPU {gpu_id}] Done {filename}. Total: {processed_count}") except Exception as e: print(f"❌ [GPU {gpu_id}] Worker failed: {e}") import traceback traceback.print_exc() # ========================================== # 5. Main # ========================================== def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True) parser.add_argument("--m1_model", type=str, required=True) parser.add_argument("--first_prob_path", type=str, required=True) parser.add_argument("-o", "--output_dir", type=str, default="analysis_output_parallel") parser.add_argument("--max_lines", type=int, default=10000) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if f.endswith('.jsonl') and "writer" not in f] files.sort() num_gpus = torch.cuda.device_count() if num_gpus == 0: return if len(files) > num_gpus: files = files[:num_gpus] actual_procs = len(files) print(f"🚀 Launching {actual_procs} processes (Batch Mode)...") mp.set_start_method('spawn', force=True) processes = [] for rank in range(actual_procs): p = mp.Process( target=process_file_worker, args=(rank, rank % num_gpus, files[rank], args.output_dir, args.m1_model, args.first_prob_path, args.max_lines) ) p.start() processes.append(p) for p in processes: p.join() print("✅ All workers finished. Merging results...") final_results = {"Gzip": [], "Tokenizer": [], "AC_M1": []} for filename in os.listdir(args.output_dir): if filename.startswith("partial_result_") and filename.endswith(".json"): try: with open(os.path.join(args.output_dir, filename), 'r') as f: data = json.load(f) for k in final_results: if k in data: final_results[k].extend(data[k]) except: pass for k, v in final_results.items(): print(f" -> {k}: {len(v)} samples collected.") plot_records = [] for algo, vals in final_results.items(): cleaned = [v for v in vals if v < 2.0] for v in cleaned: plot_records.append({"Algorithm": algo, "Normalized Edit Distance": v}) if not plot_records: print("❌ No data collected.") return print("📊 Generating plot...") try: df = pd.DataFrame(plot_records) plt.figure(figsize=(12, 7)) sns.set_style("whitegrid") sns.kdeplot(data=df, x="Normalized Edit Distance", hue="Algorithm", fill=True, common_norm=False, palette="tab10", alpha=0.5) plt.title("Compression Stability Analysis") plt.xlabel("Normalized Levenshtein Distance") plt.xlim(0, 1.2) plt.savefig(os.path.join(args.output_dir, "stability_parallel_batch.png"), dpi=300) except Exception as e: print(f"⚠️ Plotting failed: {e}") stats = {k: {"mean": float(np.mean(v)), "count": len(v)} for k, v in final_results.items() if v} with open(os.path.join(args.output_dir, "final_stats.json"), 'w') as f: json.dump(stats, f, indent=2) print(f"🎉 Done!") if __name__ == "__main__": main() """ # 有 8 个json 文件 先测试一个文件 python compare_three_compression_lv.py \ --input_dir /mnt/hdfs/user/linzheng/data/ocpython_subsampled_50G_entropy90_splits_chunk512_ow20_iterative-true_forcepadding-true_merged_ac \ --m1_model /mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/m1_checkpoints/m1_40M_lr1e-3_steps200k_bs8_seqlen2048_python/checkpoints/0000200000 \ --first_prob_path /mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/ac_unigram_probs/python500k_unigram_prob.json 这里可能出现缺少某些模块 pip install xformers==0.0.23.post1 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com """