|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
import json |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any, Callable, Tuple, Optional |
|
|
import logging |
|
|
import argparse |
|
|
import re |
|
|
import gc |
|
|
from collections import defaultdict, Counter |
|
|
from m1_compression.compressor import ( |
|
|
load_m1_model_and_tokenizer, |
|
|
ALPHABET_SIZE, |
|
|
) |
|
|
import multiprocessing as mp |
|
|
from offline_utils import ( |
|
|
compress_windows_starts_lens, |
|
|
decompress_windows_starts_lens, |
|
|
unpack_windows, |
|
|
InterleavedJsonlDataset, |
|
|
batched_m1_compress_predict_fn, |
|
|
find_next_batch_range, |
|
|
) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger() |
|
|
|
|
|
MAX_LINE_LEN = 512 |
|
|
|
|
|
def print_windows(text: str, |
|
|
starts: list[int], |
|
|
lens: list[int], |
|
|
sample_idx: int = None, |
|
|
): |
|
|
from rich.console import Console |
|
|
from rich.text import Text |
|
|
import io |
|
|
PALETTE = ( |
|
|
"#c6f6d5", "#bee3f8", "#fbb6ce", |
|
|
"#faf089", "#fed7e2", "#b2f5ea", |
|
|
) |
|
|
string_io = io.StringIO() |
|
|
console = Console(record=True, force_terminal=True, color_system="truecolor", file=string_io) |
|
|
|
|
|
t = Text() |
|
|
last_end = 0 |
|
|
colour_idx = 0 |
|
|
|
|
|
for s, l in sorted(zip(starts, lens)): |
|
|
t.append(text[last_end:s]) |
|
|
|
|
|
if s == last_end: |
|
|
colour_idx = (colour_idx + 1) % len(PALETTE) |
|
|
|
|
|
t.append(text[s:s + l], |
|
|
style=f"on {PALETTE[colour_idx]} bold black") |
|
|
last_end = s + l |
|
|
|
|
|
t.append(text[last_end:]) |
|
|
console.print(t) |
|
|
|
|
|
|
|
|
save_idx = sample_idx % 100 |
|
|
return console.save_svg(f"window_visualize_{save_idx}.svg") |
|
|
|
|
|
def collect_lines(batched_bytes_data: List[bytes], max_len: int = 2048) -> Tuple[List[bytes], Dict[int, Tuple[int, int]]]: |
|
|
batched_lines = [] |
|
|
line_id_to_sample_offsets = {} |
|
|
line_idx = 0 |
|
|
|
|
|
for sample_idx, data_bytes in enumerate(batched_bytes_data): |
|
|
if len(data_bytes) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
lines_with_positions = [] |
|
|
for match in re.finditer(b'[^\r\n]*(?:\r\n|\r|\n)*', data_bytes): |
|
|
if match.group(): |
|
|
lines_with_positions.append((match.group(), match.start())) |
|
|
|
|
|
for line, byte_offset in lines_with_positions: |
|
|
if len(line) > max_len: |
|
|
logger.info("Line too long with {} bytes, splitting into chunks...".format(len(line))) |
|
|
|
|
|
for chunk_start in range(0, len(line), max_len): |
|
|
chunk_end = min(chunk_start + max_len, len(line)) |
|
|
batched_lines.append(line[chunk_start:chunk_end]) |
|
|
|
|
|
chunk_byte_offset = byte_offset + chunk_start |
|
|
line_id_to_sample_offsets[line_idx] = (sample_idx, chunk_byte_offset) |
|
|
line_idx += 1 |
|
|
else: |
|
|
batched_lines.append(line) |
|
|
line_id_to_sample_offsets[line_idx] = (sample_idx, byte_offset) |
|
|
line_idx += 1 |
|
|
|
|
|
return batched_lines, line_id_to_sample_offsets |
|
|
|
|
|
def calculate_skew(entropy: torch.Tensor) -> torch.Tensor: |
|
|
mean = torch.mean(entropy) |
|
|
diffs = entropy - mean |
|
|
var = torch.mean(torch.pow(diffs, 2.0)) |
|
|
std = torch.pow(var, 0.5) |
|
|
if std == 0.0: |
|
|
return torch.tensor(0.0) |
|
|
zscores = diffs / std |
|
|
skews = torch.mean(torch.pow(zscores, 3.0)) |
|
|
return skews |
|
|
|
|
|
def get_split_points( |
|
|
probs: torch.Tensor, |
|
|
next_bytes: torch.Tensor, |
|
|
lengths: torch.Tensor, |
|
|
base_global_quantile: float, |
|
|
base_monotonic_quantile: float, |
|
|
debug: bool = False, |
|
|
): |
|
|
B, L = probs.shape[0], probs.shape[1] |
|
|
arange_ids = torch.arange(L, device=probs.device).unsqueeze(0) |
|
|
pad_mask = arange_ids < lengths.unsqueeze(1) |
|
|
padded_cross_entropy = F.cross_entropy( |
|
|
probs.transpose(1, 2), |
|
|
next_bytes, |
|
|
reduction="none" |
|
|
) |
|
|
|
|
|
flattened_cross_entropy = padded_cross_entropy[pad_mask] |
|
|
assert flattened_cross_entropy.dim() == 1 |
|
|
|
|
|
skew_flattened_cross_entropy = calculate_skew(flattened_cross_entropy.float()) |
|
|
if skew_flattened_cross_entropy > 0.0: |
|
|
base_global_quantile = base_global_quantile - 0.04 * skew_flattened_cross_entropy.item() |
|
|
base_global_quantile = min(max(base_global_quantile, 0.0), 1.0) |
|
|
|
|
|
|
|
|
threshold = torch.quantile(flattened_cross_entropy, base_global_quantile).clamp(0.1, 10.0) |
|
|
|
|
|
padded_cross_entropy_diff = torch.diff(padded_cross_entropy, dim=1) |
|
|
padded_cross_entropy_diff = torch.cat( |
|
|
[ |
|
|
torch.zeros(B, 1, device=padded_cross_entropy_diff.device), |
|
|
padded_cross_entropy_diff |
|
|
], |
|
|
dim=1 |
|
|
) |
|
|
flattened_cross_entropy_diff = padded_cross_entropy_diff[pad_mask] |
|
|
|
|
|
skew_flattened_cross_entropy_diff = calculate_skew(flattened_cross_entropy_diff.float()) |
|
|
if skew_flattened_cross_entropy_diff > 0.0: |
|
|
base_monotonic_quantile = base_monotonic_quantile - 0.04 * skew_flattened_cross_entropy_diff.item() |
|
|
base_monotonic_quantile = min(max(base_monotonic_quantile, 0.0), 1.0) |
|
|
|
|
|
diff_threshold = torch.quantile(flattened_cross_entropy_diff, base_monotonic_quantile).clamp(0.01, 10.0) |
|
|
split_points_mask = ((padded_cross_entropy > threshold) | (padded_cross_entropy_diff > diff_threshold)) & pad_mask |
|
|
|
|
|
if debug: |
|
|
logger.info(f"skew_flattened_cross_entropy: {skew_flattened_cross_entropy}") |
|
|
logger.info(f"skew_flattened_cross_entropy_diff: {skew_flattened_cross_entropy_diff}") |
|
|
logger.info(f"base_global_quantile: {base_global_quantile}") |
|
|
logger.info(f"base_monotonic_quantile: {base_monotonic_quantile}") |
|
|
logger.info(f"threshold: {threshold}") |
|
|
logger.info(f"diff_threshold: {diff_threshold}") |
|
|
return split_points_mask |
|
|
|
|
|
def get_batch_size_for_length(window_len, max_batch_size): |
|
|
"""Determines the batch size for a given window length.""" |
|
|
BATCH_SIZE_TIERS = { |
|
|
512: max_batch_size, |
|
|
1024: max(max_batch_size // 2, 1), |
|
|
2048: max(max_batch_size // 4, 1), |
|
|
} |
|
|
for max_len, batch_size in BATCH_SIZE_TIERS.items(): |
|
|
if window_len <= max_len: |
|
|
return batch_size |
|
|
return 1 |
|
|
|
|
|
def calculate_entropy_and_split_points_fn( |
|
|
batch: Dict[str, Any], |
|
|
predict_fn: Callable, |
|
|
chunk_size: int = 512, |
|
|
base_global_quantile: float = 90.0, |
|
|
base_monotonic_quantile: float = 90.0, |
|
|
unigram_probs: Optional[torch.Tensor] = None, |
|
|
max_m1_batch_size: int = 2048, |
|
|
line_split: bool = False, |
|
|
debug: bool = False, |
|
|
) -> List[Dict[str, Any]]: |
|
|
|
|
|
batched_bytes_data = [item["text"].encode('utf-8') for item in batch] |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
if unigram_probs is not None: |
|
|
unigram_probs = unigram_probs.to(device) |
|
|
|
|
|
|
|
|
all_split_point_masks = [] |
|
|
|
|
|
|
|
|
if line_split: |
|
|
chunks, chunk_to_sample_and_offset = collect_lines(batched_bytes_data, max_len=MAX_LINE_LEN) |
|
|
|
|
|
sorted_chunks = sorted(enumerate(chunks), key=lambda x: len(x[1])) |
|
|
sorted_idx, sorted_chunks = zip(*sorted_chunks) |
|
|
sorted_chunks = list(sorted_chunks) |
|
|
chunk_idx_map = { |
|
|
orig_idx: new_idx |
|
|
for new_idx, orig_idx in enumerate(sorted_idx) |
|
|
} |
|
|
|
|
|
chunks_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in sorted_chunks] |
|
|
num_chunks = len(sorted_chunks) |
|
|
start_idx = 0 |
|
|
while start_idx < num_chunks: |
|
|
|
|
|
start_idx, end_idx = find_next_batch_range(chunks_np, start_idx, max_m1_batch_size, get_batch_size_for_length) |
|
|
|
|
|
batch_chunks_np = chunks_np[start_idx:end_idx] |
|
|
|
|
|
effective_batch_size = end_idx - start_idx |
|
|
|
|
|
lengths_pt = torch.tensor([len(chunk) for chunk in batch_chunks_np], dtype=torch.long, device=device) |
|
|
batch_chunks_pt = torch.zeros( |
|
|
(effective_batch_size, max(lengths_pt)), |
|
|
dtype=torch.long, |
|
|
device=device |
|
|
) |
|
|
for i, chunk_np in enumerate(batch_chunks_np): |
|
|
batch_chunks_pt[i, :len(chunk_np)] = torch.tensor(chunk_np, dtype=torch.long, device=device) |
|
|
|
|
|
cur_batch = batch_chunks_pt[:effective_batch_size] |
|
|
cur_lengths = lengths_pt[:effective_batch_size] |
|
|
with torch.no_grad(): |
|
|
probs = predict_fn(cur_batch) |
|
|
|
|
|
|
|
|
first_prob = unigram_probs.expand( |
|
|
effective_batch_size, 1, -1) |
|
|
final_probs = torch.cat([first_prob, probs[:, :-1, :]], dim=1) |
|
|
start_idx = end_idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
split_points_mask = get_split_points( |
|
|
final_probs, |
|
|
cur_batch, |
|
|
cur_lengths, |
|
|
base_global_quantile, |
|
|
base_monotonic_quantile, |
|
|
debug, |
|
|
) |
|
|
all_split_point_masks.append(split_points_mask) |
|
|
|
|
|
split_point_chunk_idx_lst = [] |
|
|
split_point_position_idx_lst = [] |
|
|
processed_chunks = 0 |
|
|
for mask in all_split_point_masks: |
|
|
split_point_chunk_idx, split_point_position_idx = mask.cpu().nonzero(as_tuple=True) |
|
|
split_point_chunk_idx_lst.append(split_point_chunk_idx + processed_chunks) |
|
|
split_point_position_idx_lst.append(split_point_position_idx) |
|
|
processed_chunks = processed_chunks + mask.shape[0] |
|
|
split_point_chunk_idx = torch.cat(split_point_chunk_idx_lst) |
|
|
split_point_position_idx = torch.cat(split_point_position_idx_lst) |
|
|
else: |
|
|
chunk_idx_map = None |
|
|
|
|
|
chunks = [] |
|
|
chunk_to_sample_and_offset = {} |
|
|
chunk_idx = 0 |
|
|
for sample_idx, data_bytes in enumerate(batched_bytes_data): |
|
|
logger.debug(f"Processing sample {sample_idx+1} (bytes: {len(data_bytes)})") |
|
|
|
|
|
if len(data_bytes) == 0: |
|
|
continue |
|
|
|
|
|
byte_len = len(data_bytes) |
|
|
|
|
|
for i in range(0, byte_len, chunk_size): |
|
|
chunk_start = i |
|
|
chunk_end = min(i + chunk_size, byte_len) |
|
|
chunk = data_bytes[chunk_start:chunk_end] |
|
|
chunks.append(chunk) |
|
|
chunk_to_sample_and_offset[chunk_idx] = (sample_idx, chunk_start) |
|
|
|
|
|
chunk_idx += 1 |
|
|
|
|
|
|
|
|
all_split_point_masks = [] |
|
|
|
|
|
batch_chunks_pt = torch.zeros( |
|
|
(max_m1_batch_size, chunk_size), |
|
|
dtype=torch.long, |
|
|
device=device |
|
|
) |
|
|
lengths_pt = torch.zeros(max_m1_batch_size, dtype=torch.long, device=device) |
|
|
num_chunks = len(chunks) |
|
|
|
|
|
for start_idx in range(0, num_chunks, max_m1_batch_size): |
|
|
end_idx = min(start_idx + max_m1_batch_size, num_chunks) |
|
|
batch_chunks = chunks[start_idx:end_idx] |
|
|
batch_chunks_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in batch_chunks] |
|
|
|
|
|
|
|
|
effective_batch_size = end_idx - start_idx |
|
|
|
|
|
for i, chunk_np in enumerate(batch_chunks_np): |
|
|
batch_chunks_pt[i, :len(chunk_np)] = torch.tensor(chunk_np, dtype=torch.long, device=device) |
|
|
lengths_pt[i] = len(chunk_np) |
|
|
|
|
|
cur_batch = batch_chunks_pt[:effective_batch_size] |
|
|
cur_lengths = lengths_pt[:effective_batch_size] |
|
|
with torch.no_grad(): |
|
|
probs = predict_fn(cur_batch) |
|
|
|
|
|
|
|
|
first_prob = unigram_probs.expand( |
|
|
effective_batch_size, 1, -1) |
|
|
final_probs = torch.cat([first_prob, probs[:, :-1, :]], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
split_points_mask = get_split_points( |
|
|
final_probs, |
|
|
cur_batch, |
|
|
cur_lengths, |
|
|
base_global_quantile, |
|
|
base_monotonic_quantile, |
|
|
debug, |
|
|
) |
|
|
all_split_point_masks.append(split_points_mask) |
|
|
|
|
|
all_split_point_masks = torch.cat(all_split_point_masks, dim=0) |
|
|
|
|
|
all_split_points_tuple = all_split_point_masks.nonzero(as_tuple=True) |
|
|
|
|
|
|
|
|
split_point_chunk_idx, split_point_position_idx = all_split_points_tuple[0].cpu(), all_split_points_tuple[1].cpu() |
|
|
|
|
|
sample_idx_to_split_positions = defaultdict(list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
split_point_chunk_idx_np = split_point_chunk_idx.numpy() |
|
|
split_point_position_idx_np = split_point_position_idx.numpy() |
|
|
chunk_to_splits = defaultdict(list) |
|
|
|
|
|
|
|
|
for i in range(len(split_point_chunk_idx_np)): |
|
|
chunk_idx = split_point_chunk_idx_np[i] |
|
|
position = split_point_position_idx_np[i] |
|
|
chunk_to_splits[chunk_idx].append(position) |
|
|
|
|
|
|
|
|
for chunk_idx in range(num_chunks): |
|
|
|
|
|
chunk = chunks[chunk_idx] |
|
|
sample_idx, chunk_start = chunk_to_sample_and_offset[chunk_idx] |
|
|
if line_split: |
|
|
sorted_chunk_idx = chunk_idx_map[chunk_idx] |
|
|
split_points = chunk_to_splits[sorted_chunk_idx] |
|
|
else: |
|
|
split_points = chunk_to_splits[chunk_idx] |
|
|
|
|
|
if len(split_points) == 0: |
|
|
split_points = [0] |
|
|
if split_points[0] != 0: |
|
|
split_points.insert(0, 0) |
|
|
split_points.append(len(chunk)) |
|
|
|
|
|
offset_split_points = [s + chunk_start for s in split_points] |
|
|
sample_idx_to_split_positions[sample_idx].extend(offset_split_points) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample_idx_to_split_positions = {k: sorted(v) for k, v in sample_idx_to_split_positions.items()} |
|
|
|
|
|
write_results = [] |
|
|
min_window_size = 3 |
|
|
|
|
|
if debug: |
|
|
extreme_compression_results = [] |
|
|
for sample_idx, item in enumerate(batch): |
|
|
split_points = sample_idx_to_split_positions[sample_idx] |
|
|
split_windows_starts = [] |
|
|
split_windows_lens = [] |
|
|
cur_l = 0 |
|
|
cur_r = 0 |
|
|
for i in range(len(split_points) - 1): |
|
|
cur_l = split_points[i] |
|
|
cur_r = split_points[i+1] |
|
|
if cur_r - cur_l >= min_window_size: |
|
|
split_windows_starts.append(cur_l) |
|
|
split_windows_lens.append(cur_r - cur_l) |
|
|
|
|
|
compressed_windows_starts_lens_b64 = compress_windows_starts_lens(split_windows_starts, split_windows_lens) |
|
|
result = { |
|
|
**item, |
|
|
"windows_starts_lens_b64": compressed_windows_starts_lens_b64 |
|
|
} |
|
|
if debug: |
|
|
print_windows(item["text"], split_windows_starts, split_windows_lens, sample_idx=sample_idx) |
|
|
_debug_starts_lens = decompress_windows_starts_lens(compressed_windows_starts_lens_b64) |
|
|
_debug_starts, _debug_lens = _debug_starts_lens |
|
|
assert len(_debug_starts) == len(_debug_lens), f"Window starts and lens have different lengths: {len(_debug_starts)} != {len(_debug_lens)}" |
|
|
assert _debug_starts == split_windows_starts, f"Window starts do not match: {_debug_starts} != {split_windows_starts}" |
|
|
assert _debug_lens == split_windows_lens, f"Window lens do not match: {_debug_lens} != {split_windows_lens}" |
|
|
|
|
|
|
|
|
debug_sample = item["text"].encode('utf-8') |
|
|
raw_bytes = len(debug_sample) - sum(_debug_lens) |
|
|
compressed_bytes = len(_debug_starts) |
|
|
extreme_compression_rate = (compressed_bytes + raw_bytes) / len(debug_sample) |
|
|
extreme_compression_results.append(extreme_compression_rate) |
|
|
logger.info(f"[Extreme compression rate] for sample idx {sample_idx}: {extreme_compression_rate:.4f}") |
|
|
debug_byte_windows = unpack_windows(debug_sample, compressed_windows_starts_lens_b64) |
|
|
debug_bytes_windows, debug_indicators = zip(*debug_byte_windows) |
|
|
assert b"".join(debug_bytes_windows) == debug_sample, f"Debug bytes windows do not match: {b''.join(debug_bytes_windows)} != {debug_sample}" |
|
|
|
|
|
debug_split_points = sample_idx_to_split_positions[sample_idx] |
|
|
logger.info(f"Original byte length: {len(debug_sample)}") |
|
|
logger.info(f"num split_points: {len(debug_split_points)}") |
|
|
|
|
|
_debug_compressed_windows = [x[0] for x in debug_byte_windows if x[1]] |
|
|
_debug_sorted_compressed_windows = sorted(_debug_compressed_windows, key=lambda x: len(x), reverse=True) |
|
|
_debug_raw_windows = [x[0] for x in debug_byte_windows if not x[1]] |
|
|
_debug_sorted_raw_windows = sorted(_debug_raw_windows, key=lambda x: len(x), reverse=True) |
|
|
for i, byte_window in enumerate(_debug_sorted_compressed_windows): |
|
|
logger.info(f"compressed byte_window[{i}]: {byte_window}") |
|
|
if i > 10: |
|
|
break |
|
|
for i, byte_window in enumerate(_debug_sorted_raw_windows): |
|
|
logger.info(f"raw byte_window[{i}]: {byte_window}") |
|
|
if i > 10: |
|
|
break |
|
|
write_results.append(result) |
|
|
if debug: |
|
|
logger.info(f"[Extreme compression rate] for all samples: {np.mean(extreme_compression_results):.4f}") |
|
|
return write_results |
|
|
|
|
|
def writer_consumer(write_queue, output_file, buffer_size=100): |
|
|
""" |
|
|
Writer consumer: reads compressed results from write_queue and writes to file. |
|
|
Maintains its own buffer and writes when buffer is full or receives sentinel. |
|
|
""" |
|
|
write_buf = [] |
|
|
|
|
|
try: |
|
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
|
while True: |
|
|
item = write_queue.get() |
|
|
if item is None: |
|
|
break |
|
|
|
|
|
write_buf.extend(item) |
|
|
|
|
|
|
|
|
if len(write_buf) >= buffer_size: |
|
|
logger.info(f"Writer: Dumping buffer of {len(write_buf)} items to {output_file}") |
|
|
for buffered_item in write_buf: |
|
|
f.write(json.dumps(buffered_item) + '\n') |
|
|
f.flush() |
|
|
write_buf = [] |
|
|
|
|
|
|
|
|
if write_buf: |
|
|
logger.info(f"Writer: Dumping remaining {len(write_buf)} items to {output_file}") |
|
|
for buffered_item in write_buf: |
|
|
f.write(json.dumps(buffered_item) + '\n') |
|
|
f.flush() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Writer process error: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Process JSONL files using M1 arithmetic compression with buffer-based approach') |
|
|
parser.add_argument('--input_file', type=str, required=True, |
|
|
help='Directory containing input JSONL files') |
|
|
parser.add_argument('--output_dir', type=str, required=True, |
|
|
help='Directory to write compressed results') |
|
|
parser.add_argument('--entropy_model_path', type=str, required=True, |
|
|
help='Path to the M1 model checkpoint') |
|
|
parser.add_argument('--compression_model_path', type=str, required=True, |
|
|
help='Path to the M1 model checkpoint') |
|
|
parser.add_argument('--data_batch_size', type=int, default=512, |
|
|
help='Size of batches for processing (default: 512)') |
|
|
parser.add_argument('--output_window_size', type=int, default=16, |
|
|
help='Size of window for compression (default: 16)') |
|
|
parser.add_argument('--max_window_size', type=int, default=1024, |
|
|
help='Maximum window size for reading from each file (default: 1024)') |
|
|
parser.add_argument('--max_entropy_batch_size', type=int, default=4096, |
|
|
help='Size of max batch for compression (default: 4096)') |
|
|
parser.add_argument('--max_compression_batch_size', type=int, default=4096, |
|
|
help='Size of max batch for compression (default: 4096)') |
|
|
parser.add_argument('--chunk_size', type=int, default=512, |
|
|
help='Size of chunk for compression (default: 512)') |
|
|
parser.add_argument('--base_global_quantile', type=float, default=0.9, |
|
|
help='Base global quantile for compression (default: 0.9)') |
|
|
parser.add_argument('--base_monotonic_quantile', type=float, default=0.9, |
|
|
help='Base monotonic quantile for compression (default: 0.9)') |
|
|
parser.add_argument('--apply_line_split', action='store_true', default=False, |
|
|
help='apply_line_split') |
|
|
parser.add_argument('--debug', action='store_true', default=False, |
|
|
help='Debug mode (default: False)') |
|
|
parser.add_argument('--firstbyte_prob_path', type=str, default=None, |
|
|
help='Probability path for the first word of each window (default : None)') |
|
|
parser.add_argument('--num_workers', type=int, default=1, |
|
|
help='Number of workers for CPU jobs (default: 1)') |
|
|
parser.add_argument('--process_id', type=int, default=0, |
|
|
help='Process ID for distributed processing (default: 0)') |
|
|
parser.add_argument('--num_processes', type=int, default=1, |
|
|
help='Number of processes for distributed processing (default: 1)') |
|
|
args = parser.parse_args() |
|
|
|
|
|
mp.set_start_method('spawn', force=True) |
|
|
gc_freq = 100 |
|
|
dump_freq = 25 |
|
|
|
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
model, _, _ = load_m1_model_and_tokenizer(args.entropy_model_path) |
|
|
batched_predict_fn = batched_m1_compress_predict_fn(model) |
|
|
|
|
|
if args.firstbyte_prob_path is not None: |
|
|
with open(args.firstbyte_prob_path, 'r', encoding='utf-8') as f: |
|
|
first_byte_prob = json.load(f) |
|
|
print(first_byte_prob) |
|
|
first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) |
|
|
else: |
|
|
first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cuda") / ALPHABET_SIZE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = InterleavedJsonlDataset( |
|
|
file_path=args.input_file, |
|
|
rank=args.process_id, |
|
|
world_size=args.num_processes, |
|
|
) |
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=args.data_batch_size, |
|
|
shuffle=False, |
|
|
collate_fn=lambda x: x |
|
|
) |
|
|
|
|
|
input_file = Path(args.input_file) |
|
|
logger.info(f"Processing file: {input_file}") |
|
|
|
|
|
output_file = output_dir / f"{input_file.stem}_out_{args.process_id}.jsonl" |
|
|
|
|
|
logger.info("Data loaded. Start processing...") |
|
|
|
|
|
|
|
|
write_queue = mp.Queue(maxsize=200) |
|
|
writer_process = mp.Process( |
|
|
target=writer_consumer, |
|
|
args=(write_queue, output_file, dump_freq) |
|
|
) |
|
|
writer_process.start() |
|
|
|
|
|
try: |
|
|
|
|
|
for batch_idx, batch in enumerate(dataloader): |
|
|
split_points_results = calculate_entropy_and_split_points_fn( |
|
|
batch, |
|
|
batched_predict_fn, |
|
|
chunk_size=args.chunk_size, |
|
|
base_global_quantile=args.base_global_quantile, |
|
|
base_monotonic_quantile=args.base_monotonic_quantile, |
|
|
unigram_probs=first_byte_prob, |
|
|
max_m1_batch_size=args.max_entropy_batch_size, |
|
|
line_split=args.apply_line_split, |
|
|
debug=args.debug, |
|
|
) |
|
|
logger.info(f"Processed batch {batch_idx}") |
|
|
write_queue.put(split_points_results) |
|
|
|
|
|
if batch_idx % gc_freq == 0: |
|
|
|
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
write_queue.put(None) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during processing: {e}") |
|
|
|
|
|
try: |
|
|
write_queue.put(None) |
|
|
except: |
|
|
pass |
|
|
raise |
|
|
finally: |
|
|
|
|
|
writer_process.join() |
|
|
if writer_process.exitcode != 0: |
|
|
logger.error(f"Writer process failed with exit code: {writer_process.exitcode}") |
|
|
|
|
|
logger.info(f"Completed processing successfully, output written to {output_file}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|