Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better datetime axis compatibility #312

Merged
merged 4 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/plopp/backends/common.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test?

Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def union(self, other: BoundingBox) -> BoundingBox:
"""
Return the union of this bounding box with another one.
"""

return BoundingBox(
xmin=_none_min(self.xmin, other.xmin),
xmax=_none_max(self.xmax, other.xmax),
Expand Down
11 changes: 9 additions & 2 deletions src/plopp/backends/matplotlib/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import numpy as np
import scipp as sc
from matplotlib import dates as mdates
from matplotlib.collections import QuadMesh
from mpl_toolkits.axes_grid1 import make_axes_locatable

Expand All @@ -14,6 +15,10 @@
from .utils import fig_to_bytes, is_sphinx_build, make_figure


def _to_floats(x):
return mdates.date2num(x) if np.issubdtype(x.dtype, np.datetime64) else x


def _none_if_not_finite(x: Union[float, int, None]) -> Union[float, int, None]:
if x is None:
return None
Expand Down Expand Up @@ -149,9 +154,11 @@ def autoscale(self):
lines = [line for line in self.ax.lines if hasattr(line, '_plopp_mask')]
for line in lines:
line_mask = sc.array(dims=['x'], values=line._plopp_mask)
line_x = sc.DataArray(data=sc.array(dims=['x'], values=line.get_xdata()))
line_x = sc.DataArray(
data=sc.array(dims=['x'], values=_to_floats(line.get_xdata()))
)
line_y = sc.DataArray(
data=sc.array(dims=['x'], values=line.get_ydata()),
data=sc.array(dims=['x'], values=_to_floats(line.get_ydata())),
masks={'mask': line_mask},
)
line_bbox = BoundingBox(
Expand Down
27 changes: 27 additions & 0 deletions tests/backends/matplotlib/mpl_figure_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import numpy as np
import scipp as sc

from plopp.backends.matplotlib import MatplotlibBackend
from plopp.backends.matplotlib.interactive import InteractiveFig
from plopp.backends.matplotlib.static import StaticFig
from plopp.core import Node
from plopp.graphics.imageview import ImageView
from plopp.graphics.lineview import LineView

Expand All @@ -30,3 +33,27 @@ def test_create_interactive_fig2d(use_ipympl):
b = MatplotlibBackend()
fig = b.figure2d(View=ImageView)
assert isinstance(fig, InteractiveFig)


def test_datetime_compatibility_between_1d_and_2d_figures():
b = MatplotlibBackend()
# 2d data
t = np.arange(
np.datetime64('2017-03-16T20:58:17'), np.datetime64('2017-03-16T21:15:17'), 20
)
time = sc.array(dims=['time'], values=t)
z = sc.arange('z', 50.0, unit='m')
v = 10 * np.random.random(z.shape + time.shape)
da2d = sc.DataArray(
data=sc.array(dims=['z', 'time'], values=v), coords={'time': time, 'z': z}
)
fig = b.figure2d(ImageView, Node(da2d))
assert len(fig.ax.lines) == 0
assert len(fig.ax.collections) == 1

# 1d data
v = np.random.rand(time.sizes['time'])
da1d = sc.DataArray(data=sc.array(dims=['time'], values=v), coords={'time': time})
b.figure1d(LineView, Node(da1d), ax=fig.ax)
assert len(fig.ax.lines) > 0
assert len(fig.ax.collections) == 1
2 changes: 1 addition & 1 deletion tests/backends/matplotlib/mpl_lineview_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_grid():


def test_ax():
fig, ax = plt.subplots()
_, ax = plt.subplots()
assert len(ax.lines) == 0
da = data_array(ndim=1)
_ = LineView(Node(da), ax=ax)
Expand Down
Loading