File size: 1,954 Bytes
77d636f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import unittest
import torch
import sys
import os
# 确保能导入 src
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.config import ModelConfig
from src.models.autoencoder import LatentAutoencoder
from src.models.dit import FlowDiT
class TestModels(unittest.TestCase):
def setUp(self):
# 构造一个小配置用于测试
self.cfg = ModelConfig(
encoder_name="roberta-base",
latent_dim=128,
max_seq_len=32,
decoder_layers=2, # 快一点
dit_layers=2
)
# 强行覆盖 dit_hidden 确保测试一致性 (虽然 property 已经保证了)
# self.cfg.dit_hidden = 128
def test_ae_shape(self):
print("\nTesting Autoencoder Shape...")
model = LatentAutoencoder(self.cfg)
input_ids = torch.randint(0, 100, (2, 32))
mask = torch.ones((2, 32))
logits, z = model(input_ids, mask)
self.assertEqual(z.shape, (2, 32, 128))
# 50265 是 RoBERTa 的词表大小
self.assertEqual(logits.shape, (2, 32, 50265))
print("AE Shape Check Passed.")
def test_dit_shape(self):
print("\nTesting DiT Shape...")
model = FlowDiT(self.cfg)
x = torch.randn(2, 32, 128) # B, Seq, Dim
t = torch.rand(2) # B
cond = torch.randn(2, 32, 128)
out = model(x, t, condition=cond)
self.assertEqual(out.shape, (2, 32, 128))
print("DiT Shape Check Passed.")
def test_cfg_forward(self):
print("\nTesting CFG Forward...")
model = FlowDiT(self.cfg)
x = torch.randn(2, 32, 128)
t = torch.rand(2)
cond = torch.randn(2, 32, 128)
out = model.forward_with_cfg(x, t, cond, cfg_scale=3.0)
self.assertEqual(out.shape, (2, 32, 128))
print("CFG Check Passed.")
if __name__ == "__main__":
unittest.main() |