Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -107,7 +107,7 @@ def get_splade_representation(text):
|
|
| 107 |
return formatted_output
|
| 108 |
|
| 109 |
|
| 110 |
-
def get_splade_lexical_representation(text
|
| 111 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
| 112 |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
|
| 113 |
|
|
@@ -125,18 +125,17 @@ def get_splade_lexical_representation(text, apply_lexical_mask: bool): # Added p
|
|
| 125 |
else:
|
| 126 |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
|
| 127 |
|
| 128 |
-
# --- Apply Lexical Mask
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
splade_vector = splade_vector * bow_mask
|
| 140 |
# --- End Lexical Mask Logic ---
|
| 141 |
|
| 142 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
|
@@ -172,12 +171,8 @@ def get_splade_lexical_representation(text, apply_lexical_mask: bool): # Added p
|
|
| 172 |
def predict_representation(model_choice, text):
|
| 173 |
if model_choice == "SPLADE (cocondenser)":
|
| 174 |
return get_splade_representation(text)
|
| 175 |
-
elif model_choice == "SPLADE-v3-Lexical
|
| 176 |
-
#
|
| 177 |
-
return get_splade_lexical_representation(text, apply_lexical_mask=False)
|
| 178 |
-
elif model_choice == "SPLADE-v3-Lexical (lexical-only)":
|
| 179 |
-
# Call the lexical function applying the mask
|
| 180 |
-
return get_splade_lexical_representation(text, apply_lexical_mask=True)
|
| 181 |
else:
|
| 182 |
return "Please select a model."
|
| 183 |
|
|
@@ -188,8 +183,7 @@ demo = gr.Interface(
|
|
| 188 |
gr.Radio(
|
| 189 |
[
|
| 190 |
"SPLADE (cocondenser)",
|
| 191 |
-
"SPLADE-v3-Lexical
|
| 192 |
-
"SPLADE-v3-Lexical (lexical-only)" # Option with lexical mask applied
|
| 193 |
],
|
| 194 |
label="Choose Representation Model",
|
| 195 |
value="SPLADE (cocondenser)" # Default selection
|
|
@@ -202,7 +196,7 @@ demo = gr.Interface(
|
|
| 202 |
],
|
| 203 |
outputs=gr.Markdown(),
|
| 204 |
title="🌌 Sparse Representation Generator",
|
| 205 |
-
description="Enter any text to see its SPLADE sparse vector.
|
| 206 |
allow_flagging="never"
|
| 207 |
)
|
| 208 |
|
|
|
|
| 107 |
return formatted_output
|
| 108 |
|
| 109 |
|
| 110 |
+
def get_splade_lexical_representation(text): # Removed apply_lexical_mask parameter
|
| 111 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
| 112 |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
|
| 113 |
|
|
|
|
| 125 |
else:
|
| 126 |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
|
| 127 |
|
| 128 |
+
# --- Apply Lexical Mask (always applied for this function now) ---
|
| 129 |
+
# Get the vocabulary size from the tokenizer
|
| 130 |
+
vocab_size = tokenizer_splade_lexical.vocab_size
|
| 131 |
+
|
| 132 |
+
# Create the Bag-of-Words mask
|
| 133 |
+
bow_mask = create_lexical_bow_mask(
|
| 134 |
+
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
| 135 |
+
).squeeze()
|
| 136 |
+
|
| 137 |
+
# Multiply the SPLADE vector by the BoW mask to zero out expanded terms
|
| 138 |
+
splade_vector = splade_vector * bow_mask
|
|
|
|
| 139 |
# --- End Lexical Mask Logic ---
|
| 140 |
|
| 141 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
|
|
|
| 171 |
def predict_representation(model_choice, text):
|
| 172 |
if model_choice == "SPLADE (cocondenser)":
|
| 173 |
return get_splade_representation(text)
|
| 174 |
+
elif model_choice == "SPLADE-v3-Lexical": # Simplified choice
|
| 175 |
+
return get_splade_lexical_representation(text) # Always applies lexical mask
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
else:
|
| 177 |
return "Please select a model."
|
| 178 |
|
|
|
|
| 183 |
gr.Radio(
|
| 184 |
[
|
| 185 |
"SPLADE (cocondenser)",
|
| 186 |
+
"SPLADE-v3-Lexical" # Simplified option
|
|
|
|
| 187 |
],
|
| 188 |
label="Choose Representation Model",
|
| 189 |
value="SPLADE (cocondenser)" # Default selection
|
|
|
|
| 196 |
],
|
| 197 |
outputs=gr.Markdown(),
|
| 198 |
title="🌌 Sparse Representation Generator",
|
| 199 |
+
description="Enter any text to see its SPLADE sparse vector.", # Simplified description
|
| 200 |
allow_flagging="never"
|
| 201 |
)
|
| 202 |
|