Skip to content

Commit

Permalink
Add experimental call center summarization metrics (#2961)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai committed Aug 30, 2024
1 parent eb3313b commit bc9f007
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 0 deletions.
160 changes: 160 additions & 0 deletions src/helm/benchmark/annotation/call_center_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,163 @@ def annotate(self, request_state: RequestState) -> Any:
if expected_key not in annotator_response_parsed:
raise Exception(f"Malformed annotator response: {annotator_response_text}")
return annotator_response_parsed


class CallCenterSummarizationPairwiseComparisonAnnotator(Annotator):
"""Annotator for call center summarization with pairwise comparison."""

name = "call_center_summarization_pairwise_comparison"

PROMPT_TEMPLATE = """\
Given a call transcript and two different summaries of the call transcript, select your preferred summary, which can be subjective, considering the criteria below. Also provide a one-sentence reasoning for your selection.
### Criteria
Faithfulness: Can all the information expressed by the summary can be inferred from the source?
Relevance: To what extent the summary include only important information from the source?
Coherence: Does the summary organize the relevant information into a well-structured summary?
### Call Transcript
{{CALL_TRANSCRIPT}}
### Summary A
{{SUMMARY_A}}
### Summary B
{{SUMMARY_B}}
### Task
Output only a JSON object with the following format:
{"reasoning": "Reasoning", "selected": "A" or "B"}
""" # noqa: E501

def __init__(self, auto_client: AutoClient):
super().__init__()
self._auto_client = auto_client

def annotate(self, request_state: RequestState) -> Any:
assert request_state.result
assert len(request_state.result.completions) == 1
call_transcript = request_state.instance.input.text
summary = request_state.result.completions[0].text.strip()
assert len(request_state.instance.all_correct_references) == 1
reference_summary = request_state.instance.all_correct_references[0].output.text
if not summary.strip():
hlog("Returning 0 scores due to empty response")
return {"faithfulness": 0, "relevance": 0, "coherence": 0}
annotator_prompt = (
textwrap.dedent(CallCenterSummarizationPairwiseComparisonAnnotator.PROMPT_TEMPLATE)
.replace("{{CALL_TRANSCRIPT}}", call_transcript)
.replace("{{SUMMARY_B}}", reference_summary)
.replace("{{SUMMARY_A}}", summary)
)
print(annotator_prompt)
annotator_request = Request(
model="openai/gpt-4o-2024-08-06",
model_deployment="openai/gpt-4o-2024-08-06",
prompt=annotator_prompt,
temperature=0.0,
max_tokens=256,
)
annotator_response = self._auto_client.make_request(annotator_request)
if not annotator_response.success:
raise Exception(f"Annotation request failed: {annotator_response.error}")
assert len(annotator_response.completions) == 1
annotator_response_text = annotator_response.completions[0].text
# OpenAI models like to surround JSON objects with ```json and ``` Markdown formatting.
# This strips everything outside the outermost {} brackets.
json_start_index = annotator_response_text.find("{")
json_end_index = annotator_response_text.rfind("}")
if json_start_index < 0 or json_end_index < 0:
raise Exception(f"Malformed annotator response: {annotator_response_text}")
annotator_response_json = annotator_response_text[json_start_index : json_end_index + 1]
try:
annotator_response_parsed = json.loads(annotator_response_json)
except JSONDecodeError:
raise Exception(f"Malformed annotator response: {annotator_response_text}")
for expected_key in ["reasoning", "selected"]:
if expected_key not in annotator_response_parsed:
raise Exception(f"Malformed annotator response: {annotator_response_text}")
score = 0.0
print(annotator_response_parsed)
selected = annotator_response_parsed["selected"].strip()
if selected == "B":
score = 0.0
elif selected == "A":
score = 1.0
else:
raise Exception(f"Malformed annotator response: {annotator_response_text}")
return {"reasoning": annotator_response_parsed["reasoning"], "score": score}


class CallCenterSummarizationKeyPointsRecallAnnotator(Annotator):
"""Annotator for call center summarization with key point recall."""

name = "call_center_summarization_key_points_recall"

PROMPT_TEMPLATE = """\
Given a call transcript, a list of key points and a summary, determine which key points are included in the summary.
### Call Transcript
{{CALL_TRANSCRIPT}}
### Key Points
{{KEY_POINTS}}
### Summary
{{SUMMARY}}
### Task
Output only a JSON array of booleans, where each boolean indicates if the corresponding key point was included in the summary.
""" # noqa: E501

def __init__(self, auto_client: AutoClient):
super().__init__()
self._auto_client = auto_client

def annotate(self, request_state: RequestState) -> Any:
assert request_state.result
assert len(request_state.result.completions) == 1
call_transcript = request_state.instance.input.text
summary = request_state.result.completions[0].text.strip()
key_points = "\n".join(
[f"- {reference.output.text}" for reference in request_state.instance.all_correct_references]
)
if not summary.strip():
hlog("Returning 0 scores due to empty response")
return {"faithfulness": 0, "relevance": 0, "coherence": 0}
annotator_prompt = (
textwrap.dedent(CallCenterSummarizationKeyPointsRecallAnnotator.PROMPT_TEMPLATE)
.replace("{{CALL_TRANSCRIPT}}", call_transcript)
.replace("{{KEY_POINTS}}", key_points)
.replace("{{SUMMARY}}", summary)
)
print(annotator_prompt)
annotator_request = Request(
model="openai/gpt-4o-2024-08-06",
model_deployment="openai/gpt-4o-2024-08-06",
prompt=annotator_prompt,
temperature=0.0,
max_tokens=256,
)
annotator_response = self._auto_client.make_request(annotator_request)
if not annotator_response.success:
raise Exception(f"Annotation request failed: {annotator_response.error}")
assert len(annotator_response.completions) == 1
annotator_response_text = annotator_response.completions[0].text
# OpenAI models like to surround JSON objects with ```json and ``` Markdown formatting.
# This strips everything outside the outermost [] brackets.
json_start_index = annotator_response_text.find("[")
json_end_index = annotator_response_text.rfind("]")
if json_start_index < 0 or json_end_index < 0:
raise Exception(f"Malformed annotator response: {annotator_response_text}")
annotator_response_json = annotator_response_text[json_start_index : json_end_index + 1]
try:
annotator_response_parsed = json.loads(annotator_response_json)
except JSONDecodeError:
raise Exception(f"Malformed annotator response: {annotator_response_text}")
if not len(annotator_response_parsed):
raise Exception(f"Malformed annotator response: {annotator_response_text}")
score = sum([1.0 if elem else 0.0 for elem in annotator_response_parsed]) / len(annotator_response_parsed)
print(annotator_response_parsed)
return {"key_points_found": json.dumps(annotator_response_parsed), "score": score}
22 changes: 22 additions & 0 deletions src/helm/benchmark/metrics/annotation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,28 @@ def evaluate_generation(
return stats


class AnnotationNumericMetric(Metric):
"""Numeric metric for numbers produced by annotators.
Expects the annotation with the given annotator name and key to be a number."""

def __init__(self, annotator_name: str, key: str):
super().__init__()
self.annotator_name = annotator_name
self.key = key

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
assert request_state.annotations
score = request_state.annotations[self.annotator_name][self.key]
return [Stat(MetricName(f"annotation_{self.annotator_name}_{self.key}")).add(score)]


class AnnotationLikertScaleMetric(Metric):
"""Numeric metric for labels produced by annotators.
Expand Down
88 changes: 88 additions & 0 deletions src/helm/benchmark/run_specs/call_center_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,91 @@ def get_call_center_summarization_spec(subset: str = "summarization") -> RunSpec
annotators=annotator_specs,
groups=[group],
)


@run_spec_function("call_center_summarization_pairwise_comparison")
def get_call_center_summarization_pairwise_comparison_spec() -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.call_center_scenario.CallCenterSummarizationPairwiseComparisonScenario",
)

instructions = "Summarize the call transcript in under 10 sentences."

adapter_spec = AdapterSpec(
method=ADAPT_GENERATION,
instructions=instructions,
input_prefix="### Call Transcript\n",
input_suffix="",
output_prefix="",
output_suffix="",
max_train_instances=0,
temperature=0.0,
max_tokens=512,
num_outputs=1,
)

annotator_specs = annotator_specs = [
AnnotatorSpec(
class_name="helm.benchmark.annotation.call_center_annotator.CallCenterSummarizationPairwiseComparisonAnnotator" # noqa: E501
)
]

metric_specs = get_basic_metric_specs([]) + [
MetricSpec(
class_name="helm.benchmark.metrics.annotation_metrics.AnnotationNumericMetric",
args={"annotator_name": "call_center_summarization_pairwise_comparison", "key": "score"},
)
]

return RunSpec(
name="call_center_summarization_pairwise_comparison",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
annotators=annotator_specs,
groups=["call_center_summarization_pairwise_comparison"],
)


@run_spec_function("call_center_summarization_key_points_recall")
def get_call_center_summarization_key_points_recall_spec() -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.call_center_scenario.CallCenterSummarizationKeyPointsRecallScenario",
)

instructions = "Summarize the call transcript in under 10 sentences."

adapter_spec = AdapterSpec(
method=ADAPT_GENERATION,
instructions=instructions,
input_prefix="### Call Transcript\n",
input_suffix="",
output_prefix="",
output_suffix="",
max_train_instances=0,
temperature=0.0,
max_tokens=512,
num_outputs=1,
)

annotator_specs = annotator_specs = [
AnnotatorSpec(
class_name="helm.benchmark.annotation.call_center_annotator.CallCenterSummarizationKeyPointsRecallAnnotator"
)
]

metric_specs = get_basic_metric_specs([]) + [
MetricSpec(
class_name="helm.benchmark.metrics.annotation_metrics.AnnotationNumericMetric",
args={"annotator_name": "call_center_summarization_key_points_recall", "key": "score"},
)
]

return RunSpec(
name="call_center_summarization_key_points_recall",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
annotators=annotator_specs,
groups=["call_center_summarization_key_points_recall"],
)
50 changes: 50 additions & 0 deletions src/helm/benchmark/scenarios/call_center_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import List

from helm.benchmark.scenarios.scenario import (
CORRECT_TAG,
Output,
Reference,
Scenario,
Instance,
TEST_SPLIT,
Expand Down Expand Up @@ -32,3 +35,50 @@ def get_instances(self, output_path: str) -> List[Instance]:
instance = Instance(input=input, references=[], split=TEST_SPLIT)
instances.append(instance)
return instances


class CallCenterSummarizationPairwiseComparisonScenario(Scenario):
"""Call center summarization."""

name = "call_center_summarization_pairwise_comparison"
description = "Call center summarization."
tags = ["call_center"]

def get_instances(self, output_path: str) -> List[Instance]:
cache_dir = os.path.join(output_path, "data")
ensure_directory_exists(cache_dir)
dataset = datasets.load_dataset(
"yifanmai/call-center", "summarization_with_annotations", split="test", cache_dir=cache_dir
)
instances: List[Instance] = []
for row in dataset:
input = Input(text=row["transcript"])
reference = Reference(output=Output(text=row["gpt-4o-mini-2024-07-18_summary"]), tags=[CORRECT_TAG])
instance = Instance(input=input, references=[reference], split=TEST_SPLIT)
instances.append(instance)
return instances


class CallCenterSummarizationKeyPointsRecallScenario(Scenario):
"""Call center summarization."""

name = "call_center_summarization_key_points_recall"
description = "Call center summarization."
tags = ["call_center"]

def get_instances(self, output_path: str) -> List[Instance]:
cache_dir = os.path.join(output_path, "data")
ensure_directory_exists(cache_dir)
dataset = datasets.load_dataset(
"yifanmai/call-center", "summarization_with_annotations", split="test", cache_dir=cache_dir
)
instances: List[Instance] = []
for row in dataset:
input = Input(text=row["transcript"])
references = [
Reference(output=Output(text=key_point), tags=[CORRECT_TAG])
for key_point in row["gpt-4o-mini-2024-07-18_key_points"]
]
instance = Instance(input=input, references=references, split=TEST_SPLIT)
instances.append(instance)
return instances
Loading

0 comments on commit bc9f007

Please sign in to comment.