Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiMedQA medical evaluation suite #2524

Merged
merged 8 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/helm/benchmark/presentation/run_specs_multimed_qa.conf
akashc1 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
entries: [
# MedQA: USMLE QA multiple choice questions with 4-5 choices
{description: "med_qa:model=text,max_train_instances=0", priority: 1}

# MedMCQA: AIIMS/NEET QA multiple choice questions with 4 choices
{description: "med_mcqa:model=text,max_train_instances=0", priority: 1}

# PubMedQA: biomedical literature Q + Context + A yes/no/maybe + long answer questions
{description: "pubmed_qa:model=text,max_train_instances=0", priority: 1}

# MMLU: exam questions QA multiple choice with 4 choices
# NOTE: this is the subset used in the MultiMedQA dataset in the MedPaLM works
{description: "mmlu:model=text,subject=anatomy,max_train_instances=0", priority: 1}
{description: "mmlu:model=text,subject=clinical_knowledge,max_train_instances=0", priority: 1}
{description: "mmlu:model=text,subject=college_medicine,max_train_instances=0", priority: 1}
{description: "mmlu:model=text,subject=medical_genetics,max_train_instances=0", priority: 1}
{description: "mmlu:model=text,subject=professional_medicine,max_train_instances=0", priority: 1}
{description: "mmlu:model=text,subject=college_biology,max_train_instances=0", priority: 1}

LiveQA: consumer health questions with librarian-generated reference answers. QA long answer
akashc1 marked this conversation as resolved.
Show resolved Hide resolved
{description: "live_qa:model=text", priority: 1}

MedicationQA: consumer medicaiton questions with reference answers. QA long answer
akashc1 marked this conversation as resolved.
Show resolved Hide resolved
{description: "medication_qa:model=text", priority: 1}
]
51 changes: 49 additions & 2 deletions src/helm/benchmark/run_specs/classic_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from helm.benchmark.runner import get_benchmark_output_path
from helm.benchmark.scenarios.scenario import ScenarioSpec, get_scenario_cache_path
from helm.common.hierarchical_logger import hlog, htrack
from helm.common.gpu_utils import get_torch_device_name
akashc1 marked this conversation as resolved.
Show resolved Hide resolved


@run_spec_function("bbq")
Expand Down Expand Up @@ -1116,7 +1117,7 @@ def get_med_mcqa_spec() -> RunSpec:
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(),
groups=["MedMCQA"],
groups=["MedMCQA", "MultiMedQA"],
akashc1 marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down Expand Up @@ -1158,7 +1159,53 @@ def get_pubmed_qa_spec() -> RunSpec:
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(),
groups=["pubmed_qa"],
groups=["pubmed_qa", "MultiMedQA"],
)


@run_spec_function("live_qa")
def get_liveqa_spec() -> RunSpec:
akashc1 marked this conversation as resolved.
Show resolved Hide resolved
scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.liveqa_scenario.LiveQAScenario")

adapter_spec = get_generation_adapter_spec(
instructions="Please answer the following consumer health question.",
input_noun="Question",
output_noun="Answer",
max_train_instances=0,
max_tokens=512,
)

return RunSpec(
name="live_qa",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_summarization_metric_specs(
{"task": "live_qa", "device": get_torch_device_name()},
),
groups=["LiveQA", "MultiMedQA"],
akashc1 marked this conversation as resolved.
Show resolved Hide resolved
)


@run_spec_function("medication_qa")
def get_medicationqa_spec() -> RunSpec:
akashc1 marked this conversation as resolved.
Show resolved Hide resolved
scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.medication_qa_scenario.MedicationQAScenario")

adapter_spec = get_generation_adapter_spec(
instructions="Please answer the following consumer health question.",
input_noun="Question",
output_noun="Answer",
max_train_instances=0,
max_tokens=512,
)

return RunSpec(
name="medication_qa",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_summarization_metric_specs(
{"task": "medication_qa", "device": get_torch_device_name()},
),
groups=["MedicationQA", "MultiMedQA"],
)


Expand Down
4 changes: 2 additions & 2 deletions src/helm/benchmark/run_specs/lite_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_mmlu_spec(subject: str, method: str = ADAPT_MULTIPLE_CHOICE_JOINT) -> Ru
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(),
groups=["mmlu"],
groups=["mmlu", "MultiMedQA"],
)


Expand Down Expand Up @@ -271,7 +271,7 @@ def get_med_qa_spec() -> RunSpec:
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(),
groups=["med_qa"],
groups=["med_qa", "MultiMedQA"],
)


Expand Down
94 changes: 94 additions & 0 deletions src/helm/benchmark/scenarios/live_qa_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
from typing import List
from xml.etree.ElementTree import Element
import xml.etree.ElementTree as ET

from helm.common.general import ensure_file_downloaded
from .scenario import CORRECT_TAG, TEST_SPLIT, Input, Instance, Output, Reference, Scenario


