diff --git a/.gitignore b/.gitignore index cb3d5237f8..718285e065 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,10 @@ nav tags notes.otl +# HELM config files +run_entries*.conf +schema*.yaml + # For Macs .DS_Store diff --git a/src/helm/benchmark/presentation/schema.py b/src/helm/benchmark/presentation/schema.py index ea34d8bc8e..c295833a00 100644 --- a/src/helm/benchmark/presentation/schema.py +++ b/src/helm/benchmark/presentation/schema.py @@ -218,10 +218,13 @@ def __post_init__(self): self.name_to_run_group = {run_group.name: run_group for run_group in self.run_groups} -def read_schema(filename: str) -> Schema: +def get_default_schema_path() -> str: + return resources.files(SCHEMA_YAML_PACKAGE).joinpath(SCHEMA_CLASSIC_YAML_FILENAME) + + +def read_schema(schema_path: str) -> Schema: # TODO: merge in model metadata from `model_metadata.yaml` - schema_path = resources.files(SCHEMA_YAML_PACKAGE).joinpath(filename) hlog(f"Reading schema file {schema_path}...") - with schema_path.open("r") as f: + with open(schema_path, "r") as f: raw = yaml.safe_load(f) return dacite.from_dict(Schema, raw) diff --git a/src/helm/benchmark/presentation/summarize.py b/src/helm/benchmark/presentation/summarize.py index af6dddbe04..0b841c59a3 100644 --- a/src/helm/benchmark/presentation/summarize.py +++ b/src/helm/benchmark/presentation/summarize.py @@ -22,7 +22,6 @@ from tqdm import tqdm from helm.benchmark.model_deployment_registry import get_model_deployment - from helm.benchmark.model_metadata_registry import get_unknown_model_metadata from helm.common.general import ( write, @@ -49,7 +48,7 @@ RunGroup, Field, read_schema, - SCHEMA_CLASSIC_YAML_FILENAME, + get_default_schema_path, BY_GROUP, THIS_GROUP_ONLY, NO_GROUPS, @@ -295,7 +294,7 @@ def __init__( release: Optional[str], suites: Optional[List[str]], suite: Optional[str], - schema_file: str, + schema_path: str, output_path: str, verbose: bool, num_threads: int, @@ -315,7 +314,7 @@ def __init__( self.suites: List[str] self.run_suite_paths: List[str] self.suite: Optional[str] = None - self.schema_file = schema_file + self.schema_path = schema_path self.release: Optional[str] = None if suite: self.suite = suite @@ -333,7 +332,7 @@ def __init__( ensure_directory_exists(self.run_release_path) - self.schema = read_schema(schema_file) + self.schema = read_schema(schema_path) def read_run(self, run_path: str) -> Run: """Load the `Run` object from `run_path`.""" @@ -361,7 +360,7 @@ def filter_runs_by_visibility(self, runs: List[Run], group: RunGroup) -> List[Ru if run_group_name not in self.schema.name_to_run_group: hlog( f"WARNING: group {run_group_name} mentioned in run spec {run.run_spec.name} " - f"but undefined in {self.schema_file}, skipping" + f"but undefined in {self.schema_path}, skipping" ) continue run_group = self.schema.name_to_run_group[run_group_name] @@ -629,7 +628,7 @@ def check_metrics_defined(self): for metric_name, run_spec_names in metric_name_to_run_spec_names.items(): if metric_name not in defined_metric_names: hlog( - f"WARNING: metric name {metric_name} undefined in {self.schema_file} " + f"WARNING: metric name {metric_name} undefined in {self.schema_path} " f"but appears in {len(run_spec_names)} run specs, including {run_spec_names[0]}" ) @@ -906,7 +905,7 @@ def create_group_table( matcher = replace(matcher, sub_split=sub_split) header_field = self.schema.name_to_metric.get(matcher.name) if header_field is None: - hlog(f"WARNING: metric name {matcher.name} undefined in {self.schema_file}, skipping") + hlog(f"WARNING: metric name {matcher.name} undefined in {self.schema_path}, skipping") continue metadata = { "metric": header_field.get_short_display_name(), @@ -1368,10 +1367,9 @@ def main(): "-o", "--output-path", type=str, help="Where the benchmarking output lives", default="benchmark_output" ) parser.add_argument( - "--schema-file", + "--schema-path", type=str, - help="File name of the schema to read (e.g., schema_classic.yaml).", - default=SCHEMA_CLASSIC_YAML_FILENAME, + help="Path to the schema file (e.g., schema_classic.yaml).", ) parser.add_argument( "--suite", @@ -1438,6 +1436,8 @@ def main(): else: raise ValueError("Exactly one of --release or --suite must be specified.") + schema_path = args.schema_path if args.schema_path else get_default_schema_path() + register_builtin_configs_from_helm_package() register_configs_from_directory(args.local_path) @@ -1446,7 +1446,7 @@ def main(): release=release, suites=suites, suite=suite, - schema_file=args.schema_file, + schema_path=schema_path, output_path=args.output_path, verbose=args.debug, num_threads=args.num_threads, diff --git a/src/helm/benchmark/presentation/test_contamination.py b/src/helm/benchmark/presentation/test_contamination.py index f0ac123a48..a1948396de 100644 --- a/src/helm/benchmark/presentation/test_contamination.py +++ b/src/helm/benchmark/presentation/test_contamination.py @@ -1,9 +1,9 @@ -from helm.benchmark.presentation.schema import read_schema, SCHEMA_CLASSIC_YAML_FILENAME +from helm.benchmark.presentation.schema import read_schema, get_default_schema_path from helm.benchmark.presentation.contamination import read_contamination, validate_contamination def test_contamination_schema(): - schema = read_schema(SCHEMA_CLASSIC_YAML_FILENAME) + schema = read_schema(get_default_schema_path()) contamination = read_contamination() validate_contamination(contamination, schema) diff --git a/src/helm/benchmark/presentation/test_summarize.py b/src/helm/benchmark/presentation/test_summarize.py index 35dc7a36d3..65b1bbf654 100644 --- a/src/helm/benchmark/presentation/test_summarize.py +++ b/src/helm/benchmark/presentation/test_summarize.py @@ -2,7 +2,7 @@ import tempfile from helm.benchmark.presentation.summarize import Summarizer -from helm.benchmark.presentation.schema import SCHEMA_CLASSIC_YAML_FILENAME +from helm.benchmark.presentation.schema import get_default_schema_path from helm.common.general import ensure_directory_exists @@ -13,7 +13,7 @@ def test_summarize_suite(): release=None, suites=None, suite="test_suite", - schema_file=SCHEMA_CLASSIC_YAML_FILENAME, + schema_path=get_default_schema_path(), output_path=output_path, verbose=False, num_threads=4, @@ -31,7 +31,7 @@ def test_summarize_release(): release="test_release", suites=["test_suite_1", "test_suite_2"], suite=None, - schema_file=SCHEMA_CLASSIC_YAML_FILENAME, + schema_path=get_default_schema_path(), output_path=output_path, verbose=False, num_threads=4,