Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -79,10 +79,10 @@ def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
|
|
| 79 |
|
| 80 |
|
| 81 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
| 82 |
-
# These functions
|
| 83 |
def get_splade_cocondenser_representation(text):
|
| 84 |
if tokenizer_splade is None or model_splade is None:
|
| 85 |
-
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
| 86 |
|
| 87 |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
|
| 88 |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
|
|
@@ -96,7 +96,7 @@ def get_splade_cocondenser_representation(text):
|
|
| 96 |
dim=1
|
| 97 |
)[0].squeeze() # Squeeze is fine here as it's a single input
|
| 98 |
else:
|
| 99 |
-
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found."
|
| 100 |
|
| 101 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
| 102 |
if not isinstance(indices, list):
|
|
@@ -120,16 +120,16 @@ def get_splade_cocondenser_representation(text):
|
|
| 120 |
for term, weight in sorted_representation:
|
| 121 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
|
| 127 |
-
return formatted_output
|
| 128 |
|
| 129 |
|
| 130 |
def get_splade_lexical_representation(text):
|
| 131 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
| 132 |
-
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors."
|
| 133 |
|
| 134 |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
|
| 135 |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
|
|
@@ -143,7 +143,7 @@ def get_splade_lexical_representation(text):
|
|
| 143 |
dim=1
|
| 144 |
)[0].squeeze() # Squeeze is fine here
|
| 145 |
else:
|
| 146 |
-
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found."
|
| 147 |
|
| 148 |
# Always apply lexical mask for this model's specific behavior
|
| 149 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
|
@@ -175,16 +175,16 @@ def get_splade_lexical_representation(text):
|
|
| 175 |
for term, weight in sorted_representation:
|
| 176 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
|
| 182 |
-
return formatted_output
|
| 183 |
|
| 184 |
|
| 185 |
def get_splade_doc_representation(text):
|
| 186 |
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
|
| 187 |
-
return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors."
|
| 188 |
|
| 189 |
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
|
| 190 |
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
|
|
@@ -220,11 +220,11 @@ def get_splade_doc_representation(text):
|
|
| 220 |
break
|
| 221 |
formatted_output += f"- **{term}**\n"
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
|
| 227 |
-
return formatted_output
|
| 228 |
|
| 229 |
|
| 230 |
# --- Unified Prediction Function for the Explorer Tab ---
|
|
@@ -236,7 +236,7 @@ def predict_representation_explorer(model_choice, text):
|
|
| 236 |
elif model_choice == "Binary Bag-of-Words": # Changed name
|
| 237 |
return get_splade_doc_representation(text)
|
| 238 |
else:
|
| 239 |
-
return "Please select a model."
|
| 240 |
|
| 241 |
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
|
| 242 |
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
|
|
@@ -339,10 +339,10 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
|
|
| 339 |
else:
|
| 340 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
| 341 |
|
| 342 |
-
|
| 343 |
-
|
| 344 |
|
| 345 |
-
return formatted_output
|
| 346 |
|
| 347 |
|
| 348 |
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
|
|
@@ -376,11 +376,16 @@ def calculate_dot_product_and_representations_independent(query_model_choice, do
|
|
| 376 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
| 377 |
|
| 378 |
# Format representations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
query_rep_str = f"Query Representation ({query_model_name_display}):\n"
|
| 380 |
-
query_rep_str +=
|
| 381 |
|
| 382 |
doc_rep_str = f"Document Representation ({doc_model_name_display}):\n"
|
| 383 |
-
doc_rep_str +=
|
| 384 |
|
| 385 |
# Combine output
|
| 386 |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
|
@@ -397,30 +402,50 @@ with gr.Blocks(title="SPLADE Demos") as demo:
|
|
| 397 |
|
| 398 |
with gr.Tabs():
|
| 399 |
with gr.TabItem("Sparse Representation"):
|
| 400 |
-
gr.Markdown("### Produce a Sparse Representation of
|
| 401 |
-
gr.
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
gr.Radio(
|
| 405 |
[
|
| 406 |
"MLM encoder (SPLADE-cocondenser-distil)",
|
| 407 |
"MLP encoder (SPLADE-v3-lexical)",
|
| 408 |
-
"Binary Bag-of-Words"
|
| 409 |
],
|
| 410 |
label="Choose Sparse Encoder",
|
| 411 |
value="MLM encoder (SPLADE-cocondenser-distil)"
|
| 412 |
-
)
|
| 413 |
-
gr.Textbox(
|
| 414 |
lines=5,
|
| 415 |
label="Enter your query or document text here:",
|
| 416 |
placeholder="e.g., Why is Padua the nicest city in Italy?"
|
| 417 |
)
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
)
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
with gr.TabItem("Compare Encoders"): # NEW TAB
|
| 425 |
gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
|
| 426 |
gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.")
|
|
@@ -429,7 +454,7 @@ with gr.Blocks(title="SPLADE Demos") as demo:
|
|
| 429 |
model_choices = [
|
| 430 |
"MLM encoder (SPLADE-cocondenser-distil)",
|
| 431 |
"MLP encoder (SPLADE-v3-lexical)",
|
| 432 |
-
"Binary Bag-of-Words"
|
| 433 |
]
|
| 434 |
|
| 435 |
gr.Interface(
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
| 82 |
+
# These functions now return a tuple: (main_representation_str, info_str)
|
| 83 |
def get_splade_cocondenser_representation(text):
|
| 84 |
if tokenizer_splade is None or model_splade is None:
|
| 85 |
+
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors.", ""
|
| 86 |
|
| 87 |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
|
| 88 |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
|
|
|
|
| 96 |
dim=1
|
| 97 |
)[0].squeeze() # Squeeze is fine here as it's a single input
|
| 98 |
else:
|
| 99 |
+
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.", ""
|
| 100 |
|
| 101 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
| 102 |
if not isinstance(indices, list):
|
|
|
|
| 120 |
for term, weight in sorted_representation:
|
| 121 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
| 122 |
|
| 123 |
+
info_output = f"--- Sparse Vector Info ---\n"
|
| 124 |
+
info_output += f"Total non-zero terms in vector: {len(indices)}\n"
|
| 125 |
+
info_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
|
| 126 |
|
| 127 |
+
return formatted_output, info_output
|
| 128 |
|
| 129 |
|
| 130 |
def get_splade_lexical_representation(text):
|
| 131 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
| 132 |
+
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors.", ""
|
| 133 |
|
| 134 |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
|
| 135 |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
|
|
|
|
| 143 |
dim=1
|
| 144 |
)[0].squeeze() # Squeeze is fine here
|
| 145 |
else:
|
| 146 |
+
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.", ""
|
| 147 |
|
| 148 |
# Always apply lexical mask for this model's specific behavior
|
| 149 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
|
|
|
| 175 |
for term, weight in sorted_representation:
|
| 176 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
| 177 |
|
| 178 |
+
info_output = f"--- Raw Sparse Vector Info ---\n"
|
| 179 |
+
info_output += f"Total non-zero terms in vector: {len(indices)}\n"
|
| 180 |
+
info_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n"
|
| 181 |
|
| 182 |
+
return formatted_output, info_output
|
| 183 |
|
| 184 |
|
| 185 |
def get_splade_doc_representation(text):
|
| 186 |
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
|
| 187 |
+
return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors.", ""
|
| 188 |
|
| 189 |
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
|
| 190 |
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
|
|
|
|
| 220 |
break
|
| 221 |
formatted_output += f"- **{term}**\n"
|
| 222 |
|
| 223 |
+
info_output = f"--- Raw Binary Bag-of-Words Vector Info ---\n" # Changed title
|
| 224 |
+
info_output += f"Total activated terms: {len(indices)}\n"
|
| 225 |
+
info_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
|
| 226 |
|
| 227 |
+
return formatted_output, info_output
|
| 228 |
|
| 229 |
|
| 230 |
# --- Unified Prediction Function for the Explorer Tab ---
|
|
|
|
| 236 |
elif model_choice == "Binary Bag-of-Words": # Changed name
|
| 237 |
return get_splade_doc_representation(text)
|
| 238 |
else:
|
| 239 |
+
return "Please select a model.", "" # Return two empty strings for consistency
|
| 240 |
|
| 241 |
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
|
| 242 |
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
|
|
|
|
| 339 |
else:
|
| 340 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
| 341 |
|
| 342 |
+
info_output = f"\nTotal non-zero terms: {len(indices)}\n"
|
| 343 |
+
info_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n"
|
| 344 |
|
| 345 |
+
return formatted_output, info_output # Now returns two strings
|
| 346 |
|
| 347 |
|
| 348 |
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
|
|
|
|
| 376 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
| 377 |
|
| 378 |
# Format representations
|
| 379 |
+
# These functions now return two strings (main_output, info_output)
|
| 380 |
+
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
|
| 381 |
+
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
query_rep_str = f"Query Representation ({query_model_name_display}):\n"
|
| 385 |
+
query_rep_str += query_main_rep_str + "\n" + query_info_str
|
| 386 |
|
| 387 |
doc_rep_str = f"Document Representation ({doc_model_name_display}):\n"
|
| 388 |
+
doc_rep_str += doc_main_rep_str + "\n" + doc_info_str
|
| 389 |
|
| 390 |
# Combine output
|
| 391 |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
|
|
|
| 402 |
|
| 403 |
with gr.Tabs():
|
| 404 |
with gr.TabItem("Sparse Representation"):
|
| 405 |
+
gr.Markdown("### Produce a Sparse Representation of an Input Text")
|
| 406 |
+
with gr.Row():
|
| 407 |
+
with gr.Column(scale=1): # Left column for inputs and info
|
| 408 |
+
model_radio = gr.Radio(
|
|
|
|
| 409 |
[
|
| 410 |
"MLM encoder (SPLADE-cocondenser-distil)",
|
| 411 |
"MLP encoder (SPLADE-v3-lexical)",
|
| 412 |
+
"Binary Bag-of-Words"
|
| 413 |
],
|
| 414 |
label="Choose Sparse Encoder",
|
| 415 |
value="MLM encoder (SPLADE-cocondenser-distil)"
|
| 416 |
+
)
|
| 417 |
+
input_text = gr.Textbox(
|
| 418 |
lines=5,
|
| 419 |
label="Enter your query or document text here:",
|
| 420 |
placeholder="e.g., Why is Padua the nicest city in Italy?"
|
| 421 |
)
|
| 422 |
+
# New Markdown component for the info output
|
| 423 |
+
info_output_display = gr.Markdown(
|
| 424 |
+
value="",
|
| 425 |
+
label="Vector Information",
|
| 426 |
+
elem_id="info_output_display" # Add an ID for potential CSS if needed
|
| 427 |
+
)
|
| 428 |
+
with gr.Column(scale=2): # Right column for the main representation output
|
| 429 |
+
main_representation_output = gr.Markdown()
|
| 430 |
+
|
| 431 |
+
# Connect the interface elements
|
| 432 |
+
model_radio.change(
|
| 433 |
+
fn=predict_representation_explorer,
|
| 434 |
+
inputs=[model_radio, input_text],
|
| 435 |
+
outputs=[main_representation_output, info_output_display]
|
| 436 |
)
|
| 437 |
+
input_text.change(
|
| 438 |
+
fn=predict_representation_explorer,
|
| 439 |
+
inputs=[model_radio, input_text],
|
| 440 |
+
outputs=[main_representation_output, info_output_display]
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Initial call to populate on load (optional, but good for demo)
|
| 444 |
+
demo.load(
|
| 445 |
+
fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
|
| 446 |
+
outputs=[main_representation_output, info_output_display]
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
with gr.TabItem("Compare Encoders"): # NEW TAB
|
| 450 |
gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
|
| 451 |
gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.")
|
|
|
|
| 454 |
model_choices = [
|
| 455 |
"MLM encoder (SPLADE-cocondenser-distil)",
|
| 456 |
"MLP encoder (SPLADE-v3-lexical)",
|
| 457 |
+
"Binary Bag-of-Words"
|
| 458 |
]
|
| 459 |
|
| 460 |
gr.Interface(
|