File size: 1,954 Bytes
77d636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import unittest
import torch
import sys
import os

# 确保能导入 src
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from src.config import ModelConfig
from src.models.autoencoder import LatentAutoencoder
from src.models.dit import FlowDiT

class TestModels(unittest.TestCase):
    def setUp(self):
        # 构造一个小配置用于测试
        self.cfg = ModelConfig(
            encoder_name="roberta-base", 
            latent_dim=128, 
            max_seq_len=32,
            decoder_layers=2, # 快一点
            dit_layers=2
        )
        # 强行覆盖 dit_hidden 确保测试一致性 (虽然 property 已经保证了)
        # self.cfg.dit_hidden = 128 

    def test_ae_shape(self):
        print("\nTesting Autoencoder Shape...")
        model = LatentAutoencoder(self.cfg)
        input_ids = torch.randint(0, 100, (2, 32))
        mask = torch.ones((2, 32))
        logits, z = model(input_ids, mask)
        
        self.assertEqual(z.shape, (2, 32, 128))
        # 50265 是 RoBERTa 的词表大小
        self.assertEqual(logits.shape, (2, 32, 50265)) 
        print("AE Shape Check Passed.")

    def test_dit_shape(self):
        print("\nTesting DiT Shape...")
        model = FlowDiT(self.cfg)
        x = torch.randn(2, 32, 128) # B, Seq, Dim
        t = torch.rand(2)           # B
        cond = torch.randn(2, 32, 128)
        
        out = model(x, t, condition=cond)
        self.assertEqual(out.shape, (2, 32, 128))
        print("DiT Shape Check Passed.")

    def test_cfg_forward(self):
        print("\nTesting CFG Forward...")
        model = FlowDiT(self.cfg)
        x = torch.randn(2, 32, 128)
        t = torch.rand(2)
        cond = torch.randn(2, 32, 128)
        
        out = model.forward_with_cfg(x, t, cond, cfg_scale=3.0)
        self.assertEqual(out.shape, (2, 32, 128))
        print("CFG Check Passed.")

if __name__ == "__main__":
    unittest.main()