diff --git a/src/esssans/direct_beam.py b/src/esssans/direct_beam.py index 7f60f51b..621e298d 100644 --- a/src/esssans/direct_beam.py +++ b/src/esssans/direct_beam.py @@ -7,7 +7,19 @@ import scipp as sc from sciline import Pipeline -from .types import BackgroundSubtractedIofQ, DirectBeam, WavelengthBands +from .types import ( + BackgroundRun, + BackgroundSubtractedIofQ, + CleanMonitor, + CleanSummedQ, + DirectBeam, + Incident, + Numerator, + SampleRun, + SolidAngle, + TransmissionFraction, + WavelengthBands, +) def _compute_efficiency_correction( @@ -85,19 +97,34 @@ def direct_beam(pipeline: Pipeline, I0: sc.Variable, niter: int = 5) -> List[dic """ direct_beam_function = None - wavelength_bands = pipeline.compute(WavelengthBands) - band_dim = (set(wavelength_bands.dims) - {'wavelength'}).pop() + bands = pipeline.compute(WavelengthBands) + band_dim = (set(bands.dims) - {'wavelength'}).pop() - full_wavelength_range = sc.concat( - [wavelength_bands.min(), wavelength_bands.max()], dim='wavelength' - ) + full_wavelength_range = sc.concat([bands.min(), bands.max()], dim='wavelength') - pipeline_bands = pipeline.copy() - pipeline_full = pipeline_bands.copy() - pipeline_full[WavelengthBands] = full_wavelength_range + pipeline = pipeline.copy() + # Append full wavelength range as extra band. This allows for running only a + # single pipeline to compute both the I(Q) in bands and the I(Q) for the full + # wavelength range. + pipeline[WavelengthBands] = sc.concat([bands, full_wavelength_range], dim=band_dim) results = [] + # Compute checkpoints to avoid recomputing the same things in every iteration + checkpoints = ( + TransmissionFraction[SampleRun], + TransmissionFraction[BackgroundRun], + SolidAngle[SampleRun], + SolidAngle[BackgroundRun], + CleanMonitor[SampleRun, Incident], + CleanMonitor[BackgroundRun, Incident], + CleanSummedQ[SampleRun, Numerator], + CleanSummedQ[BackgroundRun, Numerator], + ) + + for key, result in pipeline.compute(checkpoints).items(): + pipeline[key] = result + for it in range(niter): print("Iteration", it) @@ -105,17 +132,16 @@ def direct_beam(pipeline: Pipeline, I0: sc.Variable, niter: int = 5) -> List[dic # parameters, nor given by any providers, so it will be considered flat. # TODO: Should we have a check that DirectBeam cannot be computed from the # pipeline? - iofq_full = pipeline_full.compute(BackgroundSubtractedIofQ) - iofq_bands = pipeline_bands.compute(BackgroundSubtractedIofQ) + iofq = pipeline.compute(BackgroundSubtractedIofQ) + iofq_full = iofq['band', -1] + iofq_bands = iofq['band', :-1] if direct_beam_function is None: # Make a flat direct beam dims = [dim for dim in iofq_bands.dims if dim != 'Q'] direct_beam_function = sc.DataArray( data=sc.ones(sizes={dim: iofq_bands.sizes[dim] for dim in dims}), - coords={ - band_dim: sc.midpoints(wavelength_bands, dim='wavelength').squeeze() - }, + coords={band_dim: sc.midpoints(bands, dim='wavelength').squeeze()}, ).rename({band_dim: 'wavelength'}) direct_beam_function *= _compute_efficiency_correction( @@ -125,9 +151,8 @@ def direct_beam(pipeline: Pipeline, I0: sc.Variable, niter: int = 5) -> List[dic I0=I0, ) - # Insert new direct beam function into pipelines - pipeline_bands[DirectBeam] = direct_beam_function - pipeline_full[DirectBeam] = direct_beam_function + # Insert new direct beam function into pipeline + pipeline[DirectBeam] = direct_beam_function results.append( { diff --git a/src/esssans/i_of_q.py b/src/esssans/i_of_q.py index 8c3a1b77..d7f4c817 100644 --- a/src/esssans/i_of_q.py +++ b/src/esssans/i_of_q.py @@ -2,8 +2,10 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from typing import Dict, List, Optional, Union +from uuid import uuid4 import scipp as sc +from scipp.core.concepts import irreducible_mask from scipp.scipy.interpolate import interp1d from .common import mask_range @@ -276,9 +278,38 @@ def _dense_merge_spectra( edges = _to_q_bins(q_bins) bands = [] band_dim = (set(wavelength_bands.dims) - {'wavelength'}).pop() + + # We want to flatten data to make histogramming cheaper (avoiding allocation of + # large output before summing). We strip unnecessary content since it makes + # flattening more expensive. + stripped = data_q.copy(deep=False) + for name, coord in data_q.coords.items(): + if name not in ['Q', 'wavelength'] and any( + [dim in dims_to_reduce for dim in coord.dims] + ): + del stripped.coords[name] + to_flatten = [dim for dim in data_q.dims if dim in dims_to_reduce] + + dummy_dim = str(uuid4()) + flat = stripped.flatten(dims=to_flatten, to=dummy_dim) + + # Apply masks once, to avoid repeated work when iterating over bands + mask = irreducible_mask(flat, dummy_dim) + # When not all dims are reduced there may be extra dims in the mask and it is not + # possible to select data based on it. In this case the masks will be applied + # in the loop below, which is slightly slower. + if mask.ndim == 1: + flat = flat.drop_masks( + [name for name, mask in flat.masks.items() if dummy_dim in mask.dims] + ) + flat = flat[~mask] + + dims_to_reduce = tuple(dim for dim in dims_to_reduce if dim not in to_flatten) for wav_range in sc.collapse(wavelength_bands, keep='wavelength').values(): - band = data_q['wavelength', wav_range[0] : wav_range[1]] - bands.append(band.hist(**edges).sum(dims_to_reduce)) + band = flat['wavelength', wav_range[0] : wav_range[1]] + # By flattening before histogramming we avoid allocating a large output array, + # which would then require summing over all pixels. + bands.append(band.flatten(dims=(dummy_dim, 'Q'), to='Q').hist(**edges)) return sc.concat(bands, band_dim) diff --git a/src/esssans/normalization.py b/src/esssans/normalization.py index de6d4bfa..ce9682d2 100644 --- a/src/esssans/normalization.py +++ b/src/esssans/normalization.py @@ -210,6 +210,11 @@ def iofq_norm_wavelength_term( """ out = incident_monitor * transmission_fraction if direct_beam is not None: + # Make wavelength the inner dim + dims = list(direct_beam.dims) + dims.remove('wavelength') + dims.append('wavelength') + direct_beam = direct_beam.transpose(dims) broadcast = _broadcasters[uncertainties] out = direct_beam * broadcast(out, sizes=direct_beam.sizes) # Convert wavelength coordinate to midpoints for future histogramming