Skip to content

Commit

Permalink
docs: improve docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
jokasimr committed Sep 19, 2024
1 parent 2aa181f commit e32ba38
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 33 deletions.
66 changes: 43 additions & 23 deletions src/ess/reflectometry/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from collections.abc import Sequence
from itertools import chain
from typing import Literal

import numpy as np
import scipp as sc
Expand Down Expand Up @@ -180,17 +179,30 @@ def _interpolate_on_qgrid(f, curves, grid):

def scale_reflectivity_curves_to_overlap(
curves: Sequence[sc.DataArray],
qgrid: sc.Variable | None = None,
return_scaling_factors=False,
):
'''Stitches the curves by scaling each except the first by a factor.
) -> list[sc.DataArray] | list[sc.scalar]:
'''Make the curves overlap by scaling all except the first by a factor.
The scaling factors are determined by a maximum likelihood estimate
(assuming the errors are normal distributed).
All curves must be have the same unit for data and the Q-coordinate.
Parameters:
curves: the reflectivity curves that should be scaled together
return_scaling_factor:
If True the return value of the function
is a list of the scaling factors that should be applied.
If False (default) the function returns the scaled curves.
Returns:
A list of scaled reflectivity curves or a list of scaling factors.
'''
if qgrid is None:
qgrid = _create_qgrid_where_overlapping([c.coords['Q'] for c in curves])
if len({c.data.unit for c in curves}) != 1:
raise ValueError('The reflectivity curves must have the same unit')
if len({c.coords['Q'].unit for c in curves}) != 1:
raise ValueError('The Q-coordinates must have the same unit for each curve')

qgrid = _create_qgrid_where_overlapping([c.coords['Q'] for c in curves])

r = _interpolate_on_qgrid(lambda c: c.data.values, curves, qgrid)
v = _interpolate_on_qgrid(lambda c: c.data.variances, curves, qgrid)
Expand Down Expand Up @@ -219,8 +231,7 @@ def cost(scaling_factors):
def combine_curves(
curves: Sequence[sc.DataArray],
qgrid: sc.Variable | None = None,
how: Literal['mean'] = 'mean',
):
) -> sc.DataArray:
'''Combines the given curves by interpolating them
on a grid and merging them by the requested method.
The default method is a weighted mean where the weights
Expand All @@ -230,22 +241,31 @@ def combine_curves(
need to be scaled using :func:`stitch_reflectivity_curves`.
All curves must be have the same unit for data and the Q-coordinate.
Parameters:
curves: the reflectivity curves that should be combined
qgrid: the Q-grid of the resulting combined reflectivity curve
Returns:
A data array representing the combined reflectivity curve
'''
if len({c.data.unit for c in curves}) != 1:
raise ValueError('The reflectivity curves must have the same unit')
if len({c.coords['Q'].unit for c in curves}) != 1:
raise ValueError('The Q-coordinates must have the same unit for each curve')

r = _interpolate_on_qgrid(lambda c: c.data.values, curves, qgrid)
v = _interpolate_on_qgrid(lambda c: c.data.variances, curves, qgrid)

if how == 'mean':
v[v == 0] = np.nan
inv_v = 1.0 / v
r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0)
v_avg = 1 / np.nansum(inv_v, axis=0)
return sc.DataArray(
data=sc.array(
dims='Q',
values=r_avg,
variances=v_avg,
unit=next(iter(curves)).data.unit,
),
coords={'Q': qgrid},
)
return NotImplementedError
v[v == 0] = np.nan
inv_v = 1.0 / v
r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0)
v_avg = 1 / np.nansum(inv_v, axis=0)
return sc.DataArray(
data=sc.array(
dims='Q',
values=r_avg,
variances=v_avg,
unit=next(iter(curves)).data.unit,
),
coords={'Q': qgrid},
)
19 changes: 9 additions & 10 deletions tests/tools_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import scipp as sc
from ess.reflectometry.tools import combine_curves, stitch_reflectivity_curves
from ess.reflectometry.tools import combine_curves, scale_reflectivity_curves_to_overlap
from scipp.testing import assert_allclose


def curve(d, qmin, qmax):
return sc.DataArray(data=d, coords={'Q': sc.linspace('Q', qmin, qmax, len(d) + 1)})


def test_curve_stitching():
qgrid = sc.midpoints(sc.linspace('Q', 0, 1, 21))
def test_reflectivity_curve_scaling():
data = sc.concat(
(
sc.ones(dims=['Q'], shape=[10], with_variances=True),
Expand All @@ -20,17 +19,16 @@ def test_curve_stitching():
)
data.variances[:] = 0.1

curves = stitch_reflectivity_curves(
curves = scale_reflectivity_curves_to_overlap(
(curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)),
qgrid,
)

assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5))
assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5))
assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5))


def test_curve_stitching_default_qgrid():
def test_reflectivity_curve_scaling_return_factors():
data = sc.concat(
(
sc.ones(dims=['Q'], shape=[10], with_variances=True),
Expand All @@ -40,13 +38,14 @@ def test_curve_stitching_default_qgrid():
)
data.variances[:] = 0.1

curves = stitch_reflectivity_curves(
factors = scale_reflectivity_curves_to_overlap(
(curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)),
return_scaling_factors=True,
)

assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5))
assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5))
assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5))
assert_allclose(factors[0], sc.scalar(1.0), rtol=sc.scalar(1e-5))
assert_allclose(factors[1], sc.scalar(0.5 / 0.8), rtol=sc.scalar(1e-5))
assert_allclose(factors[2], sc.scalar(0.25 / 0.1), rtol=sc.scalar(1e-5))


def test_combined_curves():
Expand Down

0 comments on commit e32ba38

Please sign in to comment.