class LiveQAScenario(Scenario):
"""
TREC-2017 LiveQA: Medical Question Answering Task

The LiveQA'17 medical task focuses on consumer health question answering.
Please refer to the original paper for more information about the constructed datasets and the LiveQA Track:
https://trec.nist.gov/pubs/trec26/papers/Overview-QA.pdf

Paper citation:

@inproceedings{LiveMedQA2017,
author = {Asma {Ben Abacha} and Eugene Agichtein and Yuval Pinter and Dina Demner{-}Fushman},
title = {Overview of the Medical Question Answering Task at TREC 2017 LiveQA},
booktitle = {TREC 2017},
year = {2017}
}
"""

SOURCE_REPO_URL = "https://github.com/abachaa/LiveQA_MedicalTask_TREC2017/master/TestDataset/"
FILENAME = "TREC-2017-LiveQA-Medical-Test-Questions-w-summaries.xml"

name = "live_qa"
description = "TREC-2017 LiveQA: Medical Question Answering Task"
tags = ["knowledge", "generation", "question_answering", "biomedical"]

def download_liveqa(self, path: str):
"""Download the XML file containing the questions & reference answers"""
ensure_file_downloaded(
source_url=os.path.join(self.SOURCE_REPO_URL, self.FILENAME),
target_path=os.path.join(path, self.FILENAME),
unpack=False,
)

@staticmethod
def remove_whitespace(s: str) -> str:
"""Just remove all whitespace from a string"""
return " ".join(s.strip().split())

@staticmethod
def _extract_question_id(element: Element):
return element.attrib["qid"]

@classmethod
def _extract_question(cls, element: Element) -> str:
"""Given an XML Element representing a question, extract just the question as text"""
return cls.remove_whitespace(element.find("NLM-Summary").text) # type: ignore

@classmethod
def _extract_answers(cls, element: Element) -> List[str]:
"""Given an XML Element representing a question, extract the reference answers"""
answers = []
for answer in element.iter("ANSWER"):
answers.append(cls.remove_whitespace(answer.text)) # type: ignore

return answers

def process_xml(self, base_path: str) -> List[Instance]:
"""Parse the XMLs into question-answer(s) pairs"""
xml_path = os.path.join(base_path, self.FILENAME)
tree = ET.parse(xml_path)
root = tree.getroot()

instances = []
for question_root in root:
# get the actual question and question ID
id = self._extract_question_id(question_root)
question = Input(self._extract_question(question_root))

# parse out the reference answers
answers = self._extract_answers(question_root)
references = [Reference(Output(answer), tags=[CORRECT_TAG]) for answer in answers]

# stitch it all together
instances.append(Instance(question, references, split=TEST_SPLIT, id=id))

return instances

def get_instances(self, output_path: str) -> List[Instance]:
"""entrypoint to creating this scenario's instances"""
# get the dataset
self.download_liveqa(output_path)

# get the instances by parsing the XML
instances = self.process_xml(output_path)
return instances
60 changes: 60 additions & 0 deletions src/helm/benchmark/scenarios/medication_qa_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
from typing import List

import pandas as pd

from helm.common.general import ensure_file_downloaded

from .scenario import CORRECT_TAG, TEST_SPLIT, Input, Instance, Output, Reference, Scenario


class MedicationQAScenario(Scenario):
"""
The gold standard corpus for medication question answering introduced in the MedInfo 2019 paper
"Bridging the Gap between Consumers’ Medication Questions and Trusted Answers":
http://ebooks.iospress.nl/publication/51941

This dataset has consumer questions, as opposed to very clinical questions.

Paper citation:

@inproceedings{BenAbacha:MEDINFO19,
author = {Asma {Ben Abacha} and Yassine Mrabet and Mark Sharp and
Travis Goodwin and Sonya E. Shooshan and Dina Demner{-}Fushman},
title = {Bridging the Gap between Consumers’ Medication Questions and Trusted Answers},
booktitle = {MEDINFO 2019},
year = {2019},
}
"""

SOURCE_REPO_URL = "https://github.com/abachaa/Medication_QA_MedInfo2019/raw/master/"
FILENAME = "MedInfo2019-QA-Medications.xlsx"

name = "medication_qa"
description = "MedInfo 2019 MedicationQA medication question answering task"
tags = ["knowledge", "generation", "question_answering", "biomedical"]

def download_medication_qa(self, path: str):
"""download the .xlsx spreadsheet containing the question-answer pairs"""
ensure_file_downloaded(
source_url=os.path.join(self.SOURCE_REPO_URL, self.FILENAME),
target_path=os.path.join(path, self.FILENAME),
unpack=False,
)

def get_instances(self, output_path: str) -> List[Instance]:
self.download_medication_qa(output_path)
data_path = os.path.join(output_path, self.FILENAME)

data = pd.read_excel(data_path)
data = data[~data.Answer.isna()] # remove rows missing answers
instances = [
Instance(
input=Input(row.Question),
references=[Reference(Output(row.Answer), tags=[CORRECT_TAG])],
split=TEST_SPLIT,
)
for _, row in data.iterrows()
]

return instances
Loading