-
Notifications
You must be signed in to change notification settings - Fork 303
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #2744 Plotly Scatter plot for any two metrics. Each arm is represented by a single point, whose color indicates the arm's trial index. Optionally, the Pareto frontier can bes hown. This plot is useful for understanding the relationship and/or tradeoff between two metrics. This will replace the pareto frontier plot, but also allow us to use the same codepath to plot the tradeoff between ANY two metrics, not just metrics on the optimization config. **I foresee this, parallel coordinates, and parameter importance to be the most important Analyses in typical OSS Ax usage.** Differential Revision: D62207324
- Loading branch information
1 parent
66aed70
commit 0f5c16a
Showing
4 changed files
with
226 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from typing import Optional | ||
|
||
import pandas as pd | ||
from ax.analysis.analysis import AnalysisCardLevel | ||
|
||
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard | ||
from ax.core.experiment import Experiment | ||
from ax.core.generation_strategy_interface import GenerationStrategyInterface | ||
from ax.exceptions.core import UserInputError | ||
from plotly import express as px, graph_objects as go, io as pio | ||
|
||
|
||
class ScatterPlot(PlotlyAnalysis): | ||
""" | ||
Plotly Scatter plot for any two metrics. Each arm is represented by a single point, | ||
whose color indicates the arm's trial index. Optionally, the Pareto frontier can be | ||
shown. This plot is useful for understanding the relationship and/or tradeoff | ||
between two metrics. | ||
The DataFrame computed will contain one row per arm and the following columns: | ||
- arm_name: The name of the arm | ||
- trial_index: The trial index of the arm | ||
- X_METRIC_NAME: The observed mean of the metric specified | ||
- Y_METRIC_NAME: The observed mean of the metric specified | ||
- is_optimal: Whether the arm is on the Pareto frontier | ||
""" | ||
|
||
def __init__( | ||
self, x_metric_name: str, y_metric_name: str, show_pareto_frontier: bool = False | ||
) -> None: | ||
""" | ||
Args: | ||
x_metric_name: The name of the metric to plot on the X axis. | ||
y_metric_name: The name of the metric to plot on the Y axis. | ||
show_pareto_frontier: Whether to show the Pareto frontier for the two | ||
metrics. Optimization direction is inferred from the Experiment. | ||
""" | ||
|
||
self.x_metric_name = x_metric_name | ||
self.y_metric_name = y_metric_name | ||
|
||
self.show_pareto_frontier = show_pareto_frontier | ||
|
||
def compute( | ||
self, | ||
experiment: Optional[Experiment] = None, | ||
generation_strategy: Optional[GenerationStrategyInterface] = None, | ||
) -> PlotlyAnalysisCard: | ||
if experiment is None: | ||
raise UserInputError("ScatterPlot requires an Experiment") | ||
|
||
df = _prepare_data( | ||
experiment=experiment, | ||
x_metric_name=self.x_metric_name, | ||
y_metric_name=self.y_metric_name, | ||
) | ||
fig = _prepare_plot( | ||
df=df, | ||
x_metric_name=self.x_metric_name, | ||
y_metric_name=self.y_metric_name, | ||
show_pareto_frontier=self.show_pareto_frontier, | ||
x_lower_is_better=experiment.metrics[self.x_metric_name].lower_is_better | ||
or False, | ||
) | ||
|
||
return PlotlyAnalysisCard( | ||
name=str(self), | ||
title=f"Observed {self.x_metric_name} vs. {self.y_metric_name}", | ||
subtitle="Compare arms by their observed metric values", | ||
level=AnalysisCardLevel.HIGH, | ||
df=df, | ||
blob=pio.to_json(fig), | ||
) | ||
|
||
|
||
def _prepare_data( | ||
experiment: Experiment, x_metric_name: str, y_metric_name: str | ||
) -> pd.DataFrame: | ||
# Lookup the data that has already been fetched and attached to the experiment | ||
data = experiment.lookup_data().df | ||
|
||
# Filter for only rows with the relevant metric names | ||
metric_name_mask = data["metric_name"].isin([x_metric_name, y_metric_name]) | ||
filtered = data[metric_name_mask][ | ||
["arm_name", "trial_index", "metric_name", "mean"] | ||
] | ||
|
||
# Pivot the data so that each row is an arm and the columns are the metric names | ||
pivoted: pd.DataFrame = filtered.pivot_table( | ||
index=["arm_name", "trial_index"], columns="metric_name", values="mean" | ||
) | ||
pivoted.reset_index(inplace=True) | ||
pivoted.columns.name = None | ||
|
||
# Add a column indicating whether the arm is on the Pareto frontier. This is | ||
# calculated by comparing each arm to all other arms in the experiment and | ||
# creating a mask. | ||
# If directional guidance is not specified, we assume that we intendt to maximize | ||
# the metric. | ||
x_lower_is_better: bool = experiment.metrics[x_metric_name].lower_is_better or False | ||
y_lower_is_better: bool = experiment.metrics[y_metric_name].lower_is_better or False | ||
|
||
def is_optimal(row: pd.Series) -> bool: | ||
x_mask = ( | ||
(pivoted[x_metric_name] < row[x_metric_name]) | ||
if x_lower_is_better | ||
else (pivoted[x_metric_name] > row[x_metric_name]) | ||
) | ||
y_mask = ( | ||
(pivoted[y_metric_name] < row[y_metric_name]) | ||
if y_lower_is_better | ||
else (pivoted[y_metric_name] > row[y_metric_name]) | ||
) | ||
return not (x_mask & y_mask).any() | ||
|
||
pivoted["is_optimal"] = pivoted.apply( | ||
is_optimal, | ||
axis=1, | ||
) | ||
|
||
return pivoted | ||
|
||
|
||
def _prepare_plot( | ||
df: pd.DataFrame, | ||
x_metric_name: str, | ||
y_metric_name: str, | ||
show_pareto_frontier: bool, | ||
x_lower_is_better: bool, | ||
) -> go.Figure: | ||
fig = px.scatter( | ||
df, | ||
x=x_metric_name, | ||
y=y_metric_name, | ||
color="trial_index", | ||
hover_data=["trial_index", "arm_name", x_metric_name, y_metric_name], | ||
) | ||
|
||
if show_pareto_frontier: | ||
# Must sort to ensure we draw the line through optimal points in the correct | ||
# order. | ||
frontier_df = df[df["is_optimal"]].sort_values(by=x_metric_name) | ||
|
||
fig.add_trace( | ||
go.Scatter( | ||
x=frontier_df[x_metric_name], | ||
y=frontier_df[y_metric_name], | ||
mode="lines", | ||
line_shape="hv" if x_lower_is_better else "vh", | ||
showlegend=False, | ||
) | ||
) | ||
|
||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from ax.analysis.analysis import AnalysisCardLevel | ||
from ax.analysis.plotly.scatter import ScatterPlot | ||
from ax.exceptions.core import UserInputError | ||
from ax.utils.common.testutils import TestCase | ||
from ax.utils.testing.core_stubs import get_branin_experiment_with_multi_objective | ||
|
||
|
||
class TestParallelCoordinatesPlot(TestCase): | ||
def test_compute(self) -> None: | ||
analysis = ScatterPlot( | ||
x_metric_name="branin_a", | ||
y_metric_name="branin_b", | ||
show_pareto_frontier=True, | ||
) | ||
experiment = get_branin_experiment_with_multi_objective( | ||
with_completed_trial=True | ||
) | ||
|
||
with self.assertRaisesRegex(UserInputError, "requires an Experiment"): | ||
analysis.compute() | ||
|
||
card = analysis.compute(experiment=experiment) | ||
self.assertEqual( | ||
card.name, | ||
( | ||
"ScatterPlot(x_metric_name=branin_a, y_metric_name=branin_b, " | ||
"show_pareto_frontier=True)" | ||
), | ||
) | ||
self.assertEqual(card.title, "Observed branin_a vs. branin_b") | ||
self.assertEqual( | ||
card.subtitle, | ||
"Compare arms by their observed metric values", | ||
) | ||
self.assertEqual(card.level, AnalysisCardLevel.HIGH) | ||
self.assertEqual( | ||
{*card.df.columns}, | ||
{"arm_name", "trial_index", "branin_a", "branin_b", "is_optimal"}, | ||
) | ||
self.assertIsNotNone(card.blob) | ||
self.assertEqual(card.blob_annotation, "plotly") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters