# Config list for running locally using ollama
config_list = [
{
"model": "llama3.2",
"base_url": "http://localhost:11434/v1",
"api_key": "ollama",
"price" : [0, 0]
}
]
config_list_output_format = [
{
"model": "llama3.2",
"base_url": "http://localhost:11434/v1",
"api_key": "ollama",
"price" : [0, 0],
"response_format": AnswerIdx
}
]
class AnswerIdx(BaseModel):
idx: Literal["A", "B", "C", "D", "E"]
def format(self) -> str:
return self.idx
class QA_Agent(ABC):
@abstractmethod
def answer_question(self, question: str) -> List[ag.ChatResult]:
pass
#Trouble extracting multiple choice answer from discussion
1 messages · Page 1 of 1 (latest)
class SimpleAgent(QA_Agent):
def __init__(self):
self.user_proxy = ag.UserProxyAgent(
"user_proxy",
human_input_mode="NEVER",
code_execution_config=False,
)
self.medical_qa_agent = ag.AssistantAgent(
"medical_qa_agent",
llm_config={"config_list": config_list},
system_message="You are an expert medical QA agent. Give a detailed answer to the multiple choice medical question. Please clearly state which answer index (e.g. A, B, C, D, or E) the correct answer corresponds to.",
)
self.answer_reporter = ag.AssistantAgent(
"answer_reporter",
llm_config={
"config_list": config_list_output_format,
},
system_message="Report the answer to the multiple choice question (e.g. A, B, C, D, or E) that was decided on previously.",
)
def answer_question(self, question: str) -> List[ag.ChatResult]:
response = self.user_proxy.initiate_chats(
[
{
"recipient": self.medical_qa_agent,
"message": question,
"max_turns": 1,
"silent": True
},
{
"recipient": self.answer_reporter,
"message": "Report the answer to the multiple choice question (e.g. A, B, C, D, or E) that was decided on previously.",
"max_turns": 1,
"silent": True
}
]
)
return response
# Run the agent on the dataset
def test_agent(agent: QA_Agent, data: pd.DataFrame, max_samples: Optional[int] = None) -> pd.DataFrame:
def create_question(row: pd.Series) -> str:
question = row["question"]
options = "\n".join([f"{k}: {v}" for k, v in row["options"].items()])
return f"{question}\n\nYour options are...\n{options}"
if max_samples:
data = data.head(max_samples)
results = pd.DataFrame(columns=["question", "options", "answer_idx", "chat_history", "predicted_answer"])
for index, row in tqdm(data.iterrows(), total=len(data)):
try:
chat_results = agent.answer_question(create_question(row))
results.loc[index, "question"] = row["question"]
results.loc[index, "options"] = str(row["options"])
results.loc[index, "answer_idx"] = row["answer_idx"]
results.loc[index, "chat_history"] = str([cr.chat_history for cr in chat_results])
results.loc[index, "predicted_answer"] = chat_results[-1].summary[0]
except KeyboardInterrupt:
print(f"Stopped at index {index}")
break
print("Correct answers: ", len(results[results["predicted_answer"] == results["answer_idx"]]))
print("Incorrect answers: ", len(results[results["predicted_answer"] != results["answer_idx"]]))
accuracy = len(results[results["predicted_answer"] == results["answer_idx"]]) / len(results)
print("Accuracy: ", accuracy)
return results
results = test_agent(SimpleAgent(), test_data)
This simple agent is just a testing thing before I move on to more complex group chat stuff.