Spaces:
Running
Running
| from typing import List | |
| from functools import partial | |
| import logging | |
| from langchain_core.messages import AIMessage, BaseMessage | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.language_models.llms import LLM | |
| from langgraph.prebuilt import tools_condition, ToolNode | |
| from langgraph.graph.state import StateGraph | |
| from langgraph.constants import START, END | |
| from ask_candid.retrieval.elastic import retriever_tool | |
| from ask_candid.tools.recommendation import ( | |
| detect_intent_with_llm, | |
| determine_context, | |
| make_recommendation | |
| ) | |
| from ask_candid.tools.question_reformulation import reformulate_question_using_history | |
| from ask_candid.tools.org_seach import has_org_name, insert_org_link | |
| from ask_candid.tools.search import search_agent | |
| from ask_candid.agents.schema import AgentState | |
| from ask_candid.utils import html_format_docs_chat | |
| logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| def generate_with_context(state: AgentState, llm: LLM) -> AgentState: | |
| """Generate answer. | |
| Parameters | |
| ---------- | |
| state : AgentState | |
| The current state | |
| llm : LLM | |
| Returns | |
| ------- | |
| AgentState | |
| The updated state with the agent response appended to messages | |
| """ | |
| logger.info("---GENERATE ANSWER---") | |
| messages = state["messages"] | |
| question = state["user_input"] | |
| last_message = messages[-1] | |
| sources_str = last_message.content | |
| sources_list = last_message.artifact # cannot use directly as list of Documents | |
| # converting to html string | |
| sources_html = html_format_docs_chat(sources_list) | |
| if sources_list: | |
| logger.info("---ADD SOURCES---") | |
| state["messages"].append(BaseMessage(content=sources_html, type="HTML")) | |
| # Prompt | |
| qa_system_prompt = """ | |
| You are an assistant for question-answering tasks in the social and philanthropic sector. \n | |
| Use the following pieces of retrieved context to answer the question at the end. \n | |
| If you don't know the answer, just say that you don't know. \n | |
| Keep the response professional, friendly, and as concise as possible. \n | |
| Question: {question} | |
| Context: {context} | |
| Answer: | |
| """ | |
| qa_prompt = ChatPromptTemplate([ | |
| ("system", qa_system_prompt), | |
| ("human", question), | |
| ]) | |
| rag_chain = qa_prompt | llm | StrOutputParser() | |
| response = rag_chain.invoke({"context": sources_str, "question": question}) | |
| return {"messages": [AIMessage(content=response)], "user_input": question} | |
| def add_recommendations_pipeline_( | |
| G: StateGraph, | |
| reformulation_node_name: str = "reformulate", | |
| search_node_name: str = "search_agent" | |
| ) -> None: | |
| """Adds execution sub-graph for recommendation engine flow. Graph changes are in-place. | |
| Parameters | |
| ---------- | |
| G : StateGraph | |
| Execution graph | |
| reformulation_node_name : str, optional | |
| Name of the node which reforumates input queries, by default "reformulate" | |
| search_node_name : str, optional | |
| Name of the node which executes document search + retrieval, by default "search_agent" | |
| """ | |
| # Nodes for recommendation functionalities | |
| G.add_node("detect_intent_with_llm", detect_intent_with_llm) | |
| G.add_node("determine_context", determine_context) | |
| G.add_node("make_recommendation", make_recommendation) | |
| # Check for recommendation query first | |
| # Execute until reaching END if user asks for recommendation | |
| G.add_edge(reformulation_node_name, "detect_intent_with_llm") | |
| G.add_conditional_edges( | |
| source="detect_intent_with_llm", | |
| path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name, | |
| path_map={ | |
| "determine_context": "determine_context", | |
| search_node_name: search_node_name | |
| }, | |
| ) | |
| G.add_edge("determine_context", "make_recommendation") | |
| G.add_edge("make_recommendation", END) | |
| def build_compute_graph( | |
| llm: LLM, | |
| indices: List[str], | |
| enable_recommendations: bool = False | |
| ) -> StateGraph: | |
| """Execution graph builder, the output is the execution flow for an interaction with the assistant. | |
| Parameters | |
| ---------- | |
| llm : LLM | |
| indices : List[str] | |
| Semantic index names to search over | |
| enable_recommendations : bool, optional | |
| Set to `True` to allow the flow to generate recommendations based on context, by default False | |
| Returns | |
| ------- | |
| StateGraph | |
| Execution graph | |
| """ | |
| candid_retriever_tool = retriever_tool(indices=indices) | |
| retrieve = ToolNode([candid_retriever_tool]) | |
| tools = [candid_retriever_tool] | |
| G = StateGraph(AgentState) | |
| G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm)) | |
| G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools)) | |
| G.add_node("retrieve", retrieve) | |
| G.add_node("generate_with_context", partial(generate_with_context, llm=llm)) | |
| G.add_node("has_org_name", partial(has_org_name, llm=llm)) | |
| G.add_node("insert_org_link", insert_org_link) | |
| if enable_recommendations: | |
| add_recommendations_pipeline_(G, reformulation_node_name="reformulate", search_node_name="search_agent") | |
| else: | |
| G.add_edge("reformulate", "search_agent") | |
| G.add_edge(START, "reformulate") | |
| G.add_conditional_edges( | |
| source="search_agent", | |
| path=tools_condition, | |
| path_map={ | |
| "tools": "retrieve", | |
| END: "has_org_name", | |
| }, | |
| ) | |
| G.add_edge("retrieve", "generate_with_context") | |
| G.add_edge("generate_with_context", "has_org_name") | |
| G.add_conditional_edges( | |
| source="has_org_name", | |
| path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict | |
| path_map={ | |
| "insert_org_link": "insert_org_link", | |
| END: END | |
| }, | |
| ) | |
| G.add_edge("insert_org_link", END) | |
| return G | |