|
|
import torch |
|
|
import torch.optim as optim |
|
|
from transformers import AutoTokenizer |
|
|
from tqdm import tqdm |
|
|
|
|
|
from src.config import ModelConfig, TrainConfig |
|
|
from src.models.autoencoder import SphericalAutoencoder |
|
|
from src.models.dit import PatchedFlowDiT |
|
|
from src.trainer import Trainer |
|
|
from src.utils.data_utils import prepare_data |
|
|
from src.utils.sandbox import SafeSandbox |
|
|
from src.search import DiffuMCTS |
|
|
|
|
|
def inference(ae, flow, src_ids, src_mask, device, steps=10): |
|
|
ae.eval(); flow.eval() |
|
|
with torch.no_grad(): |
|
|
|
|
|
z_curr = ae.encode(src_ids, src_mask) |
|
|
z_cond = z_curr.clone() |
|
|
|
|
|
dt = 1.0 / steps |
|
|
for i in range(steps): |
|
|
t = torch.ones(z_curr.shape[0], device=device) * (i / steps) |
|
|
v = flow(z_curr, t, condition=z_cond).float() |
|
|
z_curr = z_curr + v * dt |
|
|
|
|
|
z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1) |
|
|
logits = ae.decode(z_curr) |
|
|
return torch.argmax(logits, dim=-1) |
|
|
|
|
|
def evaluate_on_humaneval(ae, flow, tokenizer, device, num_samples=20): |
|
|
""" |
|
|
在 HumanEvalPack 上进行真实的执行测试 |
|
|
""" |
|
|
print("\n>>> Starting Evaluation on HumanEvalPack (Real Execution)...") |
|
|
loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test") |
|
|
sandbox = SafeSandbox() |
|
|
|
|
|
passed = 0 |
|
|
total = 0 |
|
|
|
|
|
|
|
|
for i, batch in enumerate(tqdm(loader, total=num_samples)): |
|
|
if i >= num_samples: break |
|
|
|
|
|
src = batch['src_ids'].to(device) |
|
|
mask = batch['src_mask'].to(device) |
|
|
test_code = batch['test_code'][0] |
|
|
entry_point = batch['entry_point'][0] |
|
|
|
|
|
|
|
|
out_ids = inference(ae, flow, src, mask, device) |
|
|
gen_code = tokenizer.decode(out_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
is_pass, msg = sandbox.run(gen_code, test_code, entry_point) |
|
|
|
|
|
if is_pass: |
|
|
passed += 1 |
|
|
|
|
|
total += 1 |
|
|
|
|
|
|
|
|
if i == 0: |
|
|
print(f"\n[Case 0] Pass: {is_pass}") |
|
|
print(f"Error: {msg}") |
|
|
print(f"Generated:\n{gen_code[:200]}...") |
|
|
|
|
|
print(f"\n=== Eval Result ===") |
|
|
print(f"Pass@1: {passed}/{total} = {passed/total*100:.2f}%") |
|
|
|
|
|
def evaluate_with_mcts(ae, flow, tokenizer, device, num_samples=20): |
|
|
""" |
|
|
使用 Diffu-MCTS 进行强化评估 |
|
|
""" |
|
|
print(f"\n>>> Starting Diffu-MCTS Evaluation (samples={num_samples})...") |
|
|
|
|
|
|
|
|
loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test") |
|
|
sandbox = SafeSandbox(timeout=2.0) |
|
|
|
|
|
|
|
|
mcts = DiffuMCTS(ae, flow, tokenizer, sandbox, device, config=None) |
|
|
mcts.num_branches = 8 |
|
|
|
|
|
passed = 0 |
|
|
total = 0 |
|
|
|
|
|
|
|
|
for i, batch in enumerate(tqdm(loader, total=num_samples)): |
|
|
if i >= num_samples: break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src_ids = batch['src_ids'].to(device) |
|
|
buggy_code = tokenizer.decode(src_ids[0], skip_special_tokens=True) |
|
|
|
|
|
test_code = batch['test_code'][0] |
|
|
entry_point = batch['entry_point'][0] |
|
|
|
|
|
|
|
|
fixed_code, is_success = mcts.solve(buggy_code, test_code, entry_point) |
|
|
|
|
|
if is_success: |
|
|
passed += 1 |
|
|
|
|
|
total += 1 |
|
|
|
|
|
|
|
|
if i == 0: |
|
|
print(f"\n[Case 0]") |
|
|
print(f"Buggy:\n{buggy_code[:100]}...") |
|
|
print(f"Fixed:\n{fixed_code[:100]}...") |
|
|
print(f"Result: {'✅ PASS' if is_success else '❌ FAIL'}") |
|
|
|
|
|
print(f"\n=== MCTS Results ===") |
|
|
print(f"Pass@1 (with Search K={mcts.num_branches}): {passed}/{total} = {passed/total*100:.2f}%") |
|
|
|
|
|
def main(): |
|
|
m_cfg = ModelConfig() |
|
|
t_cfg = TrainConfig(batch_size=8, grad_accum_steps=4) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
train_loader = prepare_data("codexglue", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train") |
|
|
|
|
|
ae = SphericalAutoencoder(m_cfg).to(t_cfg.device).float() |
|
|
|
|
|
if ae.encoder.config.pad_token_id is None: |
|
|
ae.encoder.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float() |
|
|
|
|
|
trainer = Trainer(ae, flow, t_cfg, train_loader) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n>>> Training AE on CodeXGLUE...") |
|
|
opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae) |
|
|
for epoch in range(t_cfg.num_epochs_ae): |
|
|
loss = trainer.train_ae(opt_ae) |
|
|
print(f"AE Epoch {epoch}: Loss {loss:.4f}") |
|
|
|
|
|
|
|
|
print("\n>>> Training Flow Matching on CodeXGLUE...") |
|
|
opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow) |
|
|
for epoch in range(t_cfg.num_epochs_flow): |
|
|
loss = trainer.train_flow(opt_flow) |
|
|
print(f"Flow Epoch {epoch}: Loss {loss:.4f}") |
|
|
|
|
|
|
|
|
evaluate_on_humaneval(ae, flow, tokenizer, t_cfg.device) |
|
|
|
|
|
evaluate_with_mcts(ae, flow, tokenizer, t_cfg.device, num_samples=50) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |