Skip to content

Commit

Permalink
Replace helm-summarize flag --schema-file with --schema-path (#2520)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai committed Apr 3, 2024
1 parent 8a091bf commit ee43575
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 20 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ nav
tags
notes.otl

# HELM config files
run_entries*.conf
schema*.yaml

# For Macs
.DS_Store

Expand Down
9 changes: 6 additions & 3 deletions src/helm/benchmark/presentation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,13 @@ def __post_init__(self):
self.name_to_run_group = {run_group.name: run_group for run_group in self.run_groups}


def read_schema(filename: str) -> Schema:
def get_default_schema_path() -> str:
return resources.files(SCHEMA_YAML_PACKAGE).joinpath(SCHEMA_CLASSIC_YAML_FILENAME)


def read_schema(schema_path: str) -> Schema:
# TODO: merge in model metadata from `model_metadata.yaml`
schema_path = resources.files(SCHEMA_YAML_PACKAGE).joinpath(filename)
hlog(f"Reading schema file {schema_path}...")
with schema_path.open("r") as f:
with open(schema_path, "r") as f:
raw = yaml.safe_load(f)
return dacite.from_dict(Schema, raw)
24 changes: 12 additions & 12 deletions src/helm/benchmark/presentation/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from tqdm import tqdm
from helm.benchmark.model_deployment_registry import get_model_deployment

from helm.benchmark.model_metadata_registry import get_unknown_model_metadata
from helm.common.general import (
write,
Expand All @@ -49,7 +48,7 @@
RunGroup,
Field,
read_schema,
SCHEMA_CLASSIC_YAML_FILENAME,
get_default_schema_path,
BY_GROUP,
THIS_GROUP_ONLY,
NO_GROUPS,
Expand Down Expand Up @@ -295,7 +294,7 @@ def __init__(
release: Optional[str],
suites: Optional[List[str]],
suite: Optional[str],
schema_file: str,
schema_path: str,
output_path: str,
verbose: bool,
num_threads: int,
Expand All @@ -315,7 +314,7 @@ def __init__(
self.suites: List[str]
self.run_suite_paths: List[str]
self.suite: Optional[str] = None
self.schema_file = schema_file
self.schema_path = schema_path
self.release: Optional[str] = None
if suite:
self.suite = suite
Expand All @@ -333,7 +332,7 @@ def __init__(

ensure_directory_exists(self.run_release_path)

self.schema = read_schema(schema_file)
self.schema = read_schema(schema_path)

def read_run(self, run_path: str) -> Run:
"""Load the `Run` object from `run_path`."""
Expand Down Expand Up @@ -361,7 +360,7 @@ def filter_runs_by_visibility(self, runs: List[Run], group: RunGroup) -> List[Ru
if run_group_name not in self.schema.name_to_run_group:
hlog(
f"WARNING: group {run_group_name} mentioned in run spec {run.run_spec.name} "
f"but undefined in {self.schema_file}, skipping"
f"but undefined in {self.schema_path}, skipping"
)
continue
run_group = self.schema.name_to_run_group[run_group_name]
Expand Down Expand Up @@ -629,7 +628,7 @@ def check_metrics_defined(self):
for metric_name, run_spec_names in metric_name_to_run_spec_names.items():
if metric_name not in defined_metric_names:
hlog(
f"WARNING: metric name {metric_name} undefined in {self.schema_file} "
f"WARNING: metric name {metric_name} undefined in {self.schema_path} "
f"but appears in {len(run_spec_names)} run specs, including {run_spec_names[0]}"
)

Expand Down Expand Up @@ -906,7 +905,7 @@ def create_group_table(
matcher = replace(matcher, sub_split=sub_split)
header_field = self.schema.name_to_metric.get(matcher.name)
if header_field is None:
hlog(f"WARNING: metric name {matcher.name} undefined in {self.schema_file}, skipping")
hlog(f"WARNING: metric name {matcher.name} undefined in {self.schema_path}, skipping")
continue
metadata = {
"metric": header_field.get_short_display_name(),
Expand Down Expand Up @@ -1368,10 +1367,9 @@ def main():
"-o", "--output-path", type=str, help="Where the benchmarking output lives", default="benchmark_output"
)
parser.add_argument(
"--schema-file",
"--schema-path",
type=str,
help="File name of the schema to read (e.g., schema_classic.yaml).",
default=SCHEMA_CLASSIC_YAML_FILENAME,
help="Path to the schema file (e.g., schema_classic.yaml).",
)
parser.add_argument(
"--suite",
Expand Down Expand Up @@ -1438,6 +1436,8 @@ def main():
else:
raise ValueError("Exactly one of --release or --suite must be specified.")

schema_path = args.schema_path if args.schema_path else get_default_schema_path()

register_builtin_configs_from_helm_package()
register_configs_from_directory(args.local_path)

Expand All @@ -1446,7 +1446,7 @@ def main():
release=release,
suites=suites,
suite=suite,
schema_file=args.schema_file,
schema_path=schema_path,
output_path=args.output_path,
verbose=args.debug,
num_threads=args.num_threads,
Expand Down
4 changes: 2 additions & 2 deletions src/helm/benchmark/presentation/test_contamination.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from helm.benchmark.presentation.schema import read_schema, SCHEMA_CLASSIC_YAML_FILENAME
from helm.benchmark.presentation.schema import read_schema, get_default_schema_path
from helm.benchmark.presentation.contamination import read_contamination, validate_contamination


def test_contamination_schema():
schema = read_schema(SCHEMA_CLASSIC_YAML_FILENAME)
schema = read_schema(get_default_schema_path())
contamination = read_contamination()
validate_contamination(contamination, schema)

Expand Down
6 changes: 3 additions & 3 deletions src/helm/benchmark/presentation/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tempfile

from helm.benchmark.presentation.summarize import Summarizer
from helm.benchmark.presentation.schema import SCHEMA_CLASSIC_YAML_FILENAME
from helm.benchmark.presentation.schema import get_default_schema_path
from helm.common.general import ensure_directory_exists


Expand All @@ -13,7 +13,7 @@ def test_summarize_suite():
release=None,
suites=None,
suite="test_suite",
schema_file=SCHEMA_CLASSIC_YAML_FILENAME,
schema_path=get_default_schema_path(),
output_path=output_path,
verbose=False,
num_threads=4,
Expand All @@ -31,7 +31,7 @@ def test_summarize_release():
release="test_release",
suites=["test_suite_1", "test_suite_2"],
suite=None,
schema_file=SCHEMA_CLASSIC_YAML_FILENAME,
schema_path=get_default_schema_path(),
output_path=output_path,
verbose=False,
num_threads=4,
Expand Down

0 comments on commit ee43575

Please sign in to comment.