Diff-Refine / tests /test_models.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
import unittest
import torch
import sys
import os
# 确保能导入 src
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.config import ModelConfig
from src.models.autoencoder import LatentAutoencoder
from src.models.dit import FlowDiT
class TestModels(unittest.TestCase):
def setUp(self):
# 构造一个小配置用于测试
self.cfg = ModelConfig(
encoder_name="roberta-base",
latent_dim=128,
max_seq_len=32,
decoder_layers=2, # 快一点
dit_layers=2
)
# 强行覆盖 dit_hidden 确保测试一致性 (虽然 property 已经保证了)
# self.cfg.dit_hidden = 128
def test_ae_shape(self):
print("\nTesting Autoencoder Shape...")
model = LatentAutoencoder(self.cfg)
input_ids = torch.randint(0, 100, (2, 32))
mask = torch.ones((2, 32))
logits, z = model(input_ids, mask)
self.assertEqual(z.shape, (2, 32, 128))
# 50265 是 RoBERTa 的词表大小
self.assertEqual(logits.shape, (2, 32, 50265))
print("AE Shape Check Passed.")
def test_dit_shape(self):
print("\nTesting DiT Shape...")
model = FlowDiT(self.cfg)
x = torch.randn(2, 32, 128) # B, Seq, Dim
t = torch.rand(2) # B
cond = torch.randn(2, 32, 128)
out = model(x, t, condition=cond)
self.assertEqual(out.shape, (2, 32, 128))
print("DiT Shape Check Passed.")
def test_cfg_forward(self):
print("\nTesting CFG Forward...")
model = FlowDiT(self.cfg)
x = torch.randn(2, 32, 128)
t = torch.rand(2)
cond = torch.randn(2, 32, 128)
out = model.forward_with_cfg(x, t, cond, cfg_scale=3.0)
self.assertEqual(out.shape, (2, 32, 128))
print("CFG Check Passed.")
if __name__ == "__main__":
unittest.main()