Skip to content

Commit

Permalink
Add MultipanelVQA and POPE vision-language scenarios (#2517)
Browse files Browse the repository at this point in the history
  • Loading branch information
ImKeTT committed Mar 31, 2024
1 parent e74c014 commit b29fb5e
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 0 deletions.
48 changes: 48 additions & 0 deletions src/helm/benchmark/run_specs/vlm_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,54 @@ def get_bingo_spec(subject: str) -> RunSpec:
)


@run_spec_function("multipanelvqa")
def get_multipanelvqa_spec(subject: str, question_type: str) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.vision_language.multipanelvqa_scenario.MultipanelVQAScenario",
args={"subject": subject, "question_type": question_type},
)

adapter_spec: AdapterSpec
if question_type == "open":
adapter_spec = get_short_answer_generation_adapter_spec()
elif question_type == "multiple-choice":
adapter_spec = get_multiple_choice_joint_adapter_spec(
input_noun=None, output_noun="Answer", max_train_instances=0
)
else:
raise ValueError(f"Invalid question type: {question_type}")

metric_specs: List[MetricSpec] = get_exact_match_metric_specs()
run_spec_name: str = "multipanelvqa"
return RunSpec(
name=f"{run_spec_name}:subject={subject},question_type={question_type}",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=[run_spec_name],
)


@run_spec_function("pope")
def get_pope_spec() -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.vision_language.pope_scenario.POPEScenario",
)
adapter_spec: AdapterSpec = get_multiple_choice_joint_adapter_spec(
input_noun=None, output_noun="Answer", max_train_instances=0
)
metric_specs: List[MetricSpec] = get_exact_match_metric_specs()

run_spec_name: str = "pope"
return RunSpec(
name=run_spec_name,
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=[run_spec_name],
)


