Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -430,18 +430,73 @@ def calculate_dot_product_and_representations_independent(query_model_choice, do
|
|
| 430 |
if query_vector is None or doc_vector is None:
|
| 431 |
return "Failed to generate one or both vectors. Please check model loading and input text.", ""
|
| 432 |
|
| 433 |
-
# Calculate dot product
|
| 434 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
| 435 |
|
| 436 |
-
# Format representations
|
| 437 |
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
|
| 438 |
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
|
| 439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
# Combine output into a single string for the Markdown component
|
| 441 |
-
|
| 442 |
-
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
| 443 |
full_output += "---\n\n"
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
# Query Representation
|
| 446 |
full_output += f"#### Query Representation ({query_model_name_display}):\n" # Smaller heading for sub-section
|
| 447 |
full_output += f"> {query_main_rep_str}\n" # Using blockquote for the sparse list
|
|
|
|
| 430 |
if query_vector is None or doc_vector is None:
|
| 431 |
return "Failed to generate one or both vectors. Please check model loading and input text.", ""
|
| 432 |
|
| 433 |
+
# Calculate overall dot product
|
| 434 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
| 435 |
|
| 436 |
+
# Format representations for display
|
| 437 |
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
|
| 438 |
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
|
| 439 |
|
| 440 |
+
# --- NEW FEATURE: Calculate dot product of overlapping terms ---
|
| 441 |
+
overlapping_terms_dot_products = {}
|
| 442 |
+
query_indices = torch.nonzero(query_vector).squeeze().cpu()
|
| 443 |
+
doc_indices = torch.nonzero(doc_vector).squeeze().cpu()
|
| 444 |
+
|
| 445 |
+
# Handle cases where vectors are empty or single element
|
| 446 |
+
if query_indices.dim() == 0 and query_indices.numel() == 1:
|
| 447 |
+
query_indices = query_indices.unsqueeze(0)
|
| 448 |
+
if doc_indices.dim() == 0 and doc_indices.numel() == 1:
|
| 449 |
+
doc_indices = doc_indices.unsqueeze(0)
|
| 450 |
+
|
| 451 |
+
# Convert indices to sets for efficient intersection
|
| 452 |
+
query_index_set = set(query_indices.tolist())
|
| 453 |
+
doc_index_set = set(doc_indices.tolist())
|
| 454 |
+
|
| 455 |
+
common_indices = sorted(list(query_index_set.intersection(doc_index_set)))
|
| 456 |
+
|
| 457 |
+
if common_indices:
|
| 458 |
+
for idx in common_indices:
|
| 459 |
+
query_weight = query_vector[idx].item()
|
| 460 |
+
doc_weight = doc_vector[idx].item()
|
| 461 |
+
term = query_tokenizer.decode([idx]) # Tokenizers should be the same for this purpose
|
| 462 |
+
if term not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(term.strip()) > 0:
|
| 463 |
+
overlapping_terms_dot_products[term] = query_weight * doc_weight
|
| 464 |
+
|
| 465 |
+
sorted_overlapping_dot_products = sorted(
|
| 466 |
+
overlapping_terms_dot_products.items(),
|
| 467 |
+
key=lambda item: item[1],
|
| 468 |
+
reverse=True
|
| 469 |
+
)
|
| 470 |
+
# --- End NEW FEATURE ---
|
| 471 |
+
|
| 472 |
# Combine output into a single string for the Markdown component
|
| 473 |
+
full_output = f"### Overall Dot Product Score: {dot_product:.6f}\n\n"
|
|
|
|
| 474 |
full_output += "---\n\n"
|
| 475 |
|
| 476 |
+
# Overlapping Terms Dot Products
|
| 477 |
+
if sorted_overlapping_dot_products:
|
| 478 |
+
full_output += "### Dot Products of Overlapping Terms:\n"
|
| 479 |
+
full_output += "*(Term: Query_Weight x Document_Weight = Product)*\n\n"
|
| 480 |
+
overlap_list = []
|
| 481 |
+
for term, product_val in sorted_overlapping_dot_products:
|
| 482 |
+
# Get individual weights for display
|
| 483 |
+
query_weight = query_vector[query_tokenizer.encode(term, add_special_tokens=False)[0]].item()
|
| 484 |
+
doc_weight = doc_vector[doc_tokenizer.encode(term, add_special_tokens=False)[0]].item()
|
| 485 |
+
|
| 486 |
+
if query_is_binary and doc_is_binary:
|
| 487 |
+
overlap_list.append(f"**{term}**: 1.0000 x 1.0000 = {product_val:.4f}")
|
| 488 |
+
elif query_is_binary:
|
| 489 |
+
overlap_list.append(f"**{term}**: 1.0000 x {doc_weight:.4f} = {product_val:.4f}")
|
| 490 |
+
elif doc_is_binary:
|
| 491 |
+
overlap_list.append(f"**{term}**: {query_weight:.4f} x 1.0000 = {product_val:.4f}")
|
| 492 |
+
else:
|
| 493 |
+
overlap_list.append(f"**{term}**: {query_weight:.4f} x {doc_weight:.4f} = {product_val:.4f}")
|
| 494 |
+
full_output += ", ".join(overlap_list) + ".\n\n"
|
| 495 |
+
full_output += "---\n\n"
|
| 496 |
+
else:
|
| 497 |
+
full_output += "### No Overlapping Terms Found.\n\n"
|
| 498 |
+
full_output += "---\n\n"
|
| 499 |
+
|
| 500 |
# Query Representation
|
| 501 |
full_output += f"#### Query Representation ({query_model_name_display}):\n" # Smaller heading for sub-section
|
| 502 |
full_output += f"> {query_main_rep_str}\n" # Using blockquote for the sparse list
|