#Trouble extracting multiple choice answer from discussion

1 messages · Page 1 of 1 (latest)

forest jacinth
#
# Config list for running locally using ollama
config_list = [
  {
    "model": "llama3.2",
    "base_url": "http://localhost:11434/v1",
    "api_key": "ollama",
    "price" : [0, 0]
  }
]

config_list_output_format = [
  {
    "model": "llama3.2",
    "base_url": "http://localhost:11434/v1",
    "api_key": "ollama",
    "price" : [0, 0],
    "response_format": AnswerIdx
  }
]

class AnswerIdx(BaseModel):
    idx: Literal["A", "B", "C", "D", "E"]

    def format(self) -> str:
        return self.idx

class QA_Agent(ABC):
    @abstractmethod
    def answer_question(self, question: str) -> List[ag.ChatResult]:
        pass
#
class SimpleAgent(QA_Agent):
    def __init__(self):
        self.user_proxy = ag.UserProxyAgent(
            "user_proxy",
            human_input_mode="NEVER",
            code_execution_config=False,
        )

        self.medical_qa_agent = ag.AssistantAgent(
            "medical_qa_agent",
            llm_config={"config_list": config_list},
            system_message="You are an expert medical QA agent. Give a detailed answer to the multiple choice medical question. Please clearly state which answer index (e.g. A, B, C, D, or E) the correct answer corresponds to.",
        )

        self.answer_reporter = ag.AssistantAgent(
            "answer_reporter",
            llm_config={
                "config_list": config_list_output_format,
                },
            system_message="Report the answer to the multiple choice question (e.g. A, B, C, D, or E) that was decided on previously.",
        )

    def answer_question(self, question: str) -> List[ag.ChatResult]:
        response = self.user_proxy.initiate_chats(
            [
                {
                    "recipient": self.medical_qa_agent,
                    "message": question,
                    "max_turns": 1,
                    "silent": True
                },
                {
                    "recipient": self.answer_reporter,
                    "message": "Report the answer to the multiple choice question (e.g. A, B, C, D, or E) that was decided on previously.",
                    "max_turns": 1,
                    "silent": True
                }
            ]
        )

        return response
#
# Run the agent on the dataset
def test_agent(agent: QA_Agent, data: pd.DataFrame, max_samples: Optional[int] = None) -> pd.DataFrame:
    def create_question(row: pd.Series) -> str:
        question = row["question"]
        options = "\n".join([f"{k}: {v}" for k, v in row["options"].items()])
        return f"{question}\n\nYour options are...\n{options}"

    if max_samples:
        data = data.head(max_samples)

    results = pd.DataFrame(columns=["question", "options", "answer_idx", "chat_history", "predicted_answer"])

    for index, row in tqdm(data.iterrows(), total=len(data)):
        try:
            chat_results = agent.answer_question(create_question(row))
            results.loc[index, "question"] = row["question"]
            results.loc[index, "options"] = str(row["options"])
            results.loc[index, "answer_idx"] = row["answer_idx"]
            results.loc[index, "chat_history"] = str([cr.chat_history for cr in chat_results])
            results.loc[index, "predicted_answer"] = chat_results[-1].summary[0]
        except KeyboardInterrupt:
            print(f"Stopped at index {index}")
            break

    
    print("Correct answers: ", len(results[results["predicted_answer"] == results["answer_idx"]]))
    print("Incorrect answers: ", len(results[results["predicted_answer"] != results["answer_idx"]]))
    
    accuracy = len(results[results["predicted_answer"] == results["answer_idx"]]) / len(results)
    print("Accuracy: ", accuracy)

    return results


results = test_agent(SimpleAgent(), test_data)
#

This simple agent is just a testing thing before I move on to more complex group chat stuff.