@run_spec_function("heim_human_eval")
def get_heim_human_eval_spec(question_type: str) -> RunSpec:
scenario_spec = ScenarioSpec(
Expand Down
169 changes: 169 additions & 0 deletions src/helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import os.path
from typing import Dict, List

from datasets import load_dataset
from tqdm import tqdm

from helm.benchmark.scenarios.scenario import (
CORRECT_TAG,
TEST_SPLIT,
Instance,
Input,
Output,
Reference,
Scenario,
)
from helm.common.media_object import MediaObject, MultimediaObject
from helm.common.general import ensure_directory_exists


class MultipanelVQAScenario(Scenario):
"""
Muffin or Chihuahua? Challenging Large Vision-Language Models with Multipanel VQA
We introduce Multipanel Visual Question Answering (MultipanelVQA), a novel benchmark
comprising 6,600 triplets of questions, answers, and multipanel images that specifically
challenge models in comprehending multipanel images. Our evaluation shows that questions in
the MultipanelVQA benchmark pose significant challenges to the state-of-the-art Large Vision
Language Models (LVLMs) tested, even though humans can attain approximately 99% accuracy on
these questions. There are two types of questions in two different situations in the
MultipanelVQA benchmark: multiple-choice or open-ended generation paired with real-world or
synthetic images. We use the multiple-choice metrics and the exact match metric for two
different question-answering types, respectively.
@article{fan2024muffin,
title={Muffin or Chihuahua? Challenging Large Vision-Language Models with Multipanel VQA},
author={Fan, Yue and Gu, Jing and Zhou, Kaiwen and Yan, Qianqi and Jiang, Shan and
Kuo, Ching-Chen and Guan, Xinze and Wang, Xin Eric},
journal={arXiv preprint arXiv:2401.15847},
year={2024}
}
Paper: https://arxiv.org/abs/2401.15847
"""

MULTIPANELVQA_HUGGINGFACE_DATASET_NAME: Dict[str, str] = {
"synthetic": "yfan1997/MultipanelVQA_synthetic",
"real-world": "yfan1997/MultipanelVQA_real-world",
}

SUBJECTS: List[str] = ["synthetic", "real-world"]

name = "multipanelvqa"
description = "Evaluate multimodal models on ([paper](https://arxiv.org/abs/2401.15847))."
tags = ["vision-language"]

def __init__(self, subject: str, question_type: str):
super().__init__()
assert subject in self.SUBJECTS, f"Invalid subject: {subject}"
self._subject: str = subject

assert question_type in ["multiple-choice", "open"], f"Invalid question type: {question_type}"
self._question_type: str = question_type

def convert_text_answer_to_option(self, text_answer: str, question: str):
option_answer: str
# Some answer may have a ')' with it
if len(text_answer) <= 3:
option_answer = text_answer[0]
else:
# There are examples where the answer is the text answer
# instead of an option
for line in question.split("\n"):
if text_answer in line:
option_answer = line[0]
break
return option_answer.upper()

def split_options_and_question(self, original_question: str):
question_and_options: List[str] = [item.strip().lower() for item in original_question.split("\n")]
last_append_phrase: str = "(please select one)"
question: str = question_and_options[0]
options: List[str] = []
if len(question_and_options) >= 6:
for item in question_and_options[1:]:
if last_append_phrase in item:
break
options.append(item[3:])
elif len(question_and_options) == 5:
for item in question_and_options[1:]:
if last_append_phrase in item:
item = item[: -len(last_append_phrase)]
options.append(item[3:])
return question, options

def get_instances(self, output_path: str) -> List[Instance]:
images_path: str = os.path.join(output_path, "images")
ensure_directory_exists(images_path)

# There is only the test split in Unicorn benchmark
instances: List[Instance] = []
# Process the test set
# Two open-ended generation instances and
# one multi-choice generation instance per row
for image_index, row in enumerate(
tqdm(
load_dataset(
self.MULTIPANELVQA_HUGGINGFACE_DATASET_NAME[self._subject],
split=TEST_SPLIT,
cache_dir=output_path,
)
)
):
# Download the image
# Save the image locally
image_path: str = os.path.join(images_path, f"{image_index}.png")
if not os.path.exists(image_path):
row["image"].save(image_path)

# Add the references
references: List[Reference] = []
question: str
answer: str
content: List[MediaObject]
if self._question_type == "open":
question_1: str = row["question_1"]
question_2: str = row["question_2"]
answer_1: str = row["answer_1"]
answer_2: str = row["answer_2"]
for answer, question in zip([answer_1, answer_2], [question_1, question_2]):
content = [
MediaObject(location=image_path, content_type="image/png"),
MediaObject(text=question, content_type="text/plain"),
]
instances.append(
Instance(
Input(multimedia_content=MultimediaObject(content)),
references=[Reference(Output(text=answer), tags=[CORRECT_TAG])],
split=TEST_SPLIT,
)
)
else:
options: List[str]
original_question: str = row["question_3"]
question, options = self.split_options_and_question(original_question)
answer = row["answer_3"].strip()
answer = self.convert_text_answer_to_option(answer, original_question)
# The given correct answer is a letter, but we need an index
correct_answer_index: int = ord(answer) - ord("A")
# The options are originally appended to the question

for i, option in enumerate(options):
reference: Reference
is_correct: bool = i == correct_answer_index
reference = Reference(Output(text=option), tags=[CORRECT_TAG] if is_correct else [])
references.append(reference)

content = [
MediaObject(location=image_path, content_type="image/png"),
MediaObject(text=question, content_type="text/plain"),
]
instances.append(
Instance(
Input(multimedia_content=MultimediaObject(content)),
references=references,
split=TEST_SPLIT,
)
)

return instances
104 changes: 104 additions & 0 deletions src/helm/benchmark/scenarios/vision_language/pope_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import List
import os

from helm.benchmark.scenarios.scenario import (
CORRECT_TAG,
TEST_SPLIT,
Instance,
Input,
Output,
Reference,
Scenario,
)
from datasets import load_dataset
from tqdm import tqdm
from helm.common.media_object import MediaObject, MultimediaObject
from helm.common.general import ensure_directory_exists


class POPEScenario(Scenario):
"""
POPE dataset
Despite the promising progress on Large Vision-Language Models (LVLMs), we find that LVLMs suffer from
the hallucination problem, i.e. they tend to generate objects that are inconsistent with the target
images in the descriptions. To investigate it, this work presents the first systematic study on object
hallucination of LVLMs based on VQAv2 benchmark. We find that: objects that frequently occur in the
visual instructions or co-occur with the image objects, are obviously prone to be hallucinated by LVLMs.
In POPE, images from VQAv2 are matched with questions asking the appearance of certain objects in the
image. We use the exact match metric for model evaluation on POPE.
@inproceedings{li2023evaluating,
title={Evaluating Object Hallucination in Large Vision-Language Models},
author={Li, Yifan and Du, Yifan and Zhou, Kun and Wang, Jinpeng and Zhao, Wayne Xin and Wen, Ji-Rong},
booktitle={Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing},
pages={292--305},
year={2023}
}
Paper: https://aclanthology.org/2023.emnlp-main.20/
"""

POPE_HUGGINGFACE_DATASET_NAME: str = "lmms-lab/POPE"

name = "pope"
description = (
"Open-ended questions about hallucination images ([paper](https://aclanthology.org/2023.emnlp-main.20/))."
)
tags = ["vision-language", "visual question answering"]
options: List[str] = ["Yes", "No"]

def get_label_from_answer(self, answer: str):
label: str
if answer == "yes":
label = "A"
elif answer == "no":
label = "B"
else:
raise NotImplementedError(f"Invalid answer: {answer}")
return label

def get_instances(self, output_path: str) -> List[Instance]:
images_path: str = os.path.join(output_path, "images")
ensure_directory_exists(images_path)
instances: List[Instance] = []
for row in tqdm(
load_dataset(
self.POPE_HUGGINGFACE_DATASET_NAME,
split=TEST_SPLIT,
cache_dir=output_path,
)
):
image_source: str = row["image_source"]
# Save the image locally
image_path: str = os.path.join(output_path, f"{image_source}.jpg")
if not os.path.exists(image_path):
row["image"].save(image_path)

question: str = row["question"]
answer: str = row["answer"]
references: List[Reference] = []

answer = self.get_label_from_answer(answer)
# The given correct answer is a letter, but we need an index
correct_answer_index: int = ord(answer) - ord("A")
# The options are originally appended to the question

for i, option in enumerate(self.options):
reference: Reference
is_correct: bool = i == correct_answer_index
reference = Reference(Output(text=option), tags=[CORRECT_TAG] if is_correct else [])
references.append(reference)

content = [
MediaObject(location=image_path, content_type="image/jpeg"),
MediaObject(text=question, content_type="text/plain"),
]
instances.append(
Instance(
Input(multimedia_content=MultimediaObject(content)),
references=references,
split=TEST_SPLIT,
)
)

return instances
36 changes: 36 additions & 0 deletions src/helm/benchmark/static/schema_vlm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ run_groups:
- image2structure
- unicorn
- bingo
- multipanelvqa
- pope

- name: heim_human_eval
display_name: HEIM Human Eval Scenario
Expand Down Expand Up @@ -413,6 +415,40 @@ run_groups:
who: Human experts
when: "2023"
language: English, Chinese, Japanese, etc.

- name: multipanelvqa
display_name: MultipanelVQA
description: Question about real-world or synthetic multipanel images for evaluating multi-panel image reasoning ability
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: exact_match
main_split: test
taxonomy:
task: short answer or multiple-choice question answering
what: Real-world or synthetic multipanel images
who: Human experts
when: "2024"
language: English

- name: pope
display_name: POPE
description: Open-ended questions about object appearance in real-world images for evaluating hallucination behaviour
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: exact_match
main_split: test
taxonomy:
task: short answer question answering
what: Real-world images
who: Human experts
when: "2023"
language: English

- name: image2latex
display_name: Image2LaTeX
Expand Down

0 comments on commit b29fb5e

Please sign in to comment.