Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
·
7d21953
1
Parent(s):
556cc72
Handle tool calls with thinking enabled
Browse files- graph.py +15 -14
- main.py +2 -2
- prompts.py +16 -12
- requirements.txt +1 -1
graph.py
CHANGED
|
@@ -9,7 +9,7 @@ import os
|
|
| 9 |
|
| 10 |
# Local modules
|
| 11 |
from retriever import BuildRetriever
|
| 12 |
-
from prompts import
|
| 13 |
from mods.tool_calling_llm import ToolCallingLLM
|
| 14 |
|
| 15 |
# Local modules
|
|
@@ -49,13 +49,14 @@ def print_message_summaries(messages, header):
|
|
| 49 |
def normalize_messages(messages):
|
| 50 |
"""Normalize messages to sequence of types expected by chat templates"""
|
| 51 |
# Copy the most recent HumanMessage to the end
|
| 52 |
-
# (avoids
|
| 53 |
if not type(messages[-1]) is HumanMessage:
|
| 54 |
for msg in reversed(messages):
|
| 55 |
if type(msg) is HumanMessage:
|
| 56 |
messages.append(msg)
|
|
|
|
| 57 |
# Convert tool output (ToolMessage) to AIMessage
|
| 58 |
-
# (avoids
|
| 59 |
messages = [
|
| 60 |
AIMessage(msg.content) if type(msg) is ToolMessage else msg for msg in messages
|
| 61 |
]
|
|
@@ -75,7 +76,7 @@ def ToolifyHF(chat_model, system_message, system_message_suffix="", think=False)
|
|
| 75 |
Get a Hugging Face model ready for bind_tools().
|
| 76 |
"""
|
| 77 |
|
| 78 |
-
## Add /no_think flag to turn off thinking mode (SmolLM3)
|
| 79 |
# if not think:
|
| 80 |
# system_message = "/no_think\n" + system_message
|
| 81 |
|
|
@@ -203,14 +204,12 @@ def BuildGraph(
|
|
| 203 |
# Add tools to the local or remote chat model
|
| 204 |
is_local = hasattr(chat_model, "model_id")
|
| 205 |
if is_local:
|
| 206 |
-
# For local
|
| 207 |
query_model = ToolifyHF(
|
| 208 |
-
chat_model,
|
| 209 |
).bind_tools([retrieve_emails])
|
| 210 |
-
# Don't use answer_with_citations tool
|
| 211 |
-
generate_model =
|
| 212 |
-
chat_model, answer_prompt(with_tools=False), "", think_generate
|
| 213 |
-
)
|
| 214 |
else:
|
| 215 |
# For remote model (OpenAI API)
|
| 216 |
query_model = chat_model.bind_tools([retrieve_emails])
|
|
@@ -228,9 +227,7 @@ def BuildGraph(
|
|
| 228 |
messages = normalize_messages(messages)
|
| 229 |
print_message_summaries(messages, "--- query: after normalization ---")
|
| 230 |
else:
|
| 231 |
-
messages = [SystemMessage(
|
| 232 |
-
"messages"
|
| 233 |
-
]
|
| 234 |
response = query_model.invoke(messages)
|
| 235 |
|
| 236 |
return {"messages": response}
|
|
@@ -241,9 +238,13 @@ def BuildGraph(
|
|
| 241 |
messages = state["messages"]
|
| 242 |
print_message_summaries(messages, "--- generate: before normalization ---")
|
| 243 |
messages = normalize_messages(messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
print_message_summaries(messages, "--- generate: after normalization ---")
|
| 245 |
else:
|
| 246 |
-
messages = [SystemMessage(
|
| 247 |
response = generate_model.invoke(messages)
|
| 248 |
|
| 249 |
return {"messages": response}
|
|
|
|
| 9 |
|
| 10 |
# Local modules
|
| 11 |
from retriever import BuildRetriever
|
| 12 |
+
from prompts import query_prompt, generate_prompt, gemma_tools_template
|
| 13 |
from mods.tool_calling_llm import ToolCallingLLM
|
| 14 |
|
| 15 |
# Local modules
|
|
|
|
| 49 |
def normalize_messages(messages):
|
| 50 |
"""Normalize messages to sequence of types expected by chat templates"""
|
| 51 |
# Copy the most recent HumanMessage to the end
|
| 52 |
+
# (avoids SmolLM and Qwen ValueError: Last message must be a HumanMessage!)
|
| 53 |
if not type(messages[-1]) is HumanMessage:
|
| 54 |
for msg in reversed(messages):
|
| 55 |
if type(msg) is HumanMessage:
|
| 56 |
messages.append(msg)
|
| 57 |
+
break
|
| 58 |
# Convert tool output (ToolMessage) to AIMessage
|
| 59 |
+
# (avoids SmolLM and Qwen ValueError: Unknown message type: <class 'langchain_core.messages.tool.ToolMessage'>)
|
| 60 |
messages = [
|
| 61 |
AIMessage(msg.content) if type(msg) is ToolMessage else msg for msg in messages
|
| 62 |
]
|
|
|
|
| 76 |
Get a Hugging Face model ready for bind_tools().
|
| 77 |
"""
|
| 78 |
|
| 79 |
+
## Add /no_think flag to turn off thinking mode (SmolLM3 and Qwen)
|
| 80 |
# if not think:
|
| 81 |
# system_message = "/no_think\n" + system_message
|
| 82 |
|
|
|
|
| 204 |
# Add tools to the local or remote chat model
|
| 205 |
is_local = hasattr(chat_model, "model_id")
|
| 206 |
if is_local:
|
| 207 |
+
# For local models (ChatHuggingFace with SmolLM, Gemma, or Qwen)
|
| 208 |
query_model = ToolifyHF(
|
| 209 |
+
chat_model, query_prompt(compute_mode), "", think_retrieve
|
| 210 |
).bind_tools([retrieve_emails])
|
| 211 |
+
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
| 212 |
+
generate_model = chat_model
|
|
|
|
|
|
|
| 213 |
else:
|
| 214 |
# For remote model (OpenAI API)
|
| 215 |
query_model = chat_model.bind_tools([retrieve_emails])
|
|
|
|
| 227 |
messages = normalize_messages(messages)
|
| 228 |
print_message_summaries(messages, "--- query: after normalization ---")
|
| 229 |
else:
|
| 230 |
+
messages = [SystemMessage(query_prompt(compute_mode))] + state["messages"]
|
|
|
|
|
|
|
| 231 |
response = query_model.invoke(messages)
|
| 232 |
|
| 233 |
return {"messages": response}
|
|
|
|
| 238 |
messages = state["messages"]
|
| 239 |
print_message_summaries(messages, "--- generate: before normalization ---")
|
| 240 |
messages = normalize_messages(messages)
|
| 241 |
+
# Add the system message here because we're not using tools
|
| 242 |
+
messages = [
|
| 243 |
+
SystemMessage(generate_prompt(with_tools=False, think=False))
|
| 244 |
+
] + messages
|
| 245 |
print_message_summaries(messages, "--- generate: after normalization ---")
|
| 246 |
else:
|
| 247 |
+
messages = [SystemMessage(generate_prompt())] + state["messages"]
|
| 248 |
response = generate_model.invoke(messages)
|
| 249 |
|
| 250 |
return {"messages": response}
|
main.py
CHANGED
|
@@ -23,7 +23,7 @@ from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
|
|
| 23 |
from index import ProcessFile
|
| 24 |
from retriever import BuildRetriever, db_dir
|
| 25 |
from graph import BuildGraph
|
| 26 |
-
from prompts import
|
| 27 |
|
| 28 |
# -----------
|
| 29 |
# R-help-chat
|
|
@@ -200,7 +200,7 @@ def RunChain(
|
|
| 200 |
chat_model = GetChatModel(compute_mode)
|
| 201 |
|
| 202 |
# Control thinking for SmolLM3
|
| 203 |
-
system_prompt =
|
| 204 |
if hasattr(chat_model, "model_id") and not think:
|
| 205 |
system_prompt = f"/no_think\n{system_prompt}"
|
| 206 |
|
|
|
|
| 23 |
from index import ProcessFile
|
| 24 |
from retriever import BuildRetriever, db_dir
|
| 25 |
from graph import BuildGraph
|
| 26 |
+
from prompts import generate_prompt
|
| 27 |
|
| 28 |
# -----------
|
| 29 |
# R-help-chat
|
|
|
|
| 200 |
chat_model = GetChatModel(compute_mode)
|
| 201 |
|
| 202 |
# Control thinking for SmolLM3
|
| 203 |
+
system_prompt = generate_prompt()
|
| 204 |
if hasattr(chat_model, "model_id") and not think:
|
| 205 |
system_prompt = f"/no_think\n{system_prompt}"
|
| 206 |
|
prompts.py
CHANGED
|
@@ -3,7 +3,7 @@ from util import get_sources, get_start_end_months
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
|
| 6 |
-
def
|
| 7 |
"""Return system prompt for query step
|
| 8 |
|
| 9 |
Args:
|
|
@@ -13,11 +13,11 @@ def retrieve_prompt(compute_mode):
|
|
| 13 |
# Get start and end months from database
|
| 14 |
start, end = get_start_end_months(get_sources())
|
| 15 |
|
| 16 |
-
|
| 17 |
f"Today Date: {date.today()}."
|
| 18 |
"You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
|
| 19 |
"Do not ask the user for more information, but retrieve emails from the R-help mailing list archives."
|
| 20 |
-
# gpt-4o-mini
|
| 21 |
f"The emails available for retrieval are from {start} to {end}."
|
| 22 |
"Write a search query based on the user's question, but do not answer the question just yet."
|
| 23 |
"For questions about differences or comparison between X and Y, retrieve emails about X and Y."
|
|
@@ -25,19 +25,20 @@ def retrieve_prompt(compute_mode):
|
|
| 25 |
"For specific questions, use retrieve_emails(search_query=<specific topic>)."
|
| 26 |
"For questions about years, use retrieve_emails(search_query=, start_year=, end_year=) (this month is this year)."
|
| 27 |
"For questions about months, use 3-letter abbreviations (Jan..Dec) for the 'month' argument."
|
| 28 |
-
"
|
|
|
|
| 29 |
)
|
| 30 |
# A sanity check that we don't have unassigned variables
|
| 31 |
# (this causes KeyError in parsing by ToolCallingLLM)
|
| 32 |
-
matches = re.findall(r"\{.*?\}", " ".join(
|
| 33 |
if matches:
|
| 34 |
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
| 35 |
-
return
|
| 36 |
|
| 37 |
|
| 38 |
-
def
|
| 39 |
"""Return system prompt for generate step"""
|
| 40 |
-
|
| 41 |
f"Today Date: {date.today()}."
|
| 42 |
"You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
|
| 43 |
"Summarize the retrieved emails from the R-help mailing list archives to answer the user's question or query."
|
|
@@ -45,17 +46,20 @@ def answer_prompt(with_tools=True):
|
|
| 45 |
"Tell the user if there are no retrieved emails or if you are unable to answer the question based on the information in the emails."
|
| 46 |
"Do not give an answer based on your own knowledge or memory, and do not include examples that aren't based on the retrieved emails."
|
| 47 |
"Example: For a question about writing formulas for lm(), make your answer about formulas for lm() from the retrieved emails."
|
| 48 |
-
"Do not respond with packages that are only listed under sessionInfo, session info, or other attached packages."
|
|
|
|
| 49 |
"Include inline citations (email senders and dates) in your response."
|
| 50 |
"Only answer general questions about R if the answer is given in the retrieved emails."
|
| 51 |
"Respond with 300 words maximum and 30 lines of code maximum and include any relevant URLs from the retrieved emails."
|
| 52 |
)
|
| 53 |
if with_tools:
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
if matches:
|
| 57 |
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
| 58 |
-
return
|
| 59 |
|
| 60 |
|
| 61 |
# Prompt template for SmolLM3 with tools
|
|
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
|
| 6 |
+
def query_prompt(compute_mode):
|
| 7 |
"""Return system prompt for query step
|
| 8 |
|
| 9 |
Args:
|
|
|
|
| 13 |
# Get start and end months from database
|
| 14 |
start, end = get_start_end_months(get_sources())
|
| 15 |
|
| 16 |
+
query_prompt = (
|
| 17 |
f"Today Date: {date.today()}."
|
| 18 |
"You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
|
| 19 |
"Do not ask the user for more information, but retrieve emails from the R-help mailing list archives."
|
| 20 |
+
# gpt-4o-mini thinks last two months aren't available with this: "Emails from from {start} to {end} are available for retrieval."
|
| 21 |
f"The emails available for retrieval are from {start} to {end}."
|
| 22 |
"Write a search query based on the user's question, but do not answer the question just yet."
|
| 23 |
"For questions about differences or comparison between X and Y, retrieve emails about X and Y."
|
|
|
|
| 25 |
"For specific questions, use retrieve_emails(search_query=<specific topic>)."
|
| 26 |
"For questions about years, use retrieve_emails(search_query=, start_year=, end_year=) (this month is this year)."
|
| 27 |
"For questions about months, use 3-letter abbreviations (Jan..Dec) for the 'month' argument."
|
| 28 |
+
"Even if retrieved emails are already available, you should retrieve *more* emails to answer the most recent question." # Qwen
|
| 29 |
+
# "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list."
|
| 30 |
)
|
| 31 |
# A sanity check that we don't have unassigned variables
|
| 32 |
# (this causes KeyError in parsing by ToolCallingLLM)
|
| 33 |
+
matches = re.findall(r"\{.*?\}", " ".join(query_prompt))
|
| 34 |
if matches:
|
| 35 |
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
| 36 |
+
return query_prompt
|
| 37 |
|
| 38 |
|
| 39 |
+
def generate_prompt(with_tools=True, think=True):
|
| 40 |
"""Return system prompt for generate step"""
|
| 41 |
+
generate_prompt = (
|
| 42 |
f"Today Date: {date.today()}."
|
| 43 |
"You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
|
| 44 |
"Summarize the retrieved emails from the R-help mailing list archives to answer the user's question or query."
|
|
|
|
| 46 |
"Tell the user if there are no retrieved emails or if you are unable to answer the question based on the information in the emails."
|
| 47 |
"Do not give an answer based on your own knowledge or memory, and do not include examples that aren't based on the retrieved emails."
|
| 48 |
"Example: For a question about writing formulas for lm(), make your answer about formulas for lm() from the retrieved emails."
|
| 49 |
+
# "Do not respond with packages that are only listed under sessionInfo, session info, or other attached packages."
|
| 50 |
+
"Summarize the content of the emails rather than copying the headers." # Qwen
|
| 51 |
"Include inline citations (email senders and dates) in your response."
|
| 52 |
"Only answer general questions about R if the answer is given in the retrieved emails."
|
| 53 |
"Respond with 300 words maximum and 30 lines of code maximum and include any relevant URLs from the retrieved emails."
|
| 54 |
)
|
| 55 |
if with_tools:
|
| 56 |
+
generate_prompt += "Use answer_with_citations to provide the complete answer and all citations used."
|
| 57 |
+
if not think:
|
| 58 |
+
generate_prompt += "/no_think"
|
| 59 |
+
matches = re.findall(r"\{.*?\}", " ".join(generate_prompt))
|
| 60 |
if matches:
|
| 61 |
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
| 62 |
+
return generate_prompt
|
| 63 |
|
| 64 |
|
| 65 |
# Prompt template for SmolLM3 with tools
|
requirements.txt
CHANGED
|
@@ -13,7 +13,7 @@ torch==2.5.1
|
|
| 13 |
# Gemma 3: transformers>=4.50
|
| 14 |
# Gemma 3 with transformers==4.54.0 gives:
|
| 15 |
# ValueError: Max cache length is not consistent across layers
|
| 16 |
-
transformers==4.
|
| 17 |
# Commented because we have local modifications
|
| 18 |
#tool-calling-llm==0.1.2
|
| 19 |
bm25s==0.2.12
|
|
|
|
| 13 |
# Gemma 3: transformers>=4.50
|
| 14 |
# Gemma 3 with transformers==4.54.0 gives:
|
| 15 |
# ValueError: Max cache length is not consistent across layers
|
| 16 |
+
transformers==4.50.0
|
| 17 |
# Commented because we have local modifications
|
| 18 |
#tool-calling-llm==0.1.2
|
| 19 |
bm25s==0.2.12
|