File size: 5,877 Bytes
77d636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
import torch.optim as optim
from transformers import AutoTokenizer
from tqdm import tqdm

from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import SphericalAutoencoder
from src.models.dit import PatchedFlowDiT
from src.trainer import Trainer
from src.utils.data_utils import prepare_data
from src.utils.sandbox import SafeSandbox
from src.search import DiffuMCTS

def inference(ae, flow, src_ids, src_mask, device, steps=10):
    ae.eval(); flow.eval()
    with torch.no_grad():
        # Encode Source (Buggy) -> z_0
        z_curr = ae.encode(src_ids, src_mask)
        z_cond = z_curr.clone()
        
        dt = 1.0 / steps
        for i in range(steps):
            t = torch.ones(z_curr.shape[0], device=device) * (i / steps)
            v = flow(z_curr, t, condition=z_cond).float()
            z_curr = z_curr + v * dt
            
        z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1)
        logits = ae.decode(z_curr)
        return torch.argmax(logits, dim=-1)

def evaluate_on_humaneval(ae, flow, tokenizer, device, num_samples=20):
    """
    在 HumanEvalPack 上进行真实的执行测试
    """
    print("\n>>> Starting Evaluation on HumanEvalPack (Real Execution)...")
    loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test")
    sandbox = SafeSandbox()
    
    passed = 0
    total = 0
    
    # 只测前 num_samples 个,节省时间
    for i, batch in enumerate(tqdm(loader, total=num_samples)):
        if i >= num_samples: break
        
        src = batch['src_ids'].to(device)
        mask = batch['src_mask'].to(device)
        test_code = batch['test_code'][0]
        entry_point = batch['entry_point'][0]
        
        # 1. Flow Inference
        out_ids = inference(ae, flow, src, mask, device)
        gen_code = tokenizer.decode(out_ids[0], skip_special_tokens=True)
        
        # 2. Sandbox Execution
        is_pass, msg = sandbox.run(gen_code, test_code, entry_point)
        
        if is_pass:
            passed += 1
        
        total += 1
        
        # 打印第一个 Case 看看效果
        if i == 0:
            print(f"\n[Case 0] Pass: {is_pass}")
            print(f"Error: {msg}")
            print(f"Generated:\n{gen_code[:200]}...")

    print(f"\n=== Eval Result ===")
    print(f"Pass@1: {passed}/{total} = {passed/total*100:.2f}%")

def evaluate_with_mcts(ae, flow, tokenizer, device, num_samples=20):
    """
    使用 Diffu-MCTS 进行强化评估
    """
    print(f"\n>>> Starting Diffu-MCTS Evaluation (samples={num_samples})...")
    
    # 1. 准备数据和组件
    loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test")
    sandbox = SafeSandbox(timeout=2.0) # 2秒超时防止死循环
    
    # 2. 初始化搜索器
    mcts = DiffuMCTS(ae, flow, tokenizer, sandbox, device, config=None)
    mcts.num_branches = 8 # 设定分支数 K=8
    
    passed = 0
    total = 0
    
    # 3. 评估循环
    for i, batch in enumerate(tqdm(loader, total=num_samples)):
        if i >= num_samples: break
        
        # 提取原始文本 (因为 MCTS 内部会处理 Tokenize)
        # 这里的 batch['src_ids'] 是 tensor,我们需要原始 string
        # 但 data_loader 把 string 丢了,所以我们这里反解码一下,或者修改 prepare_data 返回 raw text
        # 为了简单,我们反解码 Buggy Code
        src_ids = batch['src_ids'].to(device)
        buggy_code = tokenizer.decode(src_ids[0], skip_special_tokens=True)
        
        test_code = batch['test_code'][0]
        entry_point = batch['entry_point'][0]
        
        # --- 调用 MCTS ---
        fixed_code, is_success = mcts.solve(buggy_code, test_code, entry_point)
        
        if is_success:
            passed += 1
        
        total += 1
        
        # Log 第一个样本
        if i == 0:
            print(f"\n[Case 0]")
            print(f"Buggy:\n{buggy_code[:100]}...")
            print(f"Fixed:\n{fixed_code[:100]}...")
            print(f"Result: {'✅ PASS' if is_success else '❌ FAIL'}")

    print(f"\n=== MCTS Results ===")
    print(f"Pass@1 (with Search K={mcts.num_branches}): {passed}/{total} = {passed/total*100:.2f}%")

def main():
    m_cfg = ModelConfig()
    t_cfg = TrainConfig(batch_size=8, grad_accum_steps=4)
    
    tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
    
    # 1. Load CodeXGLUE for Training
    train_loader = prepare_data("codexglue", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
    
    ae = SphericalAutoencoder(m_cfg).to(t_cfg.device).float()
    # Patch pad token
    if ae.encoder.config.pad_token_id is None:
        ae.encoder.config.pad_token_id = tokenizer.pad_token_id
        
    flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float()
    
    trainer = Trainer(ae, flow, t_cfg, train_loader)
    
    # --- Training Loop ---
    
    # Step 1: Train AE
    print("\n>>> Training AE on CodeXGLUE...")
    opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
    for epoch in range(t_cfg.num_epochs_ae):
        loss = trainer.train_ae(opt_ae)
        print(f"AE Epoch {epoch}: Loss {loss:.4f}")
        
    # Step 2: Train Flow
    print("\n>>> Training Flow Matching on CodeXGLUE...")
    opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow)
    for epoch in range(t_cfg.num_epochs_flow):
        loss = trainer.train_flow(opt_flow)
        print(f"Flow Epoch {epoch}: Loss {loss:.4f}")
        
    # --- Evaluation ---
    evaluate_on_humaneval(ae, flow, tokenizer, t_cfg.device)
    # 训练结束后,进行 MCTS 评估
    evaluate_with_mcts(ae, flow, tokenizer, t_cfg.device, num_samples=50)

if __name__ == "__main__":
    main()