Skip to content

Commit

Permalink
Delete model metadata from schema_*.yaml (#2195)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai committed Jan 5, 2024
1 parent 69c583f commit 0c08f9d
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 1,340 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ nodeenv==1.7.0
numba==0.56.4
numpy==1.23.3
openai==0.27.8
opencv-python==4.8.1.78
openpyxl==3.0.10
outcome==1.2.0
packaging==21.3
Expand Down
15 changes: 8 additions & 7 deletions src/helm/benchmark/presentation/create_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import numpy as np
from scipy.stats import pearsonr

from helm.benchmark.config_registry import register_builtin_configs_from_helm_package
from helm.common.hierarchical_logger import hlog
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.benchmark.presentation.schema import read_schema, SCHEMA_CLASSIC_YAML_FILENAME
from helm.benchmark.model_metadata_registry import MODEL_NAME_TO_MODEL_METADATA
from helm.benchmark.presentation.summarize import AGGREGATE_WIN_RATE_COLUMN

try:
Expand Down Expand Up @@ -133,9 +134,6 @@ def __init__(self, base_path: str, save_path: str, plot_format: str):
self.plot_format = plot_format
self._tables_cache: Dict[str, Dict[str, Table]] = {}

schema = read_schema(SCHEMA_CLASSIC_YAML_FILENAME)
self.model_metadata = {model_field.display_name: model_field for model_field in schema.models}

def get_group_tables(self, group_name: str) -> Dict[str, Table]:
"""Reads and parses group tables. Uses _tables_cache to avoid reprocessing the same table multiple times."""
if group_name in self._tables_cache:
Expand Down Expand Up @@ -338,14 +336,14 @@ def create_all_accuracy_v_model_property_plots(self):

def get_model_release_date(model_name: str) -> Optional[date]:
"""Maps a model name to the month of model release."""
release_date = self.model_metadata[model_name].release_date
release_date = MODEL_NAME_TO_MODEL_METADATA[model_name].release_date
if release_date is None:
return None
return release_date.replace(day=1)

def get_model_size(model_name: str) -> Optional[int]:
"""Maps a model name to the number of parameters, rounding to the nearest leading digit."""
size = self.model_metadata[model_name].num_parameters
size = MODEL_NAME_TO_MODEL_METADATA[model_name].num_parameters
if size is None:
return None
grain = 10 ** (len(str(size)) - 1)
Expand Down Expand Up @@ -401,7 +399,9 @@ def create_accuracy_v_access_bar_plot(self):

for i, access_level in enumerate(access_levels):
model_indices: List[int] = [
idx for idx, model in enumerate(table.adapters) if self.model_metadata[model].access == access_level
idx
for idx, model in enumerate(table.adapters)
if MODEL_NAME_TO_MODEL_METADATA[model].access == access_level
]
best_model_index = model_indices[table.mean_win_rates[model_indices].argmax()]

Expand Down Expand Up @@ -611,6 +611,7 @@ def main():
parser.add_argument("--suite", type=str, help="Name of the suite that we are plotting", required=True)
parser.add_argument("--plot-format", help="Format for saving plots", default="png", choices=["png", "pdf"])
args = parser.parse_args()
register_builtin_configs_from_helm_package()
base_path = os.path.join(args.output_path, "runs", args.suite)
if not os.path.exists(os.path.join(base_path, "groups")):
hlog(f"ERROR: Could not find `groups` directory under {base_path}. Did you run `summarize.py` first?")
Expand Down
33 changes: 0 additions & 33 deletions src/helm/benchmark/presentation/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
from datetime import date
from typing import List, Optional, Dict
import dacite
import mako.template
Expand Down Expand Up @@ -46,34 +45,6 @@ def get_short_display_name(self) -> str:
return name


# Note: also see Model from `models.py`.
@dataclass(frozen=True)
class ModelField(Field):
# Organization that originally created the model (e.g. "EleutherAI")
# Note that this may be different from group or the prefix of the model `name`
# ("together" in "together/gpt-j-6b") as the hosting organization
# may be different from the creator organization. We also capitalize
# this field properly to later display in the UI.
# TODO: in the future, we want to cleanup the naming in the following ways:
# - make the creator_organization an identifier with a separate display name
# - have a convention like <hosting_organization><creator_organization>/<model_name>
creator_organization: Optional[str] = None

# How this model is available (e.g., limited)
access: Optional[str] = None

# Whether we have yet to evaluate this model
todo: bool = False

# When was the model released
release_date: Optional[date] = None

# The number of parameters
# This should be a string as the number of parameters is usually a round number (175B),
# but we set it as an int for plotting purposes.
num_parameters: Optional[int] = None


@dataclass(frozen=True)
class MetricNameMatcher:
"""
Expand Down Expand Up @@ -222,9 +193,6 @@ class RunGroup(Field):
class Schema:
"""Specifies information about what to display on the frontend."""

# Models
models: List[ModelField]

# Adapter fields (e.g., temperature)
adapter: List[Field]

Expand All @@ -241,7 +209,6 @@ class Schema:
run_groups: List[RunGroup]

def __post_init__(self):
self.name_to_model = {model.name: model for model in self.models}
self.name_to_metric = {metric.name: metric for metric in self.metrics}
self.name_to_perturbation = {perturbation.name: perturbation for perturbation in self.perturbations}
self.name_to_metric_group = {metric_group.name: metric_group for metric_group in self.metric_groups}
Expand Down
74 changes: 63 additions & 11 deletions src/helm/benchmark/presentation/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
write,
ensure_directory_exists,
asdict_without_nones,
serialize_dates,
parallel_map,
singleton,
unique_simplification,
Expand All @@ -47,6 +46,7 @@
from helm.benchmark.presentation.schema import (
MetricNameMatcher,
RunGroup,
Field,
read_schema,
SCHEMA_CLASSIC_YAML_FILENAME,
BY_GROUP,
Expand All @@ -62,7 +62,7 @@
)
from helm.benchmark.config_registry import register_builtin_configs_from_helm_package, register_configs_from_directory
from helm.benchmark.presentation.run_display import write_run_display_json
from helm.benchmark.model_metadata_registry import ModelMetadata, get_model_metadata
from helm.benchmark.model_metadata_registry import ModelMetadata, get_model_metadata, get_all_models


OVERLAP_N_COUNT = 13
Expand Down Expand Up @@ -172,7 +172,7 @@ def get_model_metadata_for_adapter_spec(adapter_spec: AdapterSpec) -> ModelMetad
except ValueError:
pass

# Return a placeholder "unknoown model" model metadata.
# Return a placeholder "unknown model" model metadata.
return get_unknown_model_metadata(adapter_spec.model)


Expand Down Expand Up @@ -433,11 +433,61 @@ def group_runs(self):
self.group_adapter_to_runs[group_name][adapter_spec].append(run)
self.group_scenario_adapter_to_runs[group_name][scenario_spec][adapter_spec].append(run)

def write_schema(self):
@dataclass(frozen=True)
class _ModelField(Field):
"""The frontend version of ModelMetadata.
The frontend expects schema.json to contains a field under "model" that contains a list of `ModelField`s.
All attributes have the same meaning as in ModelMetadata."""

# TODO: Migrate frontend to use ModelMetadata instead of ModelField and delete this.
creator_organization: Optional[str] = None
access: Optional[str] = None
todo: bool = False
release_date: Optional[str] = None
num_parameters: Optional[int] = None

def get_model_field_dicts(self) -> List[Dict]:
"""Get a list of `ModelField`s dicts that will be written to schema.json.
The frontend expects schema.json to contains a field under "model" that contains a list of `ModelField`s.
This is populated by reading the `ModelMetadata` configs and filtering down to models that were
actually used, and converting each `ModelMetadata` to a `ModelField`."""
# TODO: Migrate frontend to use ModelMetadata instead of ModelField and delete this.
used_model_names: Set[str] = set()
for run in self.runs:
used_model_names.add(get_model_metadata_for_adapter_spec(run.run_spec.adapter_spec).name)

model_field_dicts: List[Dict] = []
for model_name in get_all_models():
if model_name not in used_model_names:
continue
model_metadata = get_model_metadata(model_name)
model_field = Summarizer._ModelField(
name=model_metadata.name,
display_name=model_metadata.display_name,
short_display_name=model_metadata.display_name,
description=model_metadata.description,
creator_organization=model_metadata.creator_organization_name,
access=model_metadata.access,
todo=False,
release_date=model_metadata.release_date.isoformat() if model_metadata.release_date else None,
num_parameters=model_metadata.num_parameters,
)
model_field_dicts.append(asdict_without_nones(model_field))
return model_field_dicts

def write_schema(self) -> None:
"""Write the schema file to benchmark_output so the frontend knows about it."""
# Manually add the model metadata to the schema.json, where the frontend expects it.
# TODO: Move model metadata out of schema.json into its own model_metadata.json file.
raw_schema = asdict_without_nones(self.schema)
raw_schema["models"] = self.get_model_field_dicts()
write(
os.path.join(self.run_release_path, "schema.json"),
json.dumps(asdict_without_nones(self.schema), indent=2, default=serialize_dates),
json.dumps(raw_schema, indent=2),
)

def read_runs(self):
Expand Down Expand Up @@ -921,10 +971,10 @@ def run_spec_names_to_url(run_spec_names: List[str]) -> str:

adapter_specs: List[AdapterSpec] = list(adapter_to_runs.keys())
if sort_by_model_order:
# Sort models by the order defined in the schema.
# Models not defined in the schema will be sorted alphabetically and
# placed before models in defined the schema.
model_order = [model.name for model in self.schema.models]
# Sort models by the order defined in the the model metadata config.
# Models not defined in the model metadata config will be sorted alphabetically and
# placed before models in defined the model metadata config.
model_order = get_all_models()

def _adapter_spec_sort_key(spec):
index = model_order.index(spec.model_deployment) if spec.model_deployment in model_order else -1
Expand Down Expand Up @@ -1304,8 +1354,6 @@ def symlink_latest(self) -> None:

def run_pipeline(self, skip_completed: bool, num_instances: int) -> None:
"""Run the entire summarization pipeline."""
self.write_schema()

self.read_runs()
self.group_runs()
self.check_metrics_defined()
Expand All @@ -1320,6 +1368,10 @@ def run_pipeline(self, skip_completed: bool, num_instances: int) -> None:
# because it uses self.scenario_spec_instance_id_dict
self.read_overlap_stats()

# Must happen after self.read_runs()
# because it uses self.runs
self.write_schema()

self.write_executive_summary()
self.write_runs()
self.write_run_specs()
Expand Down
Loading

0 comments on commit 0c08f9d

Please sign in to comment.