Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace helm-summarize flag --schema-file with --schema-path #2520

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ nav
tags
notes.otl

# HELM config files
run_entries*.conf
schema*.yaml

# For Macs
.DS_Store

Expand Down
9 changes: 6 additions & 3 deletions src/helm/benchmark/presentation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,13 @@ def __post_init__(self):
self.name_to_run_group = {run_group.name: run_group for run_group in self.run_groups}


def read_schema(filename: str) -> Schema:
def get_default_schema_path() -> str:
return resources.files(SCHEMA_YAML_PACKAGE).joinpath(SCHEMA_CLASSIC_YAML_FILENAME)


def read_schema(schema_path: str) -> Schema:
# TODO: merge in model metadata from `model_metadata.yaml`
schema_path = resources.files(SCHEMA_YAML_PACKAGE).joinpath(filename)
hlog(f"Reading schema file {schema_path}...")
with schema_path.open("r") as f:
with open(schema_path, "r") as f:
raw = yaml.safe_load(f)
return dacite.from_dict(Schema, raw)
24 changes: 12 additions & 12 deletions src/helm/benchmark/presentation/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from tqdm import tqdm
from helm.benchmark.model_deployment_registry import get_model_deployment

from helm.benchmark.model_metadata_registry import get_unknown_model_metadata
from helm.common.general import (
write,
Expand All @@ -49,7 +48,7 @@
RunGroup,
Field,
read_schema,
SCHEMA_CLASSIC_YAML_FILENAME,
get_default_schema_path,
BY_GROUP,
THIS_GROUP_ONLY,
NO_GROUPS,
Expand Down Expand Up @@ -295,7 +294,7 @@ def __init__(
release: Optional[str],
suites: Optional[List[str]],
suite: Optional[str],
schema_file: str,
schema_path: str,
output_path: str,
verbose: bool,
num_threads: int,
Expand All @@ -315,7 +314,7 @@ def __init__(
self.suites: List[str]
self.run_suite_paths: List[str]
self.suite: Optional[str] = None
self.schema_file = schema_file
self.schema_path = schema_path
self.release: Optional[str] = None
if suite:
self.suite = suite
Expand All @@ -333,7 +332,7 @@ def __init__(

ensure_directory_exists(self.run_release_path)

self.schema = read_schema(schema_file)
self.schema = read_schema(schema_path)

def read_run(self, run_path: str) -> Run:
"""Load the `Run` object from `run_path`."""
Expand Down Expand Up @@ -361,7 +360,7 @@ def filter_runs_by_visibility(self, runs: List[Run], group: RunGroup) -> List[Ru
if run_group_name not in self.schema.name_to_run_group:
hlog(
f"WARNING: group {run_group_name} mentioned in run spec {run.run_spec.name} "
f"but undefined in {self.schema_file}, skipping"
f"but undefined in {self.schema_path}, skipping"
)
continue
run_group = self.schema.name_to_run_group[run_group_name]
Expand Down Expand Up @@ -629,7 +628,7 @@ def check_metrics_defined(self):
for metric_name, run_spec_names in metric_name_to_run_spec_names.items():
if metric_name not in defined_metric_names:
hlog(
f"WARNING: metric name {metric_name} undefined in {self.schema_file} "
f"WARNING: metric name {metric_name} undefined in {self.schema_path} "
f"but appears in {len(run_spec_names)} run specs, including {run_spec_names[0]}"
)

Expand Down Expand Up @@ -906,7 +905,7 @@ def create_group_table(
matcher = replace(matcher, sub_split=sub_split)
header_field = self.schema.name_to_metric.get(matcher.name)
if header_field is None:
hlog(f"WARNING: metric name {matcher.name} undefined in {self.schema_file}, skipping")
hlog(f"WARNING: metric name {matcher.name} undefined in {self.schema_path}, skipping")
continue
metadata = {
"metric": header_field.get_short_display_name(),
Expand Down Expand Up @@ -1368,10 +1367,9 @@ def main():
"-o", "--output-path", type=str, help="Where the benchmarking output lives", default="benchmark_output"
)
parser.add_argument(
"--schema-file",
"--schema-path",
type=str,
help="File name of the schema to read (e.g., schema_classic.yaml).",
default=SCHEMA_CLASSIC_YAML_FILENAME,
help="Path to the schema file (e.g., schema_classic.yaml).",
)
parser.add_argument(
"--suite",
Expand Down Expand Up @@ -1438,6 +1436,8 @@ def main():
else:
raise ValueError("Exactly one of --release or --suite must be specified.")

schema_path = args.schema_path if args.schema_path else get_default_schema_path()

register_builtin_configs_from_helm_package()
register_configs_from_directory(args.local_path)

Expand All @@ -1446,7 +1446,7 @@ def main():
release=release,
suites=suites,
suite=suite,
schema_file=args.schema_file,
schema_path=schema_path,
output_path=args.output_path,
verbose=args.debug,
num_threads=args.num_threads,
Expand Down
4 changes: 2 additions & 2 deletions src/helm/benchmark/presentation/test_contamination.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from helm.benchmark.presentation.schema import read_schema, SCHEMA_CLASSIC_YAML_FILENAME
from helm.benchmark.presentation.schema import read_schema, get_default_schema_path
from helm.benchmark.presentation.contamination import read_contamination, validate_contamination


def test_contamination_schema():
schema = read_schema(SCHEMA_CLASSIC_YAML_FILENAME)
schema = read_schema(get_default_schema_path())
contamination = read_contamination()
validate_contamination(contamination, schema)

Expand Down
6 changes: 3 additions & 3 deletions src/helm/benchmark/presentation/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tempfile

from helm.benchmark.presentation.summarize import Summarizer
from helm.benchmark.presentation.schema import SCHEMA_CLASSIC_YAML_FILENAME
from helm.benchmark.presentation.schema import get_default_schema_path
from helm.common.general import ensure_directory_exists


Expand All @@ -13,7 +13,7 @@ def test_summarize_suite():
release=None,
suites=None,
suite="test_suite",
schema_file=SCHEMA_CLASSIC_YAML_FILENAME,
schema_path=get_default_schema_path(),
output_path=output_path,
verbose=False,
num_threads=4,
Expand All @@ -31,7 +31,7 @@ def test_summarize_release():
release="test_release",
suites=["test_suite_1", "test_suite_2"],
suite=None,
schema_file=SCHEMA_CLASSIC_YAML_FILENAME,
schema_path=get_default_schema_path(),
output_path=output_path,
verbose=False,
num_threads=4,
Expand Down