Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster direct beam #63

Merged
merged 8 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions src/esssans/direct_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@
import scipp as sc
from sciline import Pipeline

from .types import BackgroundSubtractedIofQ, DirectBeam, WavelengthBands
from .types import (
BackgroundRun,
BackgroundSubtractedIofQ,
CleanMonitor,
CleanSummedQ,
DirectBeam,
Incident,
Numerator,
SampleRun,
SolidAngle,
TransmissionFraction,
WavelengthBands,
)


def _compute_efficiency_correction(
Expand Down Expand Up @@ -85,37 +97,51 @@ def direct_beam(pipeline: Pipeline, I0: sc.Variable, niter: int = 5) -> List[dic
"""

direct_beam_function = None
wavelength_bands = pipeline.compute(WavelengthBands)
band_dim = (set(wavelength_bands.dims) - {'wavelength'}).pop()
bands = pipeline.compute(WavelengthBands)
band_dim = (set(bands.dims) - {'wavelength'}).pop()

full_wavelength_range = sc.concat(
[wavelength_bands.min(), wavelength_bands.max()], dim='wavelength'
)
full_wavelength_range = sc.concat([bands.min(), bands.max()], dim='wavelength')

pipeline_bands = pipeline.copy()
pipeline_full = pipeline_bands.copy()
pipeline_full[WavelengthBands] = full_wavelength_range
pipeline = pipeline.copy()
# Append full wavelength range as extra band. This allows for running only a
# single pipeline to compute both the I(Q) in bands and the I(Q) for the full
# wavelength range.
pipeline[WavelengthBands] = sc.concat([bands, full_wavelength_range], dim=band_dim)

results = []

# Compute checkpoints to avoid recomputing the same things in every iteration
checkpoints = (
TransmissionFraction[SampleRun],
TransmissionFraction[BackgroundRun],
SolidAngle[SampleRun],
SolidAngle[BackgroundRun],
CleanMonitor[SampleRun, Incident],
CleanMonitor[BackgroundRun, Incident],
CleanSummedQ[SampleRun, Numerator],
CleanSummedQ[BackgroundRun, Numerator],
)

for key, result in pipeline.compute(checkpoints).items():
pipeline[key] = result

for it in range(niter):
print("Iteration", it)

# The first time we compute I(Q), the direct beam function is not in the
# parameters, nor given by any providers, so it will be considered flat.
# TODO: Should we have a check that DirectBeam cannot be computed from the
# pipeline?
iofq_full = pipeline_full.compute(BackgroundSubtractedIofQ)
iofq_bands = pipeline_bands.compute(BackgroundSubtractedIofQ)
iofq = pipeline.compute(BackgroundSubtractedIofQ)
iofq_full = iofq['band', -1]
iofq_bands = iofq['band', :-1]

if direct_beam_function is None:
# Make a flat direct beam
dims = [dim for dim in iofq_bands.dims if dim != 'Q']
direct_beam_function = sc.DataArray(
data=sc.ones(sizes={dim: iofq_bands.sizes[dim] for dim in dims}),
coords={
band_dim: sc.midpoints(wavelength_bands, dim='wavelength').squeeze()
},
coords={band_dim: sc.midpoints(bands, dim='wavelength').squeeze()},
).rename({band_dim: 'wavelength'})

direct_beam_function *= _compute_efficiency_correction(
Expand All @@ -125,9 +151,8 @@ def direct_beam(pipeline: Pipeline, I0: sc.Variable, niter: int = 5) -> List[dic
I0=I0,
)

# Insert new direct beam function into pipelines
pipeline_bands[DirectBeam] = direct_beam_function
pipeline_full[DirectBeam] = direct_beam_function
# Insert new direct beam function into pipeline
pipeline[DirectBeam] = direct_beam_function

results.append(
{
Expand Down
35 changes: 33 additions & 2 deletions src/esssans/i_of_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

from typing import Dict, List, Optional, Union
from uuid import uuid4

import scipp as sc
from scipp.core.concepts import irreducible_mask
from scipp.scipy.interpolate import interp1d

from .common import mask_range
Expand Down Expand Up @@ -276,9 +278,38 @@ def _dense_merge_spectra(
edges = _to_q_bins(q_bins)
bands = []
band_dim = (set(wavelength_bands.dims) - {'wavelength'}).pop()

# We want to flatten data to make histogramming cheaper (avoiding allocation of
# large output before summing). We strip unnecessary content since it makes
# flattening more expensive.
stripped = data_q.copy(deep=False)
for name, coord in data_q.coords.items():
if name not in ['Q', 'wavelength'] and any(
[dim in dims_to_reduce for dim in coord.dims]
):
del stripped.coords[name]
to_flatten = [dim for dim in data_q.dims if dim in dims_to_reduce]

dummy_dim = str(uuid4())
flat = stripped.flatten(dims=to_flatten, to=dummy_dim)

# Apply masks once, to avoid repeated work when iterating over bands
mask = irreducible_mask(flat, dummy_dim)
# When not all dims are reduced there may be extra dims in the mask and it is not
# possible to select data based on it. In this case the masks will be applied
# in the loop below, which is slightly slower.
if mask.ndim == 1:
flat = flat.drop_masks(
[name for name, mask in flat.masks.items() if dummy_dim in mask.dims]
)
flat = flat[~mask]

dims_to_reduce = tuple(dim for dim in dims_to_reduce if dim not in to_flatten)
for wav_range in sc.collapse(wavelength_bands, keep='wavelength').values():
band = data_q['wavelength', wav_range[0] : wav_range[1]]
bands.append(band.hist(**edges).sum(dims_to_reduce))
band = flat['wavelength', wav_range[0] : wav_range[1]]
# By flattening before histogramming we avoid allocating a large output array,
# which would then require summing over all pixels.
bands.append(band.flatten(dims=(dummy_dim, 'Q'), to='Q').hist(**edges))
return sc.concat(bands, band_dim)


Expand Down
5 changes: 5 additions & 0 deletions src/esssans/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ def iofq_norm_wavelength_term(
"""
out = incident_monitor * transmission_fraction
if direct_beam is not None:
# Make wavelength the inner dim
dims = list(direct_beam.dims)
dims.remove('wavelength')
dims.append('wavelength')
direct_beam = direct_beam.transpose(dims)
broadcast = _broadcasters[uncertainties]
out = direct_beam * broadcast(out, sizes=direct_beam.sizes)
# Convert wavelength coordinate to midpoints for future histogramming
Expand Down
Loading