MartinHummel commited on
Commit
7f63349
·
1 Parent(s): e327cd6

agent prompt

Browse files
Files changed (2) hide show
  1. agent/gaia_agent.py +34 -4
  2. local_benchmark.py +26 -0
agent/gaia_agent.py CHANGED
@@ -15,7 +15,7 @@ from langchain_community.chat_models import ChatOpenAI
15
 
16
  def create_langchain_agent() -> AgentExecutor:
17
  llm = ChatOpenAI(
18
- model_name="gpt-4o", # Or "gpt-3.5-turbo"
19
  temperature=0.1,
20
  openai_api_key=os.getenv("OPENAI_API_KEY"),
21
  )
@@ -29,14 +29,44 @@ def create_langchain_agent() -> AgentExecutor:
29
  Tool(name="vegetable_classifier_2022", func=vegetable_classifier_2022, description="Classify and extract only vegetables, excluding botanical fruits, based on a comma-separated list of food items."),
30
  Tool(name="excel_food_sales_sum", func=excel_food_sales_sum, description="Parse uploaded Excel file and return total food-related sales."),
31
  ]
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  agent = initialize_agent(
34
  tools=tools,
35
  llm=llm,
36
- agent=AgentType.OPENAI_FUNCTIONS,
37
- verbose=True
 
 
 
 
38
  )
39
 
 
 
 
40
  '''
41
  agent = initialize_agent(
42
  tools=tools,
 
15
 
16
  def create_langchain_agent() -> AgentExecutor:
17
  llm = ChatOpenAI(
18
+ model_name="gpt-4o",
19
  temperature=0.1,
20
  openai_api_key=os.getenv("OPENAI_API_KEY"),
21
  )
 
29
  Tool(name="vegetable_classifier_2022", func=vegetable_classifier_2022, description="Classify and extract only vegetables, excluding botanical fruits, based on a comma-separated list of food items."),
30
  Tool(name="excel_food_sales_sum", func=excel_food_sales_sum, description="Parse uploaded Excel file and return total food-related sales."),
31
  ]
32
+
33
+ agent_kwargs = {
34
+ "prefix": (
35
+ "You are a helpful AI assistant completing GAIA benchmark tasks.\n"
36
+ "You MUST use the tools provided to answer the user's question. Do not answer from your own knowledge.\n"
37
+ "Carefully analyze the question to determine the most appropriate tool to use.\n"
38
+ "Here are guidelines for using the tools:\n"
39
+ "- Use 'wikipedia_search' to find factual information about topics, events, people, etc. (e.g., 'Use wikipedia_search to find the population of France').\n"
40
+ "- Use 'youtube_transcript' to extract transcripts from YouTube videos when the question requires understanding the video content. (e.g., 'Use youtube_transcript to summarize the key points of this video').\n"
41
+ "- Use 'audio_transcriber' to transcribe uploaded audio files. (e.g., 'Use audio_transcriber to get the text from this audio recording').\n"
42
+ "- Use 'chess_image_solver' to analyze and solve chess puzzles from images. (e.g., 'Use chess_image_solver to determine the best move in this chess position').\n"
43
+ "- Use 'file_parser' to parse and analyze data from Excel or CSV files. (e.g., 'Use file_parser to calculate the average sales from this data').\n"
44
+ "- Use 'vegetable_classifier_2022' to classify a list of food items and extract only the vegetables. (e.g., 'Use vegetable_classifier_2022 to get a list of the vegetables in this grocery list').\n"
45
+ "- Use 'excel_food_sales_sum' to extract total food sales from excel files. (e.g., 'Use excel_food_sales_sum to calculate the total food sales').\n"
46
+ "Do NOT guess or make up answers. If a tool cannot provide the answer, truthfully respond that you were unable to find the information.\n"
47
+ ),
48
+ "suffix": (
49
+ "Use the tools to research or calculate the answer.\n"
50
+ "If a tool fails, explain the reason for the failure instead of hallucinating an answer.\n"
51
+ "Provide concise and direct answers as requested in the questions. Do not add extra information unless explicitly asked for.\n"
52
+ "For example, if asked for a number, return only the number. If asked for a list, return only the list.\n"
53
+ ),
54
+ }
55
+
56
  agent = initialize_agent(
57
  tools=tools,
58
  llm=llm,
59
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
60
+ handle_parsing_errors=True,
61
+ verbose=True,
62
+ max_iterations=10,
63
+ max_execution_time=60,
64
+ agent_kwargs=agent_kwargs # Place the agent_kwargs here
65
  )
66
 
67
+ return agent
68
+
69
+
70
  '''
71
  agent = initialize_agent(
72
  tools=tools,
local_benchmark.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from agent.gaia_agent import create_langchain_agent
2
+
3
+ questions = [
4
+ "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?",
5
+ "In the video https://www.youtube.com/watch?v=u1xXCYZ4VYM, what is the highest number of bird species to be on screen at once?",
6
+ "Reverse the string 'etisoppo eht'.",
7
+ "What country had the least number of athletes at the 1928 Summer Olympics? Return the IOC country code.",
8
+ "From the chessboard image at path 'chess_1.png', what is the best move?",
9
+ "The attached Excel file contains food and drink sales. What are the total sales for food (excluding drinks)?",
10
+ "Give me a comma-separated, alphabetized list of botanical vegetables from this: milk, eggs, flour, plums, lettuce, celery, broccoli, bell pepper, zucchini.",
11
+ "Where were the Vietnamese specimens described by Kuznetzov in Nedoshivina’s 2010 paper eventually deposited? (City name only.)",
12
+ "What is the name of the novel where a Martian child grows up on Earth and founds a church?",
13
+ "Summarize the Wikipedia page on 'Taishō Tamai'."
14
+ ]
15
+
16
+ agent = create_langchain_agent()
17
+
18
+ print("Running local benchmark...")
19
+
20
+ for idx, question in enumerate(questions):
21
+ print(f"\nQUESTION {idx + 1}: {question}")
22
+ try:
23
+ answer = agent.invoke({"input": question})
24
+ print("ANSWER:", answer)
25
+ except Exception as e:
26
+ print("❌ Error:", e)