Spaces:
Sleeping
Sleeping
hopefully fixed unicoil
Browse files
app.py
CHANGED
|
@@ -86,131 +86,56 @@ def get_unicoil_binary_representation(text):
|
|
| 86 |
return "UNICOIL model is not loaded. Please check the console for loading errors."
|
| 87 |
|
| 88 |
inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
|
|
|
|
|
|
|
| 89 |
inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
|
| 90 |
|
| 91 |
with torch.no_grad():
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
# However, looking at `castorini/unicoil-msmarco-passage`
|
| 125 |
-
# its `config.json` might give hints or the model itself is structured.
|
| 126 |
-
# Often, it uses `BertForMaskedLM` and then applies `log(1+relu)` to the logits.
|
| 127 |
-
# Let's assume it behaves similar to SPLADE for simplicity of extraction for now,
|
| 128 |
-
# or we might need to load it as `AutoModelForMaskedLM` if its internal structure
|
| 129 |
-
# is indeed like that, and then apply a binarization.
|
| 130 |
-
|
| 131 |
-
# Re-evaluating: UNICOIL typically *learns* explicit token weights.
|
| 132 |
-
# The common approach for UNICOIL with Hugging Face is indeed to load it
|
| 133 |
-
# as `AutoModelForMaskedLM` and use its `logits` output, similar to SPLADE,
|
| 134 |
-
# but with a different aggregation strategy.
|
| 135 |
-
# Let's verify the model type for 'castorini/unicoil-msmarco-passage'.
|
| 136 |
-
# Its config.json and architecture implies it's a BertForMaskedLM variant.
|
| 137 |
-
|
| 138 |
-
output = model_unicoil(**inputs) # This should be a BaseModelOutputWithPooling or similar
|
| 139 |
-
|
| 140 |
-
if not hasattr(output, 'logits'):
|
| 141 |
-
# If `model_unicoil` is an `AutoModel` without a classification head,
|
| 142 |
-
# we need to add a way to get per-token scores.
|
| 143 |
-
# This is where a custom model head or a specific model class would be needed.
|
| 144 |
-
# For `castorini/unicoil-msmarco-passage`, it *is* an MLM variant.
|
| 145 |
-
# So, `output.logits` *should* be available.
|
| 146 |
-
return "UNICOIL model output structure not as expected. 'logits' not found."
|
| 147 |
-
|
| 148 |
-
# UNICOIL's output is also typically per-token scores from the MLM head.
|
| 149 |
-
# For UNICOIL, the weights are often taken directly from the logits after pooling.
|
| 150 |
-
# Unlike SPLADE's log(1+ReLU), UNICOIL's approach can be simpler,
|
| 151 |
-
# sometimes just taking the maximum of logits (or similar pooling).
|
| 152 |
-
# A common binarization for UNICOIL is based on the sign of the re-weighted scores.
|
| 153 |
-
|
| 154 |
-
# Let's mimic a common UNICOIL interpretation for obtaining sparse weights
|
| 155 |
-
# from the logits. The weights are usually sparse and positive.
|
| 156 |
-
# We can apply a threshold for binarization.
|
| 157 |
-
|
| 158 |
-
# This is a simplification; actual UNICOIL might have specific layers.
|
| 159 |
-
# For `castorini/unicoil-msmarco-passage`, it uses the `log(1+exp(logits))` formulation
|
| 160 |
-
# followed by max pooling, then often binarization based on a threshold.
|
| 161 |
-
|
| 162 |
-
# Applying a common interpretation of UNICOIL-like score generation for sparse weights:
|
| 163 |
-
# Instead of `log(1+ReLU(logits))`, it often uses `torch.log(1 + torch.exp(output.logits))`.
|
| 164 |
-
# This is essentially the softplus function, which makes values positive and sparse.
|
| 165 |
-
|
| 166 |
-
# Get the sparse weights using the UNICOIL-like transformation
|
| 167 |
-
sparse_weights = torch.max(torch.log(1 + torch.exp(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
|
| 168 |
-
|
| 169 |
-
# --- Binarization Step for UNICOIL ---
|
| 170 |
-
# For true "binary sparse", we threshold these sparse weights.
|
| 171 |
-
# A common approach is to simply take any non-zero value as 1, and zero as 0.
|
| 172 |
-
# Or, define a small threshold for binarization if values are very small but non-zero.
|
| 173 |
-
# For simplicity, let's treat anything above a very small epsilon as 1.
|
| 174 |
-
|
| 175 |
-
# Convert to binary: 1 if weight > epsilon, else 0
|
| 176 |
-
threshold = 1e-6 # Define a small threshold for binarization
|
| 177 |
-
binary_sparse_vector = (sparse_weights > threshold).int()
|
| 178 |
-
|
| 179 |
-
# Get indices of the '1's in the binary vector
|
| 180 |
-
binary_indices = torch.nonzero(binary_sparse_vector).squeeze().cpu().tolist()
|
| 181 |
-
|
| 182 |
-
if not isinstance(binary_indices, list):
|
| 183 |
-
binary_indices = [binary_indices] if binary_indices.numel() > 0 else []
|
| 184 |
-
|
| 185 |
-
# Map token IDs back to terms for the binary representation
|
| 186 |
-
binary_terms = {}
|
| 187 |
-
for token_id in binary_indices:
|
| 188 |
-
decoded_token = tokenizer_unicoil.decode([token_id])
|
| 189 |
-
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
| 190 |
-
binary_terms[decoded_token] = 1 # Value is always 1 for binary
|
| 191 |
-
|
| 192 |
-
sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0]) # Sort by term for consistent display
|
| 193 |
-
|
| 194 |
-
formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
|
| 195 |
-
if not sorted_binary_terms:
|
| 196 |
-
formatted_output += "No significant terms activated for this input.\n"
|
| 197 |
-
else:
|
| 198 |
-
# Display up to 50 activated terms for readability
|
| 199 |
-
for i, (term, _) in enumerate(sorted_binary_terms):
|
| 200 |
-
if i >= 50:
|
| 201 |
-
break
|
| 202 |
-
formatted_output += f"- **{term}**\n" # Only show term, as weight is always 1
|
| 203 |
-
if len(sorted_binary_terms) > 50:
|
| 204 |
formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 210 |
|
| 211 |
return formatted_output
|
| 212 |
|
| 213 |
|
|
|
|
|
|
|
| 214 |
# --- Unified Prediction Function for Gradio ---
|
| 215 |
def predict_representation(model_choice, text):
|
| 216 |
if model_choice == "SPLADE":
|
|
|
|
| 86 |
return "UNICOIL model is not loaded. Please check the console for loading errors."
|
| 87 |
|
| 88 |
inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
|
| 89 |
+
input_ids = inputs["input_ids"]
|
| 90 |
+
attention_mask = inputs["attention_mask"]
|
| 91 |
inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
|
| 92 |
|
| 93 |
with torch.no_grad():
|
| 94 |
+
output = model_unicoil(**inputs)
|
| 95 |
+
|
| 96 |
+
if not hasattr(output, "logits"):
|
| 97 |
+
return "UNICOIL model output structure not as expected. 'logits' not found."
|
| 98 |
+
|
| 99 |
+
logits = output.logits.squeeze(0) # [seq_len, vocab_size]
|
| 100 |
+
token_ids = input_ids.squeeze(0) # [seq_len]
|
| 101 |
+
mask = attention_mask.squeeze(0) # [seq_len]
|
| 102 |
+
|
| 103 |
+
transformed_scores = torch.log(1 + torch.exp(logits)) # softplus
|
| 104 |
+
token_scores = transformed_scores[range(len(token_ids)), token_ids] # only scores for input tokens
|
| 105 |
+
token_scores = token_scores * mask # mask out padding
|
| 106 |
+
|
| 107 |
+
# Binarize: threshold scores > 0.5 (tune as needed)
|
| 108 |
+
binary_mask = (token_scores > 0.5)
|
| 109 |
+
activated_token_ids = token_ids[binary_mask].cpu().tolist()
|
| 110 |
+
|
| 111 |
+
# Map token ids to strings
|
| 112 |
+
binary_terms = {}
|
| 113 |
+
for token_id in activated_token_ids:
|
| 114 |
+
decoded_token = tokenizer_unicoil.decode([token_id])
|
| 115 |
+
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
| 116 |
+
binary_terms[decoded_token] = 1
|
| 117 |
+
|
| 118 |
+
sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0])
|
| 119 |
+
|
| 120 |
+
formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
|
| 121 |
+
if not sorted_binary_terms:
|
| 122 |
+
formatted_output += "No significant terms activated for this input.\n"
|
| 123 |
+
else:
|
| 124 |
+
for i, (term, _) in enumerate(sorted_binary_terms):
|
| 125 |
+
if i >= 50:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
|
| 127 |
+
break
|
| 128 |
+
formatted_output += f"- **{term}**\n"
|
| 129 |
+
|
| 130 |
+
formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
|
| 131 |
+
formatted_output += f"Total activated terms: {len(sorted_binary_terms)}\n"
|
| 132 |
+
formatted_output += f"Sparsity: {1 - (len(sorted_binary_terms) / tokenizer_unicoil.vocab_size):.2%}\n"
|
| 133 |
|
| 134 |
return formatted_output
|
| 135 |
|
| 136 |
|
| 137 |
+
|
| 138 |
+
|
| 139 |
# --- Unified Prediction Function for Gradio ---
|
| 140 |
def predict_representation(model_choice, text):
|
| 141 |
if model_choice == "SPLADE":
|