Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,17 +7,17 @@ tokenizer_splade = None
|
|
| 7 |
model_splade = None
|
| 8 |
tokenizer_splade_lexical = None
|
| 9 |
model_splade_lexical = None
|
| 10 |
-
tokenizer_splade_doc = None
|
| 11 |
-
model_splade_doc = None
|
| 12 |
|
| 13 |
# Load SPLADE v3 model (original)
|
| 14 |
try:
|
| 15 |
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
|
| 16 |
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
|
| 17 |
model_splade.eval() # Set to evaluation mode for inference
|
| 18 |
-
print("SPLADE
|
| 19 |
except Exception as e:
|
| 20 |
-
print(f"Error loading SPLADE
|
| 21 |
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
|
| 22 |
|
| 23 |
# Load SPLADE v3 Lexical model
|
|
@@ -26,24 +26,24 @@ try:
|
|
| 26 |
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
|
| 27 |
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
|
| 28 |
model_splade_lexical.eval() # Set to evaluation mode for inference
|
| 29 |
-
print(f"SPLADE
|
| 30 |
except Exception as e:
|
| 31 |
-
print(f"Error loading SPLADE
|
| 32 |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 33 |
|
| 34 |
-
# Load SPLADE v3 Doc model
|
| 35 |
try:
|
| 36 |
splade_doc_model_name = "naver/splade-v3-doc"
|
| 37 |
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
|
| 38 |
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
|
| 39 |
model_splade_doc.eval() # Set to evaluation mode for inference
|
| 40 |
-
print(f"SPLADE
|
| 41 |
except Exception as e:
|
| 42 |
-
print(f"Error loading SPLADE
|
| 43 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 44 |
|
| 45 |
|
| 46 |
-
# --- Helper function for lexical mask
|
| 47 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
| 48 |
"""
|
| 49 |
Creates a binary bag-of-words mask from input_ids,
|
|
@@ -69,9 +69,9 @@ def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
|
| 69 |
|
| 70 |
# --- Core Representation Functions ---
|
| 71 |
|
| 72 |
-
def
|
| 73 |
if tokenizer_splade is None or model_splade is None:
|
| 74 |
-
return "SPLADE
|
| 75 |
|
| 76 |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
|
| 77 |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
|
|
@@ -80,12 +80,13 @@ def get_splade_representation(text):
|
|
| 80 |
output = model_splade(**inputs)
|
| 81 |
|
| 82 |
if hasattr(output, 'logits'):
|
|
|
|
| 83 |
splade_vector = torch.max(
|
| 84 |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
|
| 85 |
dim=1
|
| 86 |
)[0].squeeze()
|
| 87 |
else:
|
| 88 |
-
return "Model output structure not as expected for SPLADE
|
| 89 |
|
| 90 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
| 91 |
if not isinstance(indices, list):
|
|
@@ -102,7 +103,7 @@ def get_splade_representation(text):
|
|
| 102 |
|
| 103 |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
| 104 |
|
| 105 |
-
formatted_output = "SPLADE
|
| 106 |
if not sorted_representation:
|
| 107 |
formatted_output += "No significant terms found for this input.\n"
|
| 108 |
else:
|
|
@@ -118,7 +119,7 @@ def get_splade_representation(text):
|
|
| 118 |
|
| 119 |
def get_splade_lexical_representation(text):
|
| 120 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
| 121 |
-
return "SPLADE
|
| 122 |
|
| 123 |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
|
| 124 |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
|
|
@@ -132,15 +133,14 @@ def get_splade_lexical_representation(text):
|
|
| 132 |
dim=1
|
| 133 |
)[0].squeeze()
|
| 134 |
else:
|
| 135 |
-
return "Model output structure not as expected for SPLADE
|
| 136 |
|
| 137 |
-
#
|
| 138 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
| 139 |
bow_mask = create_lexical_bow_mask(
|
| 140 |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
| 141 |
).squeeze()
|
| 142 |
splade_vector = splade_vector * bow_mask
|
| 143 |
-
# --- End Lexical Mask Logic ---
|
| 144 |
|
| 145 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
| 146 |
if not isinstance(indices, list):
|
|
@@ -157,7 +157,7 @@ def get_splade_lexical_representation(text):
|
|
| 157 |
|
| 158 |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
| 159 |
|
| 160 |
-
formatted_output = "SPLADE
|
| 161 |
if not sorted_representation:
|
| 162 |
formatted_output += "No significant terms found for this input.\n"
|
| 163 |
else:
|
|
@@ -171,10 +171,10 @@ def get_splade_lexical_representation(text):
|
|
| 171 |
return formatted_output
|
| 172 |
|
| 173 |
|
| 174 |
-
#
|
| 175 |
def get_splade_doc_representation(text):
|
| 176 |
if tokenizer_splade_doc is None or model_splade_doc is None:
|
| 177 |
-
return "SPLADE
|
| 178 |
|
| 179 |
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
|
| 180 |
inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
|
|
@@ -183,33 +183,22 @@ def get_splade_doc_representation(text):
|
|
| 183 |
output = model_splade_doc(**inputs)
|
| 184 |
|
| 185 |
if not hasattr(output, "logits"):
|
| 186 |
-
return "SPLADE
|
| 187 |
|
| 188 |
-
# For SPLADE-v3-Doc,
|
| 189 |
-
# We will
|
| 190 |
-
#
|
| 191 |
-
#
|
| 192 |
-
|
| 193 |
-
# Option 1: Binarize based on softplus output and threshold (similar to UNICOIL)
|
| 194 |
-
# This might still activate some "expanded" terms if the model predicts them strongly.
|
| 195 |
-
# transformed_scores = torch.log(1 + torch.exp(output.logits)) # Softplus
|
| 196 |
-
# splade_vector_raw = torch.max(transformed_scores * inputs['attention_mask'].unsqueeze(-1), dim=1).values
|
| 197 |
-
# binary_splade_vector = (splade_vector_raw > 0.5).float() # Binarize
|
| 198 |
-
|
| 199 |
-
# Option 2: Rely on the original BoW for terms, with 1 for presence
|
| 200 |
-
# This aligns best with "no weighting, no expansion"
|
| 201 |
vocab_size = tokenizer_splade_doc.vocab_size
|
| 202 |
-
binary_splade_vector = create_lexical_bow_mask(
|
| 203 |
inputs['input_ids'], vocab_size, tokenizer_splade_doc
|
| 204 |
).squeeze()
|
| 205 |
|
| 206 |
-
# We set values to 1 as it's a binary representation, not weighted
|
| 207 |
indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
|
| 208 |
-
if not isinstance(indices, list):
|
| 209 |
-
indices = [indices] if indices else []
|
| 210 |
|
| 211 |
-
#
|
| 212 |
-
values = [1.0] * len(indices)
|
| 213 |
token_weights = dict(zip(indices, values))
|
| 214 |
|
| 215 |
meaningful_tokens = {}
|
|
@@ -218,16 +207,14 @@ def get_splade_doc_representation(text):
|
|
| 218 |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
| 219 |
meaningful_tokens[decoded_token] = weight
|
| 220 |
|
| 221 |
-
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for
|
| 222 |
|
| 223 |
-
formatted_output = "SPLADE
|
| 224 |
if not sorted_representation:
|
| 225 |
formatted_output += "No significant terms found for this input.\n"
|
| 226 |
else:
|
| 227 |
-
# Display as terms with no weights as they are binary (value 1)
|
| 228 |
for i, (term, _) in enumerate(sorted_representation):
|
| 229 |
-
# Limit display for very long lists
|
| 230 |
-
if i >= 50:
|
| 231 |
formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
|
| 232 |
break
|
| 233 |
formatted_output += f"- **{term}**\n"
|
|
@@ -241,13 +228,11 @@ def get_splade_doc_representation(text):
|
|
| 241 |
|
| 242 |
# --- Unified Prediction Function for Gradio ---
|
| 243 |
def predict_representation(model_choice, text):
|
| 244 |
-
if model_choice == "SPLADE (
|
| 245 |
-
return
|
| 246 |
-
elif model_choice == "SPLADE-v3-Lexical":
|
| 247 |
-
# Always applies lexical mask for this option
|
| 248 |
return get_splade_lexical_representation(text)
|
| 249 |
-
elif model_choice == "SPLADE-v3-Doc":
|
| 250 |
-
# This function now intrinsically handles binary, lexical-only output
|
| 251 |
return get_splade_doc_representation(text)
|
| 252 |
else:
|
| 253 |
return "Please select a model."
|
|
@@ -260,10 +245,10 @@ demo = gr.Interface(
|
|
| 260 |
[
|
| 261 |
"SPLADE-cocondenser-distil (weighting and expansion)",
|
| 262 |
"SPLADE-v3-Lexical (weighting)",
|
| 263 |
-
"SPLADE-v3-Doc (binary)"
|
| 264 |
],
|
| 265 |
label="Choose Representation Model",
|
| 266 |
-
value="SPLADE (
|
| 267 |
),
|
| 268 |
gr.Textbox(
|
| 269 |
lines=5,
|
|
@@ -273,7 +258,7 @@ demo = gr.Interface(
|
|
| 273 |
],
|
| 274 |
outputs=gr.Markdown(),
|
| 275 |
title="🌌 Sparse Representation Generator",
|
| 276 |
-
description="
|
| 277 |
allow_flagging="never"
|
| 278 |
)
|
| 279 |
|
|
|
|
| 7 |
model_splade = None
|
| 8 |
tokenizer_splade_lexical = None
|
| 9 |
model_splade_lexical = None
|
| 10 |
+
tokenizer_splade_doc = None
|
| 11 |
+
model_splade_doc = None
|
| 12 |
|
| 13 |
# Load SPLADE v3 model (original)
|
| 14 |
try:
|
| 15 |
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
|
| 16 |
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
|
| 17 |
model_splade.eval() # Set to evaluation mode for inference
|
| 18 |
+
print("SPLADE-cocondenser-distil model loaded successfully!")
|
| 19 |
except Exception as e:
|
| 20 |
+
print(f"Error loading SPLADE-cocondenser-distil model: {e}")
|
| 21 |
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
|
| 22 |
|
| 23 |
# Load SPLADE v3 Lexical model
|
|
|
|
| 26 |
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
|
| 27 |
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
|
| 28 |
model_splade_lexical.eval() # Set to evaluation mode for inference
|
| 29 |
+
print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
|
| 30 |
except Exception as e:
|
| 31 |
+
print(f"Error loading SPLADE-v3-Lexical model: {e}")
|
| 32 |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 33 |
|
| 34 |
+
# Load SPLADE v3 Doc model
|
| 35 |
try:
|
| 36 |
splade_doc_model_name = "naver/splade-v3-doc"
|
| 37 |
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
|
| 38 |
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
|
| 39 |
model_splade_doc.eval() # Set to evaluation mode for inference
|
| 40 |
+
print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
|
| 41 |
except Exception as e:
|
| 42 |
+
print(f"Error loading SPLADE-v3-Doc model: {e}")
|
| 43 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 44 |
|
| 45 |
|
| 46 |
+
# --- Helper function for lexical mask ---
|
| 47 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
| 48 |
"""
|
| 49 |
Creates a binary bag-of-words mask from input_ids,
|
|
|
|
| 69 |
|
| 70 |
# --- Core Representation Functions ---
|
| 71 |
|
| 72 |
+
def get_splade_cocondenser_representation(text):
|
| 73 |
if tokenizer_splade is None or model_splade is None:
|
| 74 |
+
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
| 75 |
|
| 76 |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
|
| 77 |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
|
|
|
|
| 80 |
output = model_splade(**inputs)
|
| 81 |
|
| 82 |
if hasattr(output, 'logits'):
|
| 83 |
+
# Standard SPLADE calculation for learned weighting and expansion
|
| 84 |
splade_vector = torch.max(
|
| 85 |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
|
| 86 |
dim=1
|
| 87 |
)[0].squeeze()
|
| 88 |
else:
|
| 89 |
+
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found."
|
| 90 |
|
| 91 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
| 92 |
if not isinstance(indices, list):
|
|
|
|
| 103 |
|
| 104 |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
| 105 |
|
| 106 |
+
formatted_output = "SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
|
| 107 |
if not sorted_representation:
|
| 108 |
formatted_output += "No significant terms found for this input.\n"
|
| 109 |
else:
|
|
|
|
| 119 |
|
| 120 |
def get_splade_lexical_representation(text):
|
| 121 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
| 122 |
+
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors."
|
| 123 |
|
| 124 |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
|
| 125 |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
|
|
|
|
| 133 |
dim=1
|
| 134 |
)[0].squeeze()
|
| 135 |
else:
|
| 136 |
+
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found."
|
| 137 |
|
| 138 |
+
# Always apply lexical mask for this model's specific behavior
|
| 139 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
| 140 |
bow_mask = create_lexical_bow_mask(
|
| 141 |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
| 142 |
).squeeze()
|
| 143 |
splade_vector = splade_vector * bow_mask
|
|
|
|
| 144 |
|
| 145 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
| 146 |
if not isinstance(indices, list):
|
|
|
|
| 157 |
|
| 158 |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
| 159 |
|
| 160 |
+
formatted_output = "SPLADE-v3-Lexical Representation (Weighting):\n"
|
| 161 |
if not sorted_representation:
|
| 162 |
formatted_output += "No significant terms found for this input.\n"
|
| 163 |
else:
|
|
|
|
| 171 |
return formatted_output
|
| 172 |
|
| 173 |
|
| 174 |
+
# Function for SPLADE-v3-Doc representation (Binary Sparse - Lexical Only)
|
| 175 |
def get_splade_doc_representation(text):
|
| 176 |
if tokenizer_splade_doc is None or model_splade_doc is None:
|
| 177 |
+
return "SPLADE-v3-Doc model is not loaded. Please check the console for loading errors."
|
| 178 |
|
| 179 |
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
|
| 180 |
inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
|
|
|
|
| 183 |
output = model_splade_doc(**inputs)
|
| 184 |
|
| 185 |
if not hasattr(output, "logits"):
|
| 186 |
+
return "SPLADE-v3-Doc model output structure not as expected. 'logits' not found."
|
| 187 |
|
| 188 |
+
# For SPLADE-v3-Doc, assuming output is designed to be binary and lexical-only.
|
| 189 |
+
# We will derive the output directly from the input tokens themselves,
|
| 190 |
+
# as the model's primary role in this context is as a pre-trained LM feature extractor
|
| 191 |
+
# for a document-side, lexical-only binary sparse representation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
vocab_size = tokenizer_splade_doc.vocab_size
|
| 193 |
+
binary_splade_vector = create_lexical_bow_mask( # Use the BOW mask directly for binary
|
| 194 |
inputs['input_ids'], vocab_size, tokenizer_splade_doc
|
| 195 |
).squeeze()
|
| 196 |
|
|
|
|
| 197 |
indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
|
| 198 |
+
if not isinstance(indices, list):
|
| 199 |
+
indices = [indices] if indices else []
|
| 200 |
|
| 201 |
+
values = [1.0] * len(indices) # All values are 1 for binary representation
|
|
|
|
| 202 |
token_weights = dict(zip(indices, values))
|
| 203 |
|
| 204 |
meaningful_tokens = {}
|
|
|
|
| 207 |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
| 208 |
meaningful_tokens[decoded_token] = weight
|
| 209 |
|
| 210 |
+
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity
|
| 211 |
|
| 212 |
+
formatted_output = "SPLADE-v3-Doc Representation (Binary):\n"
|
| 213 |
if not sorted_representation:
|
| 214 |
formatted_output += "No significant terms found for this input.\n"
|
| 215 |
else:
|
|
|
|
| 216 |
for i, (term, _) in enumerate(sorted_representation):
|
| 217 |
+
if i >= 50: # Limit display for very long lists
|
|
|
|
| 218 |
formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
|
| 219 |
break
|
| 220 |
formatted_output += f"- **{term}**\n"
|
|
|
|
| 228 |
|
| 229 |
# --- Unified Prediction Function for Gradio ---
|
| 230 |
def predict_representation(model_choice, text):
|
| 231 |
+
if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
|
| 232 |
+
return get_splade_cocondenser_representation(text)
|
| 233 |
+
elif model_choice == "SPLADE-v3-Lexical (weighting)":
|
|
|
|
| 234 |
return get_splade_lexical_representation(text)
|
| 235 |
+
elif model_choice == "SPLADE-v3-Doc (binary)":
|
|
|
|
| 236 |
return get_splade_doc_representation(text)
|
| 237 |
else:
|
| 238 |
return "Please select a model."
|
|
|
|
| 245 |
[
|
| 246 |
"SPLADE-cocondenser-distil (weighting and expansion)",
|
| 247 |
"SPLADE-v3-Lexical (weighting)",
|
| 248 |
+
"SPLADE-v3-Doc (binary)"
|
| 249 |
],
|
| 250 |
label="Choose Representation Model",
|
| 251 |
+
value="SPLADE-cocondenser-distil (weighting and expansion)" # Corrected default value
|
| 252 |
),
|
| 253 |
gr.Textbox(
|
| 254 |
lines=5,
|
|
|
|
| 258 |
],
|
| 259 |
outputs=gr.Markdown(),
|
| 260 |
title="🌌 Sparse Representation Generator",
|
| 261 |
+
description="Explore different SPLADE models and their sparse representation types: weighted and expansive, weighted and lexical-only, or strictly binary.",
|
| 262 |
allow_flagging="never"
|
| 263 |
)
|
| 264 |
|