|
|
import unittest |
|
|
import torch |
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
|
|
from src.config import ModelConfig |
|
|
from src.models.autoencoder import LatentAutoencoder |
|
|
from src.models.dit import FlowDiT |
|
|
|
|
|
class TestModels(unittest.TestCase): |
|
|
def setUp(self): |
|
|
|
|
|
self.cfg = ModelConfig( |
|
|
encoder_name="roberta-base", |
|
|
latent_dim=128, |
|
|
max_seq_len=32, |
|
|
decoder_layers=2, |
|
|
dit_layers=2 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_ae_shape(self): |
|
|
print("\nTesting Autoencoder Shape...") |
|
|
model = LatentAutoencoder(self.cfg) |
|
|
input_ids = torch.randint(0, 100, (2, 32)) |
|
|
mask = torch.ones((2, 32)) |
|
|
logits, z = model(input_ids, mask) |
|
|
|
|
|
self.assertEqual(z.shape, (2, 32, 128)) |
|
|
|
|
|
self.assertEqual(logits.shape, (2, 32, 50265)) |
|
|
print("AE Shape Check Passed.") |
|
|
|
|
|
def test_dit_shape(self): |
|
|
print("\nTesting DiT Shape...") |
|
|
model = FlowDiT(self.cfg) |
|
|
x = torch.randn(2, 32, 128) |
|
|
t = torch.rand(2) |
|
|
cond = torch.randn(2, 32, 128) |
|
|
|
|
|
out = model(x, t, condition=cond) |
|
|
self.assertEqual(out.shape, (2, 32, 128)) |
|
|
print("DiT Shape Check Passed.") |
|
|
|
|
|
def test_cfg_forward(self): |
|
|
print("\nTesting CFG Forward...") |
|
|
model = FlowDiT(self.cfg) |
|
|
x = torch.randn(2, 32, 128) |
|
|
t = torch.rand(2) |
|
|
cond = torch.randn(2, 32, 128) |
|
|
|
|
|
out = model.forward_with_cfg(x, t, cond, cfg_scale=3.0) |
|
|
self.assertEqual(out.shape, (2, 32, 128)) |
|
|
print("CFG Check Passed.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
unittest.main() |