import torch import torch.optim as optim from transformers import AutoTokenizer from tqdm import tqdm import torch.nn.functional as F import os import argparse import sacrebleu from src.config import ModelConfig, TrainConfig from src.models.autoencoder import ReshapedAutoencoder from src.models.dit import PatchedFlowDiT from src.trainer import Trainer from src.utils.data_utils import prepare_data # --- Helper Functions for Inference (复制过来以便独立运行) --- def _pick_stop_id(tokenizer): return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id def _first_pos(x_1d, token_id, default): idx = (x_1d == token_id).nonzero(as_tuple=True)[0] return idx[0].item() if idx.numel() > 0 else default def calculate_metrics(sources, predictions, references): bleu = sacrebleu.corpus_bleu(predictions, [references]) try: sari = sacrebleu.corpus_sari(sources, predictions, [references]) sari_score = sari.score except Exception: sari_score = 0.0 ratios = [len(p) / len(s) if len(s) > 0 else 0 for p, s in zip(predictions, sources)] avg_ratio = sum(ratios) / len(ratios) if ratios else 0 return {"SARI": sari_score, "BLEU": bleu.score, "Compression Ratio": avg_ratio} @torch.no_grad() def inference_batch(ae, flow, loader, tokenizer, device, steps=10, save_path="results.txt", use_oneshot=True): ae.eval() flow.eval() stop_id = _pick_stop_id(tokenizer) pad_id = tokenizer.pad_token_id print(f"\n>>> Running Inference on {len(loader.dataset)} examples...") all_sources, all_targets, all_generated = [], [], [] scale = getattr(ae, "latent_scale", 10.0) # 兼容逻辑 with open(save_path, "w", encoding="utf-8") as f: f.write("Source\tTarget\tGenerated\n") for batch in tqdm(loader, desc="Inferencing"): src_ids = batch['src_ids'].to(device) src_mask = batch['src_mask'].to(device) tgt_ids = batch['tgt_ids'].to(device) B, L = src_ids.shape # Encode z_curr = ae.encode(src_ids, src_mask) z_cond = z_curr.clone() # Flow Sampling if use_oneshot: t0 = torch.zeros(B, device=device) z_curr = flow(z_curr, t0, condition=z_cond).float() else: dt = 1.0 / steps for i in range(steps): t_val = i / steps if t_val >= 0.999: break t = torch.ones(B, device=device) * t_val pred_z1 = flow(z_curr, t, condition=z_cond).float() v = (pred_z1 - z_curr) / (1.0 - t_val + 1e-4) z_curr = z_curr + v * dt z_curr = pred_z1 # Decode (Pass 1: Detect Length) full_mask = torch.ones(B, L, device=device) logits1 = ae.decode(z_curr, attention_mask=full_mask) ids1 = logits1.argmax(dim=-1) stop_pos = [] for i in range(B): pos = _first_pos(ids1[i], stop_id, default=L - 1) stop_pos.append(pos) # Decode (Pass 2: Clean Decode) gen_mask = torch.zeros(B, L, device=device) for i in range(B): gen_mask[i, : stop_pos[i] + 1] = 1.0 logits2 = ae.decode(z_curr, attention_mask=gen_mask) ids2 = logits2.argmax(dim=-1) ids2 = ids2.masked_fill(gen_mask == 0, pad_id) # Convert to Text src_texts = tokenizer.batch_decode(src_ids, skip_special_tokens=True) tgt_texts = tokenizer.batch_decode(tgt_ids, skip_special_tokens=True) gen_texts = [] for i in range(B): end = stop_pos[i] + 1 ids_cut = ids2[i, :end] gen_texts.append(tokenizer.decode(ids_cut, skip_special_tokens=True)) for s, t, g in zip(src_texts, tgt_texts, gen_texts): s_c = s.replace("\n", " ") t_c = t.replace("\n", " ") g_c = g.replace("\n", " ") f.write(f"{s_c}\t{t_c}\t{g_c}\n") all_sources.append(s_c) all_targets.append(t_c) all_generated.append(g_c) return all_sources, all_targets, all_generated def main(): parser = argparse.ArgumentParser() parser.add_argument("--ae_ckpt", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/residual_robust_checkpoints/ae_best.pt", help="Path to pre-trained AE checkpoint") parser.add_argument("--save_dir", type=str, default="residual_robust_checkpoints", help="Directory to save flow checkpoints") parser.add_argument("--use_oneshot", action="store_true", default=True, help="Use one-shot sampling for inference") args = parser.parse_args() os.makedirs(args.save_dir, exist_ok=True) # --- Config --- m_cfg = ModelConfig( encoder_name='../jina-embeddings-v2-base-code', latent_dim=512, max_seq_len=128 ) t_cfg = TrainConfig( batch_size=16, num_epochs_flow=35, # 只关注 Flow 的 epoch grad_accum_steps=4, use_amp=False, lr_flow=2e-4 ) # --- Tokenizer & Data --- tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name,local_files_only=True, trust_remote_code=False) train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train") test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test") # --- Load AE (Pre-trained) --- print(f"\n>>> Loading Pre-trained Autoencoder from {args.ae_ckpt} ...") ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float() if not os.path.exists(args.ae_ckpt): raise FileNotFoundError(f"AE checkpoint not found at {args.ae_ckpt}. Please run train_ae.py first.") ae.load_state_dict(torch.load(args.ae_ckpt, map_location=t_cfg.device)) # 冻结 AE 的所有参数,Flow 训练时不更新 AE ae.eval() for param in ae.parameters(): param.requires_grad = False print(">>> Autoencoder loaded and frozen.") if ae.encoder.config.pad_token_id is None: ae.encoder.config.pad_token_id = tokenizer.pad_token_id # --- Initialize Flow --- flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float() # --- Trainer --- trainer = Trainer( ae=ae, flow=flow, cfg=t_cfg, loader=train_loader, pad_id=tokenizer.pad_token_id, stop_id=_pick_stop_id(tokenizer) ) # --- Optimizer --- opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow) # --- Training Loop --- best_flow_loss = float('inf') print("\n>>> Start Training Flow DiT...") for epoch in range(t_cfg.num_epochs_flow): # 传入 opt_flow 训练 Flow loss = trainer.train_flow(opt_flow) print(f"Flow Epoch {epoch}: Loss {loss:.4f}") # Save Best if loss < best_flow_loss: best_flow_loss = loss save_path = os.path.join(args.save_dir, "flow_best.pt") torch.save(flow.state_dict(), save_path) # print(f" Saved Best Flow to {save_path}") # Save Last torch.save(flow.state_dict(), os.path.join(args.save_dir, "flow_last.pt")) print(f"Flow Training Done. Best Loss: {best_flow_loss:.4f}") # --- Inference / Evaluation --- print("\n>>> Loading Best Flow Checkpoint for Evaluation...") best_flow_path = os.path.join(args.save_dir, "flow_best.pt") if os.path.exists(best_flow_path): flow.load_state_dict(torch.load(best_flow_path, map_location=t_cfg.device)) else: print("Warning: Best checkpoint not found, utilizing last epoch weights.") print("\n--- Starting Inference ---") sources, targets, gens = inference_batch( ae, flow, test_loader, tokenizer, t_cfg.device, steps=10, save_path="wiki_results.tsv", use_oneshot=args.use_oneshot ) # Metrics metrics = calculate_metrics(sources, gens, targets) print("\n=== Metrics ===") for k, v in metrics.items(): print(f"{k}: {v:.4f}") print(f"\nResults saved to wiki_results.tsv") if __name__ == "__main__": main()