import torch import torch.optim as optim from transformers import AutoTokenizer from tqdm import tqdm from src.config import ModelConfig, TrainConfig from src.models.autoencoder import SphericalAutoencoder from src.models.dit import PatchedFlowDiT from src.trainer import Trainer from src.utils.data_utils import prepare_data from src.utils.sandbox import SafeSandbox from src.search import DiffuMCTS def inference(ae, flow, src_ids, src_mask, device, steps=10): ae.eval(); flow.eval() with torch.no_grad(): # Encode Source (Buggy) -> z_0 z_curr = ae.encode(src_ids, src_mask) z_cond = z_curr.clone() dt = 1.0 / steps for i in range(steps): t = torch.ones(z_curr.shape[0], device=device) * (i / steps) v = flow(z_curr, t, condition=z_cond).float() z_curr = z_curr + v * dt z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1) logits = ae.decode(z_curr) return torch.argmax(logits, dim=-1) def evaluate_on_humaneval(ae, flow, tokenizer, device, num_samples=20): """ 在 HumanEvalPack 上进行真实的执行测试 """ print("\n>>> Starting Evaluation on HumanEvalPack (Real Execution)...") loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test") sandbox = SafeSandbox() passed = 0 total = 0 # 只测前 num_samples 个,节省时间 for i, batch in enumerate(tqdm(loader, total=num_samples)): if i >= num_samples: break src = batch['src_ids'].to(device) mask = batch['src_mask'].to(device) test_code = batch['test_code'][0] entry_point = batch['entry_point'][0] # 1. Flow Inference out_ids = inference(ae, flow, src, mask, device) gen_code = tokenizer.decode(out_ids[0], skip_special_tokens=True) # 2. Sandbox Execution is_pass, msg = sandbox.run(gen_code, test_code, entry_point) if is_pass: passed += 1 total += 1 # 打印第一个 Case 看看效果 if i == 0: print(f"\n[Case 0] Pass: {is_pass}") print(f"Error: {msg}") print(f"Generated:\n{gen_code[:200]}...") print(f"\n=== Eval Result ===") print(f"Pass@1: {passed}/{total} = {passed/total*100:.2f}%") def evaluate_with_mcts(ae, flow, tokenizer, device, num_samples=20): """ 使用 Diffu-MCTS 进行强化评估 """ print(f"\n>>> Starting Diffu-MCTS Evaluation (samples={num_samples})...") # 1. 准备数据和组件 loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test") sandbox = SafeSandbox(timeout=2.0) # 2秒超时防止死循环 # 2. 初始化搜索器 mcts = DiffuMCTS(ae, flow, tokenizer, sandbox, device, config=None) mcts.num_branches = 8 # 设定分支数 K=8 passed = 0 total = 0 # 3. 评估循环 for i, batch in enumerate(tqdm(loader, total=num_samples)): if i >= num_samples: break # 提取原始文本 (因为 MCTS 内部会处理 Tokenize) # 这里的 batch['src_ids'] 是 tensor,我们需要原始 string # 但 data_loader 把 string 丢了,所以我们这里反解码一下,或者修改 prepare_data 返回 raw text # 为了简单,我们反解码 Buggy Code src_ids = batch['src_ids'].to(device) buggy_code = tokenizer.decode(src_ids[0], skip_special_tokens=True) test_code = batch['test_code'][0] entry_point = batch['entry_point'][0] # --- 调用 MCTS --- fixed_code, is_success = mcts.solve(buggy_code, test_code, entry_point) if is_success: passed += 1 total += 1 # Log 第一个样本 if i == 0: print(f"\n[Case 0]") print(f"Buggy:\n{buggy_code[:100]}...") print(f"Fixed:\n{fixed_code[:100]}...") print(f"Result: {'✅ PASS' if is_success else '❌ FAIL'}") print(f"\n=== MCTS Results ===") print(f"Pass@1 (with Search K={mcts.num_branches}): {passed}/{total} = {passed/total*100:.2f}%") def main(): m_cfg = ModelConfig() t_cfg = TrainConfig(batch_size=8, grad_accum_steps=4) tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 1. Load CodeXGLUE for Training train_loader = prepare_data("codexglue", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train") ae = SphericalAutoencoder(m_cfg).to(t_cfg.device).float() # Patch pad token if ae.encoder.config.pad_token_id is None: ae.encoder.config.pad_token_id = tokenizer.pad_token_id flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float() trainer = Trainer(ae, flow, t_cfg, train_loader) # --- Training Loop --- # Step 1: Train AE print("\n>>> Training AE on CodeXGLUE...") opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae) for epoch in range(t_cfg.num_epochs_ae): loss = trainer.train_ae(opt_ae) print(f"AE Epoch {epoch}: Loss {loss:.4f}") # Step 2: Train Flow print("\n>>> Training Flow Matching on CodeXGLUE...") opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow) for epoch in range(t_cfg.num_epochs_flow): loss = trainer.train_flow(opt_flow) print(f"Flow Epoch {epoch}: Loss {loss:.4f}") # --- Evaluation --- evaluate_on_humaneval(ae, flow, tokenizer, t_cfg.device) # 训练结束后,进行 MCTS 评估 evaluate_with_mcts(ae, flow, tokenizer, t_cfg.device, num_samples=50) if __name__ == "__main__": main()