JacobLinCool commited on
Commit
d69b6f3
·
verified ·
1 Parent(s): 99afb96

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ - audio
6
+ - rhythm-game
7
+ - music
8
+ ---
9
+
10
+ # GameChartEvaluator (GCE4)
11
+
12
+ A neural network model for evaluating the quality of rhythm game charts relative to their corresponding music. The model predicts a quality score (0-1) indicating how well a chart synchronizes with the music.
13
+
14
+ ## Model Architecture
15
+
16
+ The model uses an early fusion approach with dilated convolutions for temporal analysis:
17
+
18
+ 1. **Early Fusion**: Concatenates music and chart mel spectrograms along the channel dimension (80 + 80 = 160 channels)
19
+ 2. **Dilated Residual Encoder**: 4 residual blocks with increasing dilation rates (1, 2, 4, 8) to capture multi-scale temporal context while preserving 11ms frame resolution
20
+ 3. **Error-Sensitive Scoring Head**: Combines average local scores with the worst 10% of scores using a learnable mixing parameter
21
+
22
+ ```
23
+ Input: (B, 80, T) music_mels + (B, 80, T) chart_mels
24
+ ↓ Concatenate
25
+ (B, 160, T)
26
+ ↓ Conv1D Projection
27
+ (B, 128, T)
28
+ ↓ Dilated ResBlocks × 4
29
+ (B, 128, T)
30
+ ↓ Linear → Sigmoid (per-frame scores)
31
+ (B, T, 1)
32
+ ↓ Error-Sensitive Pooling
33
+ (B,) final score
34
+ ```
35
+
36
+ ## Usage
37
+
38
+ ```python
39
+ import torch
40
+ from gce4 import GameChartEvaluator
41
+
42
+ model = GameChartEvaluator.from_pretrained("JacobLinCool/gce4")
43
+ model.eval()
44
+
45
+ # Input: 80-band mel spectrograms
46
+ music_mels = torch.randn(1, 80, 1000) # (batch, freq, time)
47
+ chart_mels = torch.randn(1, 80, 1000)
48
+
49
+ # Get overall quality score (0-1)
50
+ with torch.no_grad():
51
+ score = model(music_mels, chart_mels)
52
+ print(f"Quality Score: {score.item():.3f}")
53
+
54
+ # Get per-frame quality trace for explainability
55
+ with torch.no_grad():
56
+ trace = model.predict_trace(music_mels, chart_mels)
57
+ # trace shape: (batch, time)
58
+ ```
59
+
60
+ ## Input Specifications
61
+
62
+ - **music_mels**: `(Batch, 80, Time)` - Mel spectrogram of the music
63
+ - **chart_mels**: `(Batch, 80, Time)` - Mel spectrogram of synthesized chart audio (click sounds at note positions)
64
+
65
+ Both inputs should be normalized and have the same temporal dimensions.
66
+
67
+ ## Output
68
+
69
+ - **forward()**: `(Batch,)` - Single quality score per sample in range [0, 1]
70
+ - **predict_trace()**: `(Batch, Time)` - Per-frame quality scores for interpretability
71
+
72
+ ## Model Configuration
73
+
74
+ | Parameter | Default | Description |
75
+ |-----------|---------|-------------|
76
+ | `input_dim` | 80 | Mel spectrogram frequency bins |
77
+ | `d_model` | 128 | Hidden dimension |
78
+ | `n_layers` | 4 | Number of residual blocks |
79
+
80
+ ## Training
81
+
82
+ The model was trained to detect misaligned or poorly-synchronized rhythm game charts by comparing music-chart pairs with various synthetic corruptions (time shifts, random note placement, etc).
config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "d_model": 128,
3
+ "input_dim": 80,
4
+ "n_layers": 4
5
+ }
gce4.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+
7
+
8
+ class ResBlock1D(nn.Module):
9
+ """
10
+ Residual Block for extracting rhythmic features from audio spectrograms.
11
+ Maintains temporal resolution while increasing receptive field.
12
+ """
13
+
14
+ def __init__(self, channels, kernel_size=3, dilation=1):
15
+ super().__init__()
16
+ padding = (kernel_size - 1) * dilation // 2
17
+ self.conv1 = nn.Conv1d(
18
+ channels, channels, kernel_size, padding=padding, dilation=dilation
19
+ )
20
+ self.bn1 = nn.BatchNorm1d(channels)
21
+ self.conv2 = nn.Conv1d(
22
+ channels, channels, kernel_size, padding=padding, dilation=dilation
23
+ )
24
+ self.bn2 = nn.BatchNorm1d(channels)
25
+
26
+ def forward(self, x):
27
+ res = x
28
+ x = F.gelu(self.bn1(self.conv1(x)))
29
+ x = self.bn2(self.conv2(x))
30
+ return F.gelu(x + res)
31
+
32
+
33
+ class GameChartEvaluator(nn.Module, PyTorchModelHubMixin):
34
+ def __init__(self, input_dim=80, d_model=128, n_layers=4):
35
+ super().__init__()
36
+
37
+ # --- Early Fusion ---
38
+ # Input is (Batch, 80 * 2, Time)
39
+ # We stack Music (80) + Chart (80) = 160 channels
40
+ self.input_proj = nn.Conv1d(
41
+ input_dim * 2, d_model, kernel_size=3, stride=1, padding=1
42
+ )
43
+
44
+ # --- STRICT TEMPORAL ENCODER ---
45
+ # No Pooling (stride=1) to preserve 11ms resolution
46
+ # Dilations allow seeing context without losing resolution
47
+ self.encoder = nn.Sequential(
48
+ ResBlock1D(d_model, kernel_size=3, dilation=1),
49
+ ResBlock1D(d_model, kernel_size=3, dilation=2),
50
+ ResBlock1D(d_model, kernel_size=3, dilation=4),
51
+ ResBlock1D(d_model, kernel_size=3, dilation=8),
52
+ # Add more layers if you need wider context (e.g. 16, 32)
53
+ )
54
+
55
+ # --- SCORING HEAD ---
56
+ # Simple projection to scalar
57
+ self.quality_proj = nn.Linear(d_model, 1)
58
+
59
+ # Learnable Mixing
60
+ self.raw_severity = nn.Parameter(torch.tensor(0.0))
61
+
62
+ def forward(self, music_mels, chart_mels):
63
+ """
64
+ music_mels: (Batch, 80, Time)
65
+ chart_mels: (Batch, 80, Time)
66
+ """
67
+ # 1. Early Fusion: Concatenate along Channel dimension
68
+ # Shape becomes (Batch, 160, Time)
69
+ x = torch.cat([music_mels, chart_mels], dim=1)
70
+
71
+ # 2. Extract Features (Strictly Local + Context)
72
+ x = F.gelu(self.input_proj(x))
73
+ x = self.encoder(x)
74
+
75
+ # 3. Predict Score per Frame
76
+ # (Batch, Dim, Time) -> (Batch, Time, Dim)
77
+ x = x.permute(0, 2, 1)
78
+ local_scores = torch.sigmoid(self.quality_proj(x)) # (Batch, Time, 1)
79
+
80
+ # 4. Error-Sensitive Pooling
81
+ avg_score = local_scores.mean(dim=1)
82
+
83
+ k = max(1, int(local_scores.size(1) * 0.1))
84
+ min_vals, _ = torch.topk(local_scores, k, dim=1, largest=False)
85
+ worst_score = min_vals.mean(dim=1)
86
+
87
+ alpha = torch.sigmoid(self.raw_severity)
88
+ final_score = (alpha * worst_score) + ((1 - alpha) * avg_score)
89
+
90
+ return final_score.squeeze(1)
91
+
92
+ def predict_trace(self, music_mels, chart_mels):
93
+ """
94
+ Explainability Method: Returns the second-by-second quality curve.
95
+
96
+ Returns:
97
+ local_scores: (Batch, Time) - The quality score at every timestep.
98
+ """
99
+ with torch.no_grad():
100
+ # 1. Early Fusion: Concatenate along Channel dimension
101
+ # Shape becomes (Batch, 160, Time)
102
+ x = torch.cat([music_mels, chart_mels], dim=1)
103
+
104
+ # 2. Extract Features (Strictly Local + Context)
105
+ x = F.gelu(self.input_proj(x))
106
+ x = self.encoder(x)
107
+
108
+ # 3. Predict Score per Frame
109
+ # (Batch, Dim, Time) -> (Batch, Time, Dim)
110
+ x = x.permute(0, 2, 1)
111
+ local_scores = torch.sigmoid(self.quality_proj(x)) # (Batch, Time, 1)
112
+ return local_scores.squeeze(2)
113
+
114
+
115
+ if __name__ == "__main__":
116
+ # Sanity Check
117
+ from torchinfo import summary
118
+
119
+ model = GameChartEvaluator()
120
+ print(
121
+ f"Model initialized. Learnable Severity: {torch.sigmoid(model.raw_severity).item():.2f}"
122
+ )
123
+
124
+ # Dummy data (Batch=2, Freq=80, Time=1000)
125
+ m = torch.randn(2, 80, 1000)
126
+ c = torch.randn(2, 80, 1000)
127
+
128
+ output = model(m, c)
129
+ print(f"Output shape: {output.shape}") # Should be torch.Size([2])
130
+ print(f"Scores: {output}")
131
+
132
+ # Trace check
133
+ trace = model.predict_trace(m, c)
134
+ print(
135
+ f"Trace shape: {trace.shape}"
136
+ ) # Should be torch.Size([2, 500]) (due to MaxPool1d(2))
137
+
138
+ summary(model, input_data=[m, c])
logs/events.out.tfevents.1765826528.msiit232.2878790.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ffaff1b4785cecde5e9ec55c6ccb9c5bd95ccfc05e4b441020e64cbfe2f38ae
3
+ size 81061
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14cdf0af591718d42744931c641f91756ac720513d3125c860548f56cca9f59d
3
+ size 1845464