Diff-Refine / run_repair_flow.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
import torch
import torch.optim as optim
from transformers import AutoTokenizer
from tqdm import tqdm
from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import SphericalAutoencoder
from src.models.dit import PatchedFlowDiT
from src.trainer import Trainer
from src.utils.data_utils import prepare_data
from src.utils.sandbox import SafeSandbox
from src.search import DiffuMCTS
def inference(ae, flow, src_ids, src_mask, device, steps=10):
ae.eval(); flow.eval()
with torch.no_grad():
# Encode Source (Buggy) -> z_0
z_curr = ae.encode(src_ids, src_mask)
z_cond = z_curr.clone()
dt = 1.0 / steps
for i in range(steps):
t = torch.ones(z_curr.shape[0], device=device) * (i / steps)
v = flow(z_curr, t, condition=z_cond).float()
z_curr = z_curr + v * dt
z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1)
logits = ae.decode(z_curr)
return torch.argmax(logits, dim=-1)
def evaluate_on_humaneval(ae, flow, tokenizer, device, num_samples=20):
"""
在 HumanEvalPack 上进行真实的执行测试
"""
print("\n>>> Starting Evaluation on HumanEvalPack (Real Execution)...")
loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test")
sandbox = SafeSandbox()
passed = 0
total = 0
# 只测前 num_samples 个,节省时间
for i, batch in enumerate(tqdm(loader, total=num_samples)):
if i >= num_samples: break
src = batch['src_ids'].to(device)
mask = batch['src_mask'].to(device)
test_code = batch['test_code'][0]
entry_point = batch['entry_point'][0]
# 1. Flow Inference
out_ids = inference(ae, flow, src, mask, device)
gen_code = tokenizer.decode(out_ids[0], skip_special_tokens=True)
# 2. Sandbox Execution
is_pass, msg = sandbox.run(gen_code, test_code, entry_point)
if is_pass:
passed += 1
total += 1
# 打印第一个 Case 看看效果
if i == 0:
print(f"\n[Case 0] Pass: {is_pass}")
print(f"Error: {msg}")
print(f"Generated:\n{gen_code[:200]}...")
print(f"\n=== Eval Result ===")
print(f"Pass@1: {passed}/{total} = {passed/total*100:.2f}%")
def evaluate_with_mcts(ae, flow, tokenizer, device, num_samples=20):
"""
使用 Diffu-MCTS 进行强化评估
"""
print(f"\n>>> Starting Diffu-MCTS Evaluation (samples={num_samples})...")
# 1. 准备数据和组件
loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test")
sandbox = SafeSandbox(timeout=2.0) # 2秒超时防止死循环
# 2. 初始化搜索器
mcts = DiffuMCTS(ae, flow, tokenizer, sandbox, device, config=None)
mcts.num_branches = 8 # 设定分支数 K=8
passed = 0
total = 0
# 3. 评估循环
for i, batch in enumerate(tqdm(loader, total=num_samples)):
if i >= num_samples: break
# 提取原始文本 (因为 MCTS 内部会处理 Tokenize)
# 这里的 batch['src_ids'] 是 tensor,我们需要原始 string
# 但 data_loader 把 string 丢了,所以我们这里反解码一下,或者修改 prepare_data 返回 raw text
# 为了简单,我们反解码 Buggy Code
src_ids = batch['src_ids'].to(device)
buggy_code = tokenizer.decode(src_ids[0], skip_special_tokens=True)
test_code = batch['test_code'][0]
entry_point = batch['entry_point'][0]
# --- 调用 MCTS ---
fixed_code, is_success = mcts.solve(buggy_code, test_code, entry_point)
if is_success:
passed += 1
total += 1
# Log 第一个样本
if i == 0:
print(f"\n[Case 0]")
print(f"Buggy:\n{buggy_code[:100]}...")
print(f"Fixed:\n{fixed_code[:100]}...")
print(f"Result: {'✅ PASS' if is_success else '❌ FAIL'}")
print(f"\n=== MCTS Results ===")
print(f"Pass@1 (with Search K={mcts.num_branches}): {passed}/{total} = {passed/total*100:.2f}%")
def main():
m_cfg = ModelConfig()
t_cfg = TrainConfig(batch_size=8, grad_accum_steps=4)
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
# 1. Load CodeXGLUE for Training
train_loader = prepare_data("codexglue", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
ae = SphericalAutoencoder(m_cfg).to(t_cfg.device).float()
# Patch pad token
if ae.encoder.config.pad_token_id is None:
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float()
trainer = Trainer(ae, flow, t_cfg, train_loader)
# --- Training Loop ---
# Step 1: Train AE
print("\n>>> Training AE on CodeXGLUE...")
opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
for epoch in range(t_cfg.num_epochs_ae):
loss = trainer.train_ae(opt_ae)
print(f"AE Epoch {epoch}: Loss {loss:.4f}")
# Step 2: Train Flow
print("\n>>> Training Flow Matching on CodeXGLUE...")
opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow)
for epoch in range(t_cfg.num_epochs_flow):
loss = trainer.train_flow(opt_flow)
print(f"Flow Epoch {epoch}: Loss {loss:.4f}")
# --- Evaluation ---
evaluate_on_humaneval(ae, flow, tokenizer, t_cfg.device)
# 训练结束后,进行 MCTS 评估
evaluate_with_mcts(ae, flow, tokenizer, t_cfg.device, num_samples=50)
if __name__ == "__main__":
main()