Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -43,33 +43,28 @@ except Exception as e:
|
|
| 43 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 44 |
|
| 45 |
|
| 46 |
-
# --- Helper function for lexical mask ---
|
| 47 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
| 48 |
"""
|
| 49 |
Creates a binary bag-of-words mask from input_ids,
|
| 50 |
zeroing out special tokens and padding.
|
| 51 |
"""
|
| 52 |
-
# Initialize a zero vector for the entire vocabulary
|
| 53 |
bow_mask = torch.zeros(vocab_size, device=input_ids.device)
|
| 54 |
-
|
| 55 |
-
# Get unique token IDs from the input, excluding special tokens
|
| 56 |
-
# input_ids is typically [batch_size, seq_len], we assume batch_size=1
|
| 57 |
meaningful_token_ids = []
|
| 58 |
-
for token_id in input_ids.squeeze().tolist():
|
| 59 |
if token_id not in [
|
| 60 |
tokenizer.pad_token_id,
|
| 61 |
tokenizer.cls_token_id,
|
| 62 |
tokenizer.sep_token_id,
|
| 63 |
tokenizer.mask_token_id,
|
| 64 |
-
tokenizer.unk_token_id
|
| 65 |
]:
|
| 66 |
meaningful_token_ids.append(token_id)
|
| 67 |
|
| 68 |
-
# Set 1 for tokens present in the original input
|
| 69 |
if meaningful_token_ids:
|
| 70 |
-
bow_mask[list(set(meaningful_token_ids))] = 1
|
| 71 |
|
| 72 |
-
return bow_mask.unsqueeze(0)
|
| 73 |
|
| 74 |
|
| 75 |
# --- Core Representation Functions ---
|
|
@@ -140,15 +135,10 @@ def get_splade_lexical_representation(text):
|
|
| 140 |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
|
| 141 |
|
| 142 |
# --- Apply Lexical Mask (always applied for this function now) ---
|
| 143 |
-
# Get the vocabulary size from the tokenizer
|
| 144 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
| 145 |
-
|
| 146 |
-
# Create the Bag-of-Words mask
|
| 147 |
bow_mask = create_lexical_bow_mask(
|
| 148 |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
| 149 |
).squeeze()
|
| 150 |
-
|
| 151 |
-
# Multiply the SPLADE vector by the BoW mask to zero out expanded terms
|
| 152 |
splade_vector = splade_vector * bow_mask
|
| 153 |
# --- End Lexical Mask Logic ---
|
| 154 |
|
|
@@ -181,8 +171,8 @@ def get_splade_lexical_representation(text):
|
|
| 181 |
return formatted_output
|
| 182 |
|
| 183 |
|
| 184 |
-
# NEW: Function for SPLADE-v3-Doc representation
|
| 185 |
-
def get_splade_doc_representation(text
|
| 186 |
if tokenizer_splade_doc is None or model_splade_doc is None:
|
| 187 |
return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
|
| 188 |
|
|
@@ -192,28 +182,34 @@ def get_splade_doc_representation(text, apply_lexical_mask: bool):
|
|
| 192 |
with torch.no_grad():
|
| 193 |
output = model_splade_doc(**inputs)
|
| 194 |
|
| 195 |
-
if hasattr(output,
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
#
|
| 204 |
-
if
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
#
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
token_weights = dict(zip(indices, values))
|
| 218 |
|
| 219 |
meaningful_tokens = {}
|
|
@@ -222,17 +218,22 @@ def get_splade_doc_representation(text, apply_lexical_mask: bool):
|
|
| 222 |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
| 223 |
meaningful_tokens[decoded_token] = weight
|
| 224 |
|
| 225 |
-
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[
|
| 226 |
|
| 227 |
-
formatted_output = "SPLADE v3 Doc Representation (
|
| 228 |
if not sorted_representation:
|
| 229 |
formatted_output += "No significant terms found for this input.\n"
|
| 230 |
else:
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
|
| 237 |
|
| 238 |
return formatted_output
|
|
@@ -243,12 +244,11 @@ def predict_representation(model_choice, text):
|
|
| 243 |
if model_choice == "SPLADE (cocondenser)":
|
| 244 |
return get_splade_representation(text)
|
| 245 |
elif model_choice == "SPLADE-v3-Lexical":
|
| 246 |
-
# Always applies lexical mask for this option
|
| 247 |
return get_splade_lexical_representation(text)
|
| 248 |
-
elif model_choice == "SPLADE-v3-Doc
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
return get_splade_doc_representation(text, apply_lexical_mask=True)
|
| 252 |
else:
|
| 253 |
return "Please select a model."
|
| 254 |
|
|
@@ -259,9 +259,8 @@ demo = gr.Interface(
|
|
| 259 |
gr.Radio(
|
| 260 |
[
|
| 261 |
"SPLADE (cocondenser)",
|
| 262 |
-
"SPLADE-v3-Lexical",
|
| 263 |
-
"SPLADE-v3-Doc
|
| 264 |
-
"SPLADE-v3-Doc (lexical-only)" # Option with lexical mask applied
|
| 265 |
],
|
| 266 |
label="Choose Representation Model",
|
| 267 |
value="SPLADE (cocondenser)" # Default selection
|
|
@@ -274,7 +273,7 @@ demo = gr.Interface(
|
|
| 274 |
],
|
| 275 |
outputs=gr.Markdown(),
|
| 276 |
title="🌌 Sparse Representation Generator",
|
| 277 |
-
description="Enter any text to see its
|
| 278 |
allow_flagging="never"
|
| 279 |
)
|
| 280 |
|
|
|
|
| 43 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 44 |
|
| 45 |
|
| 46 |
+
# --- Helper function for lexical mask (still needed for splade-v3-lexical) ---
|
| 47 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
| 48 |
"""
|
| 49 |
Creates a binary bag-of-words mask from input_ids,
|
| 50 |
zeroing out special tokens and padding.
|
| 51 |
"""
|
|
|
|
| 52 |
bow_mask = torch.zeros(vocab_size, device=input_ids.device)
|
|
|
|
|
|
|
|
|
|
| 53 |
meaningful_token_ids = []
|
| 54 |
+
for token_id in input_ids.squeeze().tolist():
|
| 55 |
if token_id not in [
|
| 56 |
tokenizer.pad_token_id,
|
| 57 |
tokenizer.cls_token_id,
|
| 58 |
tokenizer.sep_token_id,
|
| 59 |
tokenizer.mask_token_id,
|
| 60 |
+
tokenizer.unk_token_id
|
| 61 |
]:
|
| 62 |
meaningful_token_ids.append(token_id)
|
| 63 |
|
|
|
|
| 64 |
if meaningful_token_ids:
|
| 65 |
+
bow_mask[list(set(meaningful_token_ids))] = 1
|
| 66 |
|
| 67 |
+
return bow_mask.unsqueeze(0)
|
| 68 |
|
| 69 |
|
| 70 |
# --- Core Representation Functions ---
|
|
|
|
| 135 |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
|
| 136 |
|
| 137 |
# --- Apply Lexical Mask (always applied for this function now) ---
|
|
|
|
| 138 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
|
|
|
|
|
|
| 139 |
bow_mask = create_lexical_bow_mask(
|
| 140 |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
| 141 |
).squeeze()
|
|
|
|
|
|
|
| 142 |
splade_vector = splade_vector * bow_mask
|
| 143 |
# --- End Lexical Mask Logic ---
|
| 144 |
|
|
|
|
| 171 |
return formatted_output
|
| 172 |
|
| 173 |
|
| 174 |
+
# NEW: Function for SPLADE-v3-Doc representation (Binary Sparse)
|
| 175 |
+
def get_splade_doc_representation(text):
|
| 176 |
if tokenizer_splade_doc is None or model_splade_doc is None:
|
| 177 |
return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
|
| 178 |
|
|
|
|
| 182 |
with torch.no_grad():
|
| 183 |
output = model_splade_doc(**inputs)
|
| 184 |
|
| 185 |
+
if not hasattr(output, "logits"):
|
| 186 |
+
return "SPLADE v3 Doc model output structure not as expected. 'logits' not found."
|
| 187 |
+
|
| 188 |
+
# For SPLADE-v3-Doc, the output is often a binary sparse vector.
|
| 189 |
+
# We will assume a simple binarization based on a threshold or selecting active tokens.
|
| 190 |
+
# A common way to get "binary" is to use softplus and then binarize, or directly binarize max logits.
|
| 191 |
+
# Given the "no weighting, no expansion" request, we'll aim for a strict presence check.
|
| 192 |
+
|
| 193 |
+
# Option 1: Binarize based on softplus output and threshold (similar to UNICOIL)
|
| 194 |
+
# This might still activate some "expanded" terms if the model predicts them strongly.
|
| 195 |
+
# transformed_scores = torch.log(1 + torch.exp(output.logits)) # Softplus
|
| 196 |
+
# splade_vector_raw = torch.max(transformed_scores * inputs['attention_mask'].unsqueeze(-1), dim=1).values
|
| 197 |
+
# binary_splade_vector = (splade_vector_raw > 0.5).float() # Binarize
|
| 198 |
+
|
| 199 |
+
# Option 2: Rely on the original BoW for terms, with 1 for presence
|
| 200 |
+
# This aligns best with "no weighting, no expansion"
|
| 201 |
+
vocab_size = tokenizer_splade_doc.vocab_size
|
| 202 |
+
binary_splade_vector = create_lexical_bow_mask(
|
| 203 |
+
inputs['input_ids'], vocab_size, tokenizer_splade_doc
|
| 204 |
+
).squeeze()
|
| 205 |
+
|
| 206 |
+
# We set values to 1 as it's a binary representation, not weighted
|
| 207 |
+
indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
|
| 208 |
+
if not isinstance(indices, list): # Handle case where only one non-zero index
|
| 209 |
+
indices = [indices] if indices else [] # Ensure it's a list even if empty or single
|
| 210 |
+
|
| 211 |
+
# Values are all 1 for binary representation
|
| 212 |
+
values = [1.0] * len(indices)
|
| 213 |
token_weights = dict(zip(indices, values))
|
| 214 |
|
| 215 |
meaningful_tokens = {}
|
|
|
|
| 218 |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
| 219 |
meaningful_tokens[decoded_token] = weight
|
| 220 |
|
| 221 |
+
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
|
| 222 |
|
| 223 |
+
formatted_output = "SPLADE v3 Doc Representation (Binary Sparse - Lexical Only):\n"
|
| 224 |
if not sorted_representation:
|
| 225 |
formatted_output += "No significant terms found for this input.\n"
|
| 226 |
else:
|
| 227 |
+
# Display as terms with no weights as they are binary (value 1)
|
| 228 |
+
for i, (term, _) in enumerate(sorted_representation):
|
| 229 |
+
# Limit display for very long lists for readability
|
| 230 |
+
if i >= 50:
|
| 231 |
+
formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
|
| 232 |
+
break
|
| 233 |
+
formatted_output += f"- **{term}**\n"
|
| 234 |
+
|
| 235 |
+
formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
|
| 236 |
+
formatted_output += f"Total activated terms: {len(indices)}\n"
|
| 237 |
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
|
| 238 |
|
| 239 |
return formatted_output
|
|
|
|
| 244 |
if model_choice == "SPLADE (cocondenser)":
|
| 245 |
return get_splade_representation(text)
|
| 246 |
elif model_choice == "SPLADE-v3-Lexical":
|
| 247 |
+
# Always applies lexical mask for this option
|
| 248 |
return get_splade_lexical_representation(text)
|
| 249 |
+
elif model_choice == "SPLADE-v3-Doc": # Simplified to a single option
|
| 250 |
+
# This function now intrinsically handles binary, lexical-only output
|
| 251 |
+
return get_splade_doc_representation(text)
|
|
|
|
| 252 |
else:
|
| 253 |
return "Please select a model."
|
| 254 |
|
|
|
|
| 259 |
gr.Radio(
|
| 260 |
[
|
| 261 |
"SPLADE (cocondenser)",
|
| 262 |
+
"SPLADE-v3-Lexical",
|
| 263 |
+
"SPLADE-v3-Doc" # Only one option for Doc model
|
|
|
|
| 264 |
],
|
| 265 |
label="Choose Representation Model",
|
| 266 |
value="SPLADE (cocondenser)" # Default selection
|
|
|
|
| 273 |
],
|
| 274 |
outputs=gr.Markdown(),
|
| 275 |
title="🌌 Sparse Representation Generator",
|
| 276 |
+
description="Enter any text to see its sparse vector representation.", # Simplified description
|
| 277 |
allow_flagging="never"
|
| 278 |
)
|
| 279 |
|