Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -53,17 +53,19 @@ document_texts = {} # Stores {doc_id: doc_text}
|
|
| 53 |
initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
|
| 54 |
|
| 55 |
|
| 56 |
-
# --- Load
|
| 57 |
-
|
|
|
|
| 58 |
global document_texts
|
| 59 |
-
print("Loading
|
| 60 |
try:
|
| 61 |
-
dataset
|
| 62 |
-
|
|
|
|
| 63 |
document_texts[doc.doc_id] = doc.text.strip()
|
| 64 |
-
print(f"Loaded {len(document_texts)} documents from
|
| 65 |
except Exception as e:
|
| 66 |
-
print(f"Error loading
|
| 67 |
print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
|
| 68 |
|
| 69 |
|
|
@@ -88,8 +90,6 @@ def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
|
| 88 |
|
| 89 |
|
| 90 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
| 91 |
-
# These are your original functions, re-added.
|
| 92 |
-
|
| 93 |
def get_splade_cocondenser_representation(text):
|
| 94 |
if tokenizer_splade is None or model_splade is None:
|
| 95 |
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
|
@@ -254,8 +254,6 @@ def predict_representation_explorer(model_choice, text):
|
|
| 254 |
|
| 255 |
|
| 256 |
# --- Internal Core Representation Functions (Return Raw Vectors - for Retrieval Tab) ---
|
| 257 |
-
# These are the ones ending with _internal, as previously defined.
|
| 258 |
-
|
| 259 |
def get_splade_cocondenser_representation_internal(text, tokenizer, model):
|
| 260 |
if tokenizer is None or model is None: return None
|
| 261 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
|
@@ -400,7 +398,8 @@ def predict_retrieval_gradio(query_text, query_model_choice, selected_doc_model_
|
|
| 400 |
|
| 401 |
# --- Initial Load and Indexing Calls ---
|
| 402 |
# This part runs once when the app starts.
|
| 403 |
-
|
|
|
|
| 404 |
|
| 405 |
if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
|
| 406 |
index_documents(initial_doc_model_for_indexing)
|
|
@@ -440,13 +439,11 @@ with gr.Blocks(title="SPLADE Demos") as demo:
|
|
| 440 |
],
|
| 441 |
outputs=gr.Markdown(),
|
| 442 |
allow_flagging="never",
|
| 443 |
-
#
|
| 444 |
-
# Setting live=True might be slow for complex models on every keystroke
|
| 445 |
-
# live=True
|
| 446 |
)
|
| 447 |
|
| 448 |
with gr.TabItem("Document Retrieval Demo"):
|
| 449 |
-
gr.Markdown("### Retrieve Documents from
|
| 450 |
gr.Interface(
|
| 451 |
fn=predict_retrieval_gradio,
|
| 452 |
inputs=[
|
|
|
|
| 53 |
initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
|
| 54 |
|
| 55 |
|
| 56 |
+
# --- Load Cranfield Corpus using ir_datasets ---
|
| 57 |
+
# Renamed function for clarity, but kept original name for call consistency
|
| 58 |
+
def load_cranfield_corpus_ir_datasets():
|
| 59 |
global document_texts
|
| 60 |
+
print("Loading Cranfield corpus using ir_datasets...")
|
| 61 |
try:
|
| 62 |
+
# --- IMPORTANT CHANGE: Loading 'cranfield' dataset ---
|
| 63 |
+
dataset = ir_datasets.load("cranfield")
|
| 64 |
+
for doc in tqdm(dataset.docs_iter(), desc="Loading Cranfield documents"):
|
| 65 |
document_texts[doc.doc_id] = doc.text.strip()
|
| 66 |
+
print(f"Loaded {len(document_texts)} documents from Cranfield corpus.")
|
| 67 |
except Exception as e:
|
| 68 |
+
print(f"Error loading Cranfield corpus with ir_datasets: {e}")
|
| 69 |
print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
|
| 70 |
|
| 71 |
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
|
|
|
|
|
|
| 93 |
def get_splade_cocondenser_representation(text):
|
| 94 |
if tokenizer_splade is None or model_splade is None:
|
| 95 |
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
|
|
|
| 254 |
|
| 255 |
|
| 256 |
# --- Internal Core Representation Functions (Return Raw Vectors - for Retrieval Tab) ---
|
|
|
|
|
|
|
| 257 |
def get_splade_cocondenser_representation_internal(text, tokenizer, model):
|
| 258 |
if tokenizer is None or model is None: return None
|
| 259 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
|
|
|
| 398 |
|
| 399 |
# --- Initial Load and Indexing Calls ---
|
| 400 |
# This part runs once when the app starts.
|
| 401 |
+
# --- IMPORTANT CHANGE: Calling the function that loads Cranfield ---
|
| 402 |
+
load_cranfield_corpus_ir_datasets()
|
| 403 |
|
| 404 |
if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
|
| 405 |
index_documents(initial_doc_model_for_indexing)
|
|
|
|
| 439 |
],
|
| 440 |
outputs=gr.Markdown(),
|
| 441 |
allow_flagging="never",
|
| 442 |
+
# live=True # Setting live=True might be slow for complex models on every keystroke
|
|
|
|
|
|
|
| 443 |
)
|
| 444 |
|
| 445 |
with gr.TabItem("Document Retrieval Demo"):
|
| 446 |
+
gr.Markdown("### Retrieve Documents from Cranfield Collection") # Changed title
|
| 447 |
gr.Interface(
|
| 448 |
fn=predict_retrieval_gradio,
|
| 449 |
inputs=[
|