import unittest import torch import sys import os # 确保能导入 src sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from src.config import ModelConfig from src.models.autoencoder import LatentAutoencoder from src.models.dit import FlowDiT class TestModels(unittest.TestCase): def setUp(self): # 构造一个小配置用于测试 self.cfg = ModelConfig( encoder_name="roberta-base", latent_dim=128, max_seq_len=32, decoder_layers=2, # 快一点 dit_layers=2 ) # 强行覆盖 dit_hidden 确保测试一致性 (虽然 property 已经保证了) # self.cfg.dit_hidden = 128 def test_ae_shape(self): print("\nTesting Autoencoder Shape...") model = LatentAutoencoder(self.cfg) input_ids = torch.randint(0, 100, (2, 32)) mask = torch.ones((2, 32)) logits, z = model(input_ids, mask) self.assertEqual(z.shape, (2, 32, 128)) # 50265 是 RoBERTa 的词表大小 self.assertEqual(logits.shape, (2, 32, 50265)) print("AE Shape Check Passed.") def test_dit_shape(self): print("\nTesting DiT Shape...") model = FlowDiT(self.cfg) x = torch.randn(2, 32, 128) # B, Seq, Dim t = torch.rand(2) # B cond = torch.randn(2, 32, 128) out = model(x, t, condition=cond) self.assertEqual(out.shape, (2, 32, 128)) print("DiT Shape Check Passed.") def test_cfg_forward(self): print("\nTesting CFG Forward...") model = FlowDiT(self.cfg) x = torch.randn(2, 32, 128) t = torch.rand(2) cond = torch.randn(2, 32, 128) out = model.forward_with_cfg(x, t, cond, cfg_scale=3.0) self.assertEqual(out.shape, (2, 32, 128)) print("CFG Check Passed.") if __name__ == "__main__": unittest.main()