Skip to content

Commit

Permalink
Scatter plot analysis (#2744)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2744

Plotly Scatter plot for any two metrics. Each arm is represented by a single point, whose color indicates the arm's trial index. Optionally, the Pareto frontier can bes hown. This plot is useful for understanding the relationship and/or tradeoff between two metrics.

This will replace the pareto frontier plot, but also allow us to use the same codepath to plot the tradeoff between ANY two metrics, not just metrics on the optimization config. **I foresee this, parallel coordinates, and parameter importance to be the most important Analyses in typical OSS Ax usage.**

Differential Revision: D62207324
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Sep 5, 2024
1 parent 66aed70 commit 0f5c16a
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 3 deletions.
8 changes: 7 additions & 1 deletion ax/analysis/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,11 @@

from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.scatter import ScatterPlot

__all__ = ["PlotlyAnalysis", "PlotlyAnalysisCard", "ParallelCoordinatesPlot"]
__all__ = [
"PlotlyAnalysis",
"PlotlyAnalysisCard",
"ParallelCoordinatesPlot",
"ScatterPlot",
]
161 changes: 161 additions & 0 deletions ax/analysis/plotly/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Optional

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from plotly import express as px, graph_objects as go, io as pio


class ScatterPlot(PlotlyAnalysis):
"""
Plotly Scatter plot for any two metrics. Each arm is represented by a single point,
whose color indicates the arm's trial index. Optionally, the Pareto frontier can be
shown. This plot is useful for understanding the relationship and/or tradeoff
between two metrics.
The DataFrame computed will contain one row per arm and the following columns:
- arm_name: The name of the arm
- trial_index: The trial index of the arm
- X_METRIC_NAME: The observed mean of the metric specified
- Y_METRIC_NAME: The observed mean of the metric specified
- is_optimal: Whether the arm is on the Pareto frontier
"""

def __init__(
self, x_metric_name: str, y_metric_name: str, show_pareto_frontier: bool = False
) -> None:
"""
Args:
x_metric_name: The name of the metric to plot on the X axis.
y_metric_name: The name of the metric to plot on the Y axis.
show_pareto_frontier: Whether to show the Pareto frontier for the two
metrics. Optimization direction is inferred from the Experiment.
"""

self.x_metric_name = x_metric_name
self.y_metric_name = y_metric_name

self.show_pareto_frontier = show_pareto_frontier

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("ScatterPlot requires an Experiment")

df = _prepare_data(
experiment=experiment,
x_metric_name=self.x_metric_name,
y_metric_name=self.y_metric_name,
)
fig = _prepare_plot(
df=df,
x_metric_name=self.x_metric_name,
y_metric_name=self.y_metric_name,
show_pareto_frontier=self.show_pareto_frontier,
x_lower_is_better=experiment.metrics[self.x_metric_name].lower_is_better
or False,
)

return PlotlyAnalysisCard(
name=str(self),
title=f"Observed {self.x_metric_name} vs. {self.y_metric_name}",
subtitle="Compare arms by their observed metric values",
level=AnalysisCardLevel.HIGH,
df=df,
blob=pio.to_json(fig),
)


def _prepare_data(
experiment: Experiment, x_metric_name: str, y_metric_name: str
) -> pd.DataFrame:
# Lookup the data that has already been fetched and attached to the experiment
data = experiment.lookup_data().df

# Filter for only rows with the relevant metric names
metric_name_mask = data["metric_name"].isin([x_metric_name, y_metric_name])
filtered = data[metric_name_mask][
["arm_name", "trial_index", "metric_name", "mean"]
]

# Pivot the data so that each row is an arm and the columns are the metric names
pivoted: pd.DataFrame = filtered.pivot_table(
index=["arm_name", "trial_index"], columns="metric_name", values="mean"
)
pivoted.reset_index(inplace=True)
pivoted.columns.name = None

# Add a column indicating whether the arm is on the Pareto frontier. This is
# calculated by comparing each arm to all other arms in the experiment and
# creating a mask.
# If directional guidance is not specified, we assume that we intendt to maximize
# the metric.
x_lower_is_better: bool = experiment.metrics[x_metric_name].lower_is_better or False
y_lower_is_better: bool = experiment.metrics[y_metric_name].lower_is_better or False

def is_optimal(row: pd.Series) -> bool:
x_mask = (
(pivoted[x_metric_name] < row[x_metric_name])
if x_lower_is_better
else (pivoted[x_metric_name] > row[x_metric_name])
)
y_mask = (
(pivoted[y_metric_name] < row[y_metric_name])
if y_lower_is_better
else (pivoted[y_metric_name] > row[y_metric_name])
)
return not (x_mask & y_mask).any()

pivoted["is_optimal"] = pivoted.apply(
is_optimal,
axis=1,
)

return pivoted


def _prepare_plot(
df: pd.DataFrame,
x_metric_name: str,
y_metric_name: str,
show_pareto_frontier: bool,
x_lower_is_better: bool,
) -> go.Figure:
fig = px.scatter(
df,
x=x_metric_name,
y=y_metric_name,
color="trial_index",
hover_data=["trial_index", "arm_name", x_metric_name, y_metric_name],
)

if show_pareto_frontier:
# Must sort to ensure we draw the line through optimal points in the correct
# order.
frontier_df = df[df["is_optimal"]].sort_values(by=x_metric_name)

fig.add_trace(
go.Scatter(
x=frontier_df[x_metric_name],
y=frontier_df[y_metric_name],
mode="lines",
line_shape="hv" if x_lower_is_better else "vh",
showlegend=False,
)
)

return fig
48 changes: 48 additions & 0 deletions ax/analysis/plotly/tests/test_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.plotly.scatter import ScatterPlot
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment_with_multi_objective


class TestParallelCoordinatesPlot(TestCase):
def test_compute(self) -> None:
analysis = ScatterPlot(
x_metric_name="branin_a",
y_metric_name="branin_b",
show_pareto_frontier=True,
)
experiment = get_branin_experiment_with_multi_objective(
with_completed_trial=True
)

with self.assertRaisesRegex(UserInputError, "requires an Experiment"):
analysis.compute()

card = analysis.compute(experiment=experiment)
self.assertEqual(
card.name,
(
"ScatterPlot(x_metric_name=branin_a, y_metric_name=branin_b, "
"show_pareto_frontier=True)"
),
)
self.assertEqual(card.title, "Observed branin_a vs. branin_b")
self.assertEqual(
card.subtitle,
"Compare arms by their observed metric values",
)
self.assertEqual(card.level, AnalysisCardLevel.HIGH)
self.assertEqual(
{*card.df.columns},
{"arm_name", "trial_index", "branin_a", "branin_b", "is_optimal"},
)
self.assertIsNotNone(card.blob)
self.assertEqual(card.blob_annotation, "plotly")
12 changes: 10 additions & 2 deletions sphinx/source/analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,25 @@ Plotly Analysis
:show-inheritance:

Healthcheck Analysis
~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.analysis.healthcheck.healthcheck_analysis
:members:
:undoc-members:
:show-inheritance:

Parallel Coordinates Analysis
~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.analysis.plotly.parallel_coordinates
:members:
:undoc-members:
:show-inheritance:

Scatter Plot Analysis
~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.analysis.plotly.scatter
:members:
:undoc-members:
:show-inheritance:

0 comments on commit 0f5c16a

Please sign in to comment.