File size: 5,877 Bytes
77d636f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import torch
import torch.optim as optim
from transformers import AutoTokenizer
from tqdm import tqdm
from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import SphericalAutoencoder
from src.models.dit import PatchedFlowDiT
from src.trainer import Trainer
from src.utils.data_utils import prepare_data
from src.utils.sandbox import SafeSandbox
from src.search import DiffuMCTS
def inference(ae, flow, src_ids, src_mask, device, steps=10):
ae.eval(); flow.eval()
with torch.no_grad():
# Encode Source (Buggy) -> z_0
z_curr = ae.encode(src_ids, src_mask)
z_cond = z_curr.clone()
dt = 1.0 / steps
for i in range(steps):
t = torch.ones(z_curr.shape[0], device=device) * (i / steps)
v = flow(z_curr, t, condition=z_cond).float()
z_curr = z_curr + v * dt
z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1)
logits = ae.decode(z_curr)
return torch.argmax(logits, dim=-1)
def evaluate_on_humaneval(ae, flow, tokenizer, device, num_samples=20):
"""
在 HumanEvalPack 上进行真实的执行测试
"""
print("\n>>> Starting Evaluation on HumanEvalPack (Real Execution)...")
loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test")
sandbox = SafeSandbox()
passed = 0
total = 0
# 只测前 num_samples 个,节省时间
for i, batch in enumerate(tqdm(loader, total=num_samples)):
if i >= num_samples: break
src = batch['src_ids'].to(device)
mask = batch['src_mask'].to(device)
test_code = batch['test_code'][0]
entry_point = batch['entry_point'][0]
# 1. Flow Inference
out_ids = inference(ae, flow, src, mask, device)
gen_code = tokenizer.decode(out_ids[0], skip_special_tokens=True)
# 2. Sandbox Execution
is_pass, msg = sandbox.run(gen_code, test_code, entry_point)
if is_pass:
passed += 1
total += 1
# 打印第一个 Case 看看效果
if i == 0:
print(f"\n[Case 0] Pass: {is_pass}")
print(f"Error: {msg}")
print(f"Generated:\n{gen_code[:200]}...")
print(f"\n=== Eval Result ===")
print(f"Pass@1: {passed}/{total} = {passed/total*100:.2f}%")
def evaluate_with_mcts(ae, flow, tokenizer, device, num_samples=20):
"""
使用 Diffu-MCTS 进行强化评估
"""
print(f"\n>>> Starting Diffu-MCTS Evaluation (samples={num_samples})...")
# 1. 准备数据和组件
loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test")
sandbox = SafeSandbox(timeout=2.0) # 2秒超时防止死循环
# 2. 初始化搜索器
mcts = DiffuMCTS(ae, flow, tokenizer, sandbox, device, config=None)
mcts.num_branches = 8 # 设定分支数 K=8
passed = 0
total = 0
# 3. 评估循环
for i, batch in enumerate(tqdm(loader, total=num_samples)):
if i >= num_samples: break
# 提取原始文本 (因为 MCTS 内部会处理 Tokenize)
# 这里的 batch['src_ids'] 是 tensor,我们需要原始 string
# 但 data_loader 把 string 丢了,所以我们这里反解码一下,或者修改 prepare_data 返回 raw text
# 为了简单,我们反解码 Buggy Code
src_ids = batch['src_ids'].to(device)
buggy_code = tokenizer.decode(src_ids[0], skip_special_tokens=True)
test_code = batch['test_code'][0]
entry_point = batch['entry_point'][0]
# --- 调用 MCTS ---
fixed_code, is_success = mcts.solve(buggy_code, test_code, entry_point)
if is_success:
passed += 1
total += 1
# Log 第一个样本
if i == 0:
print(f"\n[Case 0]")
print(f"Buggy:\n{buggy_code[:100]}...")
print(f"Fixed:\n{fixed_code[:100]}...")
print(f"Result: {'✅ PASS' if is_success else '❌ FAIL'}")
print(f"\n=== MCTS Results ===")
print(f"Pass@1 (with Search K={mcts.num_branches}): {passed}/{total} = {passed/total*100:.2f}%")
def main():
m_cfg = ModelConfig()
t_cfg = TrainConfig(batch_size=8, grad_accum_steps=4)
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
# 1. Load CodeXGLUE for Training
train_loader = prepare_data("codexglue", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
ae = SphericalAutoencoder(m_cfg).to(t_cfg.device).float()
# Patch pad token
if ae.encoder.config.pad_token_id is None:
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float()
trainer = Trainer(ae, flow, t_cfg, train_loader)
# --- Training Loop ---
# Step 1: Train AE
print("\n>>> Training AE on CodeXGLUE...")
opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
for epoch in range(t_cfg.num_epochs_ae):
loss = trainer.train_ae(opt_ae)
print(f"AE Epoch {epoch}: Loss {loss:.4f}")
# Step 2: Train Flow
print("\n>>> Training Flow Matching on CodeXGLUE...")
opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow)
for epoch in range(t_cfg.num_epochs_flow):
loss = trainer.train_flow(opt_flow)
print(f"Flow Epoch {epoch}: Loss {loss:.4f}")
# --- Evaluation ---
evaluate_on_humaneval(ae, flow, tokenizer, t_cfg.device)
# 训练结束后,进行 MCTS 评估
evaluate_with_mcts(ae, flow, tokenizer, t_cfg.device, num_samples=50)
if __name__ == "__main__":
main() |