Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
·
7e18a82
1
Parent(s):
f42e9e5
Enable FlashAttention
Browse files- graph.py +42 -22
- main.py +6 -1
- pipeline.py +1 -0
- requirements.txt +3 -0
- retriever.py +2 -0
graph.py
CHANGED
|
@@ -43,8 +43,18 @@ def print_message_summaries(messages, header):
|
|
| 43 |
print(f"{type_txt}: {summary_txt}")
|
| 44 |
|
| 45 |
|
| 46 |
-
def normalize_messages(messages):
|
| 47 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# Copy the most recent HumanMessage to the end
|
| 49 |
# - Avoids SmolLM and Qwen ValueError: Last message must be a HumanMessage!
|
| 50 |
if not type(messages[-1]) is HumanMessage:
|
|
@@ -88,6 +98,10 @@ def normalize_messages(messages):
|
|
| 88 |
if not hasattr(msg, "tool_calls")
|
| 89 |
or (hasattr(msg, "tool_calls") and not msg.tool_calls)
|
| 90 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
return messages
|
| 92 |
|
| 93 |
|
|
@@ -118,6 +132,7 @@ def BuildGraph(
|
|
| 118 |
top_k=6,
|
| 119 |
think_query=False,
|
| 120 |
think_answer=False,
|
|
|
|
| 121 |
embedding_ckpt_dir=None,
|
| 122 |
):
|
| 123 |
"""
|
|
@@ -128,8 +143,10 @@ def BuildGraph(
|
|
| 128 |
compute_mode: remote or local (for retriever)
|
| 129 |
search_type: dense, sparse, or hybrid (for retriever)
|
| 130 |
top_k: number of documents to retrieve
|
| 131 |
-
think_query: Whether to use thinking mode for the query
|
| 132 |
-
think_answer: Whether to use thinking mode for the answer
|
|
|
|
|
|
|
| 133 |
|
| 134 |
Based on:
|
| 135 |
https://python.langchain.com/docs/how_to/qa_sources
|
|
@@ -175,10 +192,10 @@ def BuildGraph(
|
|
| 175 |
Use optional "months" argument to search by month.
|
| 176 |
|
| 177 |
Args:
|
| 178 |
-
search_query: Search query
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
"""
|
| 183 |
retriever = BuildRetriever(
|
| 184 |
compute_mode, search_type, top_k, start_year, end_year, embedding_ckpt_dir
|
|
@@ -208,8 +225,8 @@ def BuildGraph(
|
|
| 208 |
An answer to the question, with citations of the emails used (senders and dates).
|
| 209 |
|
| 210 |
Args:
|
| 211 |
-
answer: An answer to the question
|
| 212 |
-
citations: Citations of emails used to answer the question, e.g. Jane Doe, 2025-07-04; John Smith, 2020-01-01
|
| 213 |
"""
|
| 214 |
return answer, citations
|
| 215 |
|
|
@@ -220,8 +237,14 @@ def BuildGraph(
|
|
| 220 |
query_model = ToolifyHF(
|
| 221 |
chat_model, query_prompt(chat_model, think=think_query)
|
| 222 |
).bind_tools([retrieve_emails])
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
else:
|
| 226 |
# For remote model (OpenAI API)
|
| 227 |
query_model = chat_model.bind_tools([retrieve_emails])
|
|
@@ -235,9 +258,7 @@ def BuildGraph(
|
|
| 235 |
if is_local:
|
| 236 |
# Don't include the system message here because it's defined in ToolCallingLLM
|
| 237 |
messages = state["messages"]
|
| 238 |
-
|
| 239 |
-
messages = normalize_messages(messages)
|
| 240 |
-
# print_message_summaries(messages, "--- query: after normalization ---")
|
| 241 |
else:
|
| 242 |
messages = [SystemMessage(query_prompt(chat_model))] + state["messages"]
|
| 243 |
response = query_model.invoke(messages)
|
|
@@ -248,13 +269,12 @@ def BuildGraph(
|
|
| 248 |
"""Generates an answer with the chat model"""
|
| 249 |
if is_local:
|
| 250 |
messages = state["messages"]
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
# print_message_summaries(messages, "--- answer: after normalization ---")
|
| 258 |
else:
|
| 259 |
messages = [
|
| 260 |
SystemMessage(answer_prompt(chat_model, with_tools=True))
|
|
|
|
| 43 |
print(f"{type_txt}: {summary_txt}")
|
| 44 |
|
| 45 |
|
| 46 |
+
def normalize_messages(messages, summaries_for=None):
|
| 47 |
+
"""
|
| 48 |
+
Normalize messages to sequence of types expected by chat models
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
messages (list): message list
|
| 52 |
+
summaries_for (str): "query" or "answer" to print messages summaries or None for no summaries
|
| 53 |
+
"""
|
| 54 |
+
if summaries_for:
|
| 55 |
+
print_message_summaries(
|
| 56 |
+
messages, f"--- {summaries_for}: before normalization ---"
|
| 57 |
+
)
|
| 58 |
# Copy the most recent HumanMessage to the end
|
| 59 |
# - Avoids SmolLM and Qwen ValueError: Last message must be a HumanMessage!
|
| 60 |
if not type(messages[-1]) is HumanMessage:
|
|
|
|
| 98 |
if not hasattr(msg, "tool_calls")
|
| 99 |
or (hasattr(msg, "tool_calls") and not msg.tool_calls)
|
| 100 |
]
|
| 101 |
+
if summaries_for:
|
| 102 |
+
print_message_summaries(
|
| 103 |
+
messages, f"--- {summaries_for}: after normalization ---"
|
| 104 |
+
)
|
| 105 |
return messages
|
| 106 |
|
| 107 |
|
|
|
|
| 132 |
top_k=6,
|
| 133 |
think_query=False,
|
| 134 |
think_answer=False,
|
| 135 |
+
local_citations=True,
|
| 136 |
embedding_ckpt_dir=None,
|
| 137 |
):
|
| 138 |
"""
|
|
|
|
| 143 |
compute_mode: remote or local (for retriever)
|
| 144 |
search_type: dense, sparse, or hybrid (for retriever)
|
| 145 |
top_k: number of documents to retrieve
|
| 146 |
+
think_query: Whether to use thinking mode for the query (local model)
|
| 147 |
+
think_answer: Whether to use thinking mode for the answer (local model)
|
| 148 |
+
local_citations: Whether to use answer_with_citations() tool (local model)
|
| 149 |
+
embedding_ckpt_dir: Directory for embedding model checkpoint
|
| 150 |
|
| 151 |
Based on:
|
| 152 |
https://python.langchain.com/docs/how_to/qa_sources
|
|
|
|
| 192 |
Use optional "months" argument to search by month.
|
| 193 |
|
| 194 |
Args:
|
| 195 |
+
search_query (str): Search query
|
| 196 |
+
start_year (int, optional): Starting year for emails
|
| 197 |
+
end_year (int, optional): Ending year for emails
|
| 198 |
+
months (str, optional): One or more months separated by spaces
|
| 199 |
"""
|
| 200 |
retriever = BuildRetriever(
|
| 201 |
compute_mode, search_type, top_k, start_year, end_year, embedding_ckpt_dir
|
|
|
|
| 225 |
An answer to the question, with citations of the emails used (senders and dates).
|
| 226 |
|
| 227 |
Args:
|
| 228 |
+
answer (str): An answer to the question
|
| 229 |
+
citations (str): Citations of emails used to answer the question, e.g. Jane Doe, 2025-07-04; John Smith, 2020-01-01
|
| 230 |
"""
|
| 231 |
return answer, citations
|
| 232 |
|
|
|
|
| 237 |
query_model = ToolifyHF(
|
| 238 |
chat_model, query_prompt(chat_model, think=think_query)
|
| 239 |
).bind_tools([retrieve_emails])
|
| 240 |
+
if local_citations:
|
| 241 |
+
answer_model = ToolifyHF(
|
| 242 |
+
chat_model,
|
| 243 |
+
answer_prompt(chat_model, think=think_answer, with_tools=True),
|
| 244 |
+
).bind_tools([answer_with_citations])
|
| 245 |
+
else:
|
| 246 |
+
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
| 247 |
+
answer_model = chat_model
|
| 248 |
else:
|
| 249 |
# For remote model (OpenAI API)
|
| 250 |
query_model = chat_model.bind_tools([retrieve_emails])
|
|
|
|
| 258 |
if is_local:
|
| 259 |
# Don't include the system message here because it's defined in ToolCallingLLM
|
| 260 |
messages = state["messages"]
|
| 261 |
+
messages = normalize_messages(messages, "query")
|
|
|
|
|
|
|
| 262 |
else:
|
| 263 |
messages = [SystemMessage(query_prompt(chat_model))] + state["messages"]
|
| 264 |
response = query_model.invoke(messages)
|
|
|
|
| 269 |
"""Generates an answer with the chat model"""
|
| 270 |
if is_local:
|
| 271 |
messages = state["messages"]
|
| 272 |
+
messages = normalize_messages(messages, "answer")
|
| 273 |
+
if not local_citations:
|
| 274 |
+
# Add the system message here if we're not using tools
|
| 275 |
+
messages = [
|
| 276 |
+
SystemMessage(answer_prompt(chat_model, think=think_answer))
|
| 277 |
+
] + messages
|
|
|
|
| 278 |
else:
|
| 279 |
messages = [
|
| 280 |
SystemMessage(answer_prompt(chat_model, with_tools=True))
|
main.py
CHANGED
|
@@ -154,6 +154,10 @@ def GetChatModel(compute_mode, ckpt_dir=None):
|
|
| 154 |
id_or_dir,
|
| 155 |
# We need this to load the model in BF16 instead of fp32 (torch.float)
|
| 156 |
torch_dtype=torch.bfloat16,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
)
|
| 158 |
|
| 159 |
# Use MyTextGenerationPipeline with custom preprocess() method
|
|
@@ -164,7 +168,8 @@ def GetChatModel(compute_mode, ckpt_dir=None):
|
|
| 164 |
return_full_text=False,
|
| 165 |
# It seems that max_new_tokens has to be specified here, not in .invoke()
|
| 166 |
max_new_tokens=2000,
|
| 167 |
-
# Use padding for
|
|
|
|
| 168 |
# https://github.com/google-deepmind/gemma/issues/169
|
| 169 |
padding="longest",
|
| 170 |
)
|
|
|
|
| 154 |
id_or_dir,
|
| 155 |
# We need this to load the model in BF16 instead of fp32 (torch.float)
|
| 156 |
torch_dtype=torch.bfloat16,
|
| 157 |
+
# Enable FlashAttention (requires pip install flash-attn)
|
| 158 |
+
# https://huggingface.co/docs/transformers/en/attention_interface
|
| 159 |
+
# https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2
|
| 160 |
+
attn_implementation="flash_attention_2",
|
| 161 |
)
|
| 162 |
|
| 163 |
# Use MyTextGenerationPipeline with custom preprocess() method
|
|
|
|
| 168 |
return_full_text=False,
|
| 169 |
# It seems that max_new_tokens has to be specified here, not in .invoke()
|
| 170 |
max_new_tokens=2000,
|
| 171 |
+
# Use padding for proper alignment for FlashAttention
|
| 172 |
+
# Part of fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned"
|
| 173 |
# https://github.com/google-deepmind/gemma/issues/169
|
| 174 |
padding="longest",
|
| 175 |
)
|
pipeline.py
CHANGED
|
@@ -8,6 +8,7 @@ class MyTextGenerationPipeline(TextGenerationPipeline):
|
|
| 8 |
This subclass overrides the preprocess method to add pad_to_multiple_of=8 to tokenizer_kwargs.
|
| 9 |
Fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned"
|
| 10 |
https://github.com/google-deepmind/gemma/issues/169
|
|
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
def preprocess(
|
|
|
|
| 8 |
This subclass overrides the preprocess method to add pad_to_multiple_of=8 to tokenizer_kwargs.
|
| 9 |
Fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned"
|
| 10 |
https://github.com/google-deepmind/gemma/issues/169
|
| 11 |
+
NOTE: we also need padding="longest", which is set during class instantiation
|
| 12 |
"""
|
| 13 |
|
| 14 |
def preprocess(
|
requirements.txt
CHANGED
|
@@ -4,6 +4,9 @@ chromadb==0.6.3
|
|
| 4 |
# NOTE: chromadb==1.0.13 was giving intermittent error:
|
| 5 |
# ValueError('Could not connect to tenant default_tenant. Are you sure it exists?')
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
# Stated requirements:
|
| 8 |
# Gemma 3: transformers>=4.50
|
| 9 |
# Qwen3: transformers>=4.51
|
|
|
|
| 4 |
# NOTE: chromadb==1.0.13 was giving intermittent error:
|
| 5 |
# ValueError('Could not connect to tenant default_tenant. Are you sure it exists?')
|
| 6 |
|
| 7 |
+
# FlashAttention
|
| 8 |
+
flash-attn==2.8.2
|
| 9 |
+
|
| 10 |
# Stated requirements:
|
| 11 |
# Gemma 3: transformers>=4.50
|
| 12 |
# Qwen3: transformers>=4.51
|
retriever.py
CHANGED
|
@@ -49,6 +49,7 @@ def BuildRetriever(
|
|
| 49 |
top_k: Number of documents to retrieve for "dense" and "sparse"
|
| 50 |
start_year: Start year (optional)
|
| 51 |
end_year: End year (optional)
|
|
|
|
| 52 |
"""
|
| 53 |
if search_type == "dense":
|
| 54 |
if not (start_year or end_year):
|
|
@@ -134,6 +135,7 @@ def BuildRetrieverDense(compute_mode: str, top_k=6, embedding_ckpt_dir=None):
|
|
| 134 |
Args:
|
| 135 |
compute_mode: Compute mode for embeddings (remote or local)
|
| 136 |
top_k: Number of documents to retrieve
|
|
|
|
| 137 |
"""
|
| 138 |
|
| 139 |
# Don't try to use local models without a GPU
|
|
|
|
| 49 |
top_k: Number of documents to retrieve for "dense" and "sparse"
|
| 50 |
start_year: Start year (optional)
|
| 51 |
end_year: End year (optional)
|
| 52 |
+
embedding_ckpt_dir: Directory for embedding model checkpoint
|
| 53 |
"""
|
| 54 |
if search_type == "dense":
|
| 55 |
if not (start_year or end_year):
|
|
|
|
| 135 |
Args:
|
| 136 |
compute_mode: Compute mode for embeddings (remote or local)
|
| 137 |
top_k: Number of documents to retrieve
|
| 138 |
+
embedding_ckpt_dir: Directory for embedding model checkpoint
|
| 139 |
"""
|
| 140 |
|
| 141 |
# Don't try to use local models without a GPU
|