diff --git a/src/ess/reflectometry/tools.py b/src/ess/reflectometry/tools.py index f0540eb..71fd1f0 100644 --- a/src/ess/reflectometry/tools.py +++ b/src/ess/reflectometry/tools.py @@ -2,7 +2,6 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from collections.abc import Sequence from itertools import chain -from typing import Literal import numpy as np import scipp as sc @@ -180,17 +179,30 @@ def _interpolate_on_qgrid(f, curves, grid): def scale_reflectivity_curves_to_overlap( curves: Sequence[sc.DataArray], - qgrid: sc.Variable | None = None, return_scaling_factors=False, -): - '''Stitches the curves by scaling each except the first by a factor. +) -> list[sc.DataArray] | list[sc.scalar]: + '''Make the curves overlap by scaling all except the first by a factor. The scaling factors are determined by a maximum likelihood estimate (assuming the errors are normal distributed). All curves must be have the same unit for data and the Q-coordinate. + + Parameters: + curves: the reflectivity curves that should be scaled together + return_scaling_factor: + If True the return value of the function + is a list of the scaling factors that should be applied. + If False (default) the function returns the scaled curves. + + Returns: + A list of scaled reflectivity curves or a list of scaling factors. ''' - if qgrid is None: - qgrid = _create_qgrid_where_overlapping([c.coords['Q'] for c in curves]) + if len({c.data.unit for c in curves}) != 1: + raise ValueError('The reflectivity curves must have the same unit') + if len({c.coords['Q'].unit for c in curves}) != 1: + raise ValueError('The Q-coordinates must have the same unit for each curve') + + qgrid = _create_qgrid_where_overlapping([c.coords['Q'] for c in curves]) r = _interpolate_on_qgrid(lambda c: c.data.values, curves, qgrid) v = _interpolate_on_qgrid(lambda c: c.data.variances, curves, qgrid) @@ -219,8 +231,7 @@ def cost(scaling_factors): def combine_curves( curves: Sequence[sc.DataArray], qgrid: sc.Variable | None = None, - how: Literal['mean'] = 'mean', -): +) -> sc.DataArray: '''Combines the given curves by interpolating them on a grid and merging them by the requested method. The default method is a weighted mean where the weights @@ -230,22 +241,31 @@ def combine_curves( need to be scaled using :func:`stitch_reflectivity_curves`. All curves must be have the same unit for data and the Q-coordinate. + + Parameters: + curves: the reflectivity curves that should be combined + qgrid: the Q-grid of the resulting combined reflectivity curve + Returns: + A data array representing the combined reflectivity curve ''' + if len({c.data.unit for c in curves}) != 1: + raise ValueError('The reflectivity curves must have the same unit') + if len({c.coords['Q'].unit for c in curves}) != 1: + raise ValueError('The Q-coordinates must have the same unit for each curve') + r = _interpolate_on_qgrid(lambda c: c.data.values, curves, qgrid) v = _interpolate_on_qgrid(lambda c: c.data.variances, curves, qgrid) - if how == 'mean': - v[v == 0] = np.nan - inv_v = 1.0 / v - r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0) - v_avg = 1 / np.nansum(inv_v, axis=0) - return sc.DataArray( - data=sc.array( - dims='Q', - values=r_avg, - variances=v_avg, - unit=next(iter(curves)).data.unit, - ), - coords={'Q': qgrid}, - ) - return NotImplementedError + v[v == 0] = np.nan + inv_v = 1.0 / v + r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0) + v_avg = 1 / np.nansum(inv_v, axis=0) + return sc.DataArray( + data=sc.array( + dims='Q', + values=r_avg, + variances=v_avg, + unit=next(iter(curves)).data.unit, + ), + coords={'Q': qgrid}, + ) diff --git a/tests/tools_test.py b/tests/tools_test.py index 93be492..bf0ea9f 100644 --- a/tests/tools_test.py +++ b/tests/tools_test.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) import scipp as sc -from ess.reflectometry.tools import combine_curves, stitch_reflectivity_curves +from ess.reflectometry.tools import combine_curves, scale_reflectivity_curves_to_overlap from scipp.testing import assert_allclose @@ -9,8 +9,7 @@ def curve(d, qmin, qmax): return sc.DataArray(data=d, coords={'Q': sc.linspace('Q', qmin, qmax, len(d) + 1)}) -def test_curve_stitching(): - qgrid = sc.midpoints(sc.linspace('Q', 0, 1, 21)) +def test_reflectivity_curve_scaling(): data = sc.concat( ( sc.ones(dims=['Q'], shape=[10], with_variances=True), @@ -20,9 +19,8 @@ def test_curve_stitching(): ) data.variances[:] = 0.1 - curves = stitch_reflectivity_curves( + curves = scale_reflectivity_curves_to_overlap( (curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)), - qgrid, ) assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5)) @@ -30,7 +28,7 @@ def test_curve_stitching(): assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5)) -def test_curve_stitching_default_qgrid(): +def test_reflectivity_curve_scaling_return_factors(): data = sc.concat( ( sc.ones(dims=['Q'], shape=[10], with_variances=True), @@ -40,13 +38,14 @@ def test_curve_stitching_default_qgrid(): ) data.variances[:] = 0.1 - curves = stitch_reflectivity_curves( + factors = scale_reflectivity_curves_to_overlap( (curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)), + return_scaling_factors=True, ) - assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5)) - assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5)) - assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5)) + assert_allclose(factors[0], sc.scalar(1.0), rtol=sc.scalar(1e-5)) + assert_allclose(factors[1], sc.scalar(0.5 / 0.8), rtol=sc.scalar(1e-5)) + assert_allclose(factors[2], sc.scalar(0.25 / 0.1), rtol=sc.scalar(1e-5)) def test_combined_curves():