# scripts/eval_ae_consistency.py """ z0 = encoder(x) x^1 = decoder(z0) z1 = encoder(x^1) """ import argparse import torch import torch.nn.functional as F from tqdm import tqdm from transformers import AutoTokenizer from src.config import ModelConfig, TrainConfig from src.models.autoencoder import ReshapedAutoencoder from src.utils.data_utils import prepare_data def pick_stop_id(tokenizer): return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id def masked_mean(x, mask, eps=1e-6): # x: [B,L] or [B,L,D] reduced already, mask: [B,L] denom = mask.sum().clamp(min=eps) return (x * mask).sum() / denom @torch.no_grad() def main(): ap = argparse.ArgumentParser() ap.add_argument("--dataset", type=str, default="wiki") ap.add_argument("--split", type=str, default="test") ap.add_argument("--max_seq_len", type=int, default=128) ap.add_argument("--batch_size", type=int, default=16) ap.add_argument("--ckpt", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/residual_robust_checkpoints/ae_best.pt", help="path to ae.state_dict()") ap.add_argument("--max_batches", type=int, default=0, help="0 means full eval") ap.add_argument("--print_n", type=int, default=8) args = ap.parse_args() # configs m_cfg = ModelConfig( encoder_name='../jina-embeddings-v2-base-code', latent_dim=512, max_seq_len=args.max_seq_len, ) t_cfg = TrainConfig(batch_size=args.batch_size) device = t_cfg.device tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name,local_files_only=True,trust_remote_code=False) stop_id = pick_stop_id(tokenizer) loader = prepare_data(args.dataset, tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split=args.split) # test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test") ae = ReshapedAutoencoder(m_cfg).to(device).float() if args.ckpt: sd = torch.load(args.ckpt, map_location="cpu") ae.load_state_dict(sd, strict=True) ae.eval() total_ce = 0.0 total_acc = 0.0 total_tokens = 0.0 eos_found = 0 eos_pos_err = 0.0 eos_count = 0 total_cos = 0.0 total_l2 = 0.0 total_lat_tokens = 0.0 printed = 0 for bi, batch in enumerate(tqdm(loader, desc="Eval AE")): if args.max_batches and bi >= args.max_batches: break ids = batch["tgt_ids"].to(device) mask = batch["tgt_mask"].to(device) # --- forward --- z0 = ae.encode(ids, mask) # [B,L,D] logits = ae.decode(z0, attention_mask=mask) # [B,L,V] pred = logits.argmax(dim=-1) # [B,L] # --- masked CE --- labels = ids.masked_fill(mask == 0, -100) ce = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, reduction="sum") total_ce += ce.item() # --- token acc (masked) --- correct = ((pred == ids) & (mask.bool())).sum().item() tok = mask.sum().item() total_acc += correct total_tokens += tok # --- EOS stats --- # true/pred EOS position (first occurrence) B, L = ids.shape for i in range(B): # only search within valid tokens valid_len = int(mask[i].sum().item()) true_seq = ids[i, :valid_len] pred_seq = pred[i, :valid_len] true_pos = (true_seq == stop_id).nonzero(as_tuple=True)[0] pred_pos = (pred_seq == stop_id).nonzero(as_tuple=True)[0] if pred_pos.numel() > 0: eos_found += 1 if true_pos.numel() > 0: eos_count += 1 tpos = int(true_pos[0].item()) ppos = int(pred_pos[0].item()) if pred_pos.numel() > 0 else valid_len - 1 eos_pos_err += abs(ppos - tpos) # --- latent cycle: z0 -> token -> z1 --- z1 = ae.encode(pred, mask) cos = F.cosine_similarity(z0, z1, dim=-1) # [B,L] l2 = (z0 - z1).pow(2).mean(dim=-1) # [B,L] total_cos += (cos * mask).sum().item() total_l2 += (l2 * mask).sum().item() total_lat_tokens += mask.sum().item() # --- print a few examples --- if printed < args.print_n: s = tokenizer.decode(ids[0], skip_special_tokens=True) ## 这里没有进行 pos 截断 # valid_len = int(mask[0].sum().item()) # pred_seq = pred[0, :valid_len] # # 找 stop(eos/sep) # end = _first_pos(pred_seq, stop_id, default=valid_len-1) + 1 # g = tokenizer.decode(pred_seq[:end], skip_special_tokens=True) g = tokenizer.decode(pred[0], skip_special_tokens=True) print("\n--- Example ---") print("GT :", s) print("REC:", g) printed += 1 avg_ce = total_ce / max(total_tokens, 1.0) avg_acc = total_acc / max(total_tokens, 1.0) avg_cos = total_cos / max(total_lat_tokens, 1.0) avg_l2 = total_l2 / max(total_lat_tokens, 1.0) eos_found_rate = eos_found / max(total_tokens / args.max_seq_len, 1.0) # 近似 batch 数 eos_mae = eos_pos_err / max(eos_count, 1) print("\n===== AE Metrics =====") print(f"Masked CE per-token: {avg_ce:.4f}") print(f"Token Acc (masked): {avg_acc:.4f}") print(f"Latent cycle cosine(z0,z1): {avg_cos:.4f}") print(f"Latent cycle l2(z0,z1): {avg_l2:.6f}") print(f"EOS found rate (rough): {eos_found_rate:.4f}") print(f"EOS position MAE (only where GT has EOS): {eos_mae:.2f}") if __name__ == "__main__": main()