Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adhere to stricter handling of constraints on typevars by Sciline #75

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/esssans/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
QxyBins,
RawMonitor,
RunType,
ScatteringRunType,
WavelengthMask,
WavelengthMonitor,
)
Expand Down Expand Up @@ -204,8 +205,8 @@ def monitor_to_wavelength(


def calibrate_positions(
detector: MaskedData[RunType], beam_center: BeamCenter
) -> CalibratedMaskedData[RunType]:
detector: MaskedData[ScatteringRunType], beam_center: BeamCenter
) -> CalibratedMaskedData[ScatteringRunType]:
"""
Calibrate pixel positions.

Expand All @@ -220,17 +221,17 @@ def calibrate_positions(
# for RawData, MaskedData, ... no reason to restrict necessarily.
# Would we be fine with just choosing on option, or will this get in the way for users?
def detector_to_wavelength(
detector: CalibratedMaskedData[RunType],
detector: CalibratedMaskedData[ScatteringRunType],
graph: ElasticCoordTransformGraph,
) -> CleanWavelength[RunType, Numerator]:
return CleanWavelength[RunType, Numerator](
) -> CleanWavelength[ScatteringRunType, Numerator]:
return CleanWavelength[ScatteringRunType, Numerator](
detector.transform_coords('wavelength', graph=graph)
)


def mask_wavelength(
da: CleanWavelength[RunType, IofQPart], mask: Optional[WavelengthMask]
) -> CleanWavelengthMasked[RunType, IofQPart]:
da: CleanWavelength[ScatteringRunType, IofQPart], mask: Optional[WavelengthMask]
) -> CleanWavelengthMasked[ScatteringRunType, IofQPart]:
if mask is not None:
# If we have binned data and the wavelength coord is multi-dimensional, we need
# to make a single wavelength bin before we can mask the range.
Expand All @@ -239,19 +240,19 @@ def mask_wavelength(
if (dim in da.bins.coords) and (dim in da.coords):
da = da.bin({dim: 1})
da = mask_range(da, mask=mask)
return CleanWavelengthMasked[RunType, IofQPart](da)
return CleanWavelengthMasked[ScatteringRunType, IofQPart](da)


def compute_Q(
data: CleanWavelengthMasked[RunType, IofQPart],
data: CleanWavelengthMasked[ScatteringRunType, IofQPart],
graph: ElasticCoordTransformGraph,
compute_Qxy: Optional[QxyBins],
) -> CleanQ[RunType, IofQPart]:
) -> CleanQ[ScatteringRunType, IofQPart]:
"""
Convert a data array from wavelength to Q.
"""
YooSunYoung marked this conversation as resolved.
Show resolved Hide resolved
# Keep naming of wavelength dim, subsequent steps use a (Q[xy], wavelength) binning.
return CleanQ[RunType, IofQPart](
return CleanQ[ScatteringRunType, IofQPart](
data.transform_coords(
('Qx', 'Qy') if compute_Qxy else 'Q',
graph=graph,
Expand Down
7 changes: 4 additions & 3 deletions src/esssans/i_of_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ReturnEvents,
RunType,
SampleRun,
ScatteringRunType,
UncertaintyBroadcastMode,
WavelengthBins,
WavelengthMonitor,
Expand Down Expand Up @@ -132,11 +133,11 @@ def resample_direct_beam(


def merge_spectra(
data: CleanQ[RunType, IofQPart],
data: CleanQ[ScatteringRunType, IofQPart],
q_bins: Optional[QBins],
qxy_bins: Optional[QxyBins],
dims_to_keep: Optional[DimsToKeep],
) -> CleanSummedQ[RunType, IofQPart]:
) -> CleanSummedQ[ScatteringRunType, IofQPart]:
"""
Merges all spectra:

Expand Down Expand Up @@ -210,7 +211,7 @@ def merge_spectra(
.group(*[flat.coords[dim] for dim in flat.dims if dim != helper_dim])
.hist(**edges)
)
return CleanSummedQ[RunType, IofQPart](out.squeeze())
return CleanSummedQ[ScatteringRunType, IofQPart](out.squeeze())


def subtract_background(
Expand Down
10 changes: 5 additions & 5 deletions src/esssans/isis/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import sciline
import scipp as sc

from ..types import RawData, RunType
from ..types import RawData, ScatteringRunType


class RawDataWithComponentUserOffsets(
sciline.Scope[RunType, sc.DataArray], sc.DataArray
sciline.Scope[ScatteringRunType, sc.DataArray], sc.DataArray
):
"""Raw data with applied user configuration for component positions."""

Expand All @@ -19,10 +19,10 @@ class RawDataWithComponentUserOffsets(


def apply_component_user_offsets_to_raw_data(
data: RawData[RunType],
data: RawData[ScatteringRunType],
sample_offset: SampleOffset,
detector_bank_offset: DetectorBankOffset,
) -> RawDataWithComponentUserOffsets[RunType]:
) -> RawDataWithComponentUserOffsets[ScatteringRunType]:
"""Apply user configuration to raw data.

Parameters
Expand All @@ -41,4 +41,4 @@ def apply_component_user_offsets_to_raw_data(
)
pos = data.coords['position']
data.coords['position'] = pos + detector_bank_offset.to(unit=pos.unit, copy=False)
return RawDataWithComponentUserOffsets[RunType](data)
return RawDataWithComponentUserOffsets[ScatteringRunType](data)
8 changes: 4 additions & 4 deletions src/esssans/isis/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sciline
import scipp as sc

from ..types import MaskedData, RawData, RunType, SampleRun
from ..types import MaskedData, RawData, SampleRun, ScatteringRunType
from .components import RawDataWithComponentUserOffsets
from .io import MaskedDetectorIDs

Expand All @@ -30,9 +30,9 @@ def to_pixel_mask(data: RawData[SampleRun], masked: MaskedDetectorIDs) -> PixelM


def apply_pixel_masks(
data: RawDataWithComponentUserOffsets[RunType],
data: RawDataWithComponentUserOffsets[ScatteringRunType],
masks: sciline.Series[str, PixelMask],
) -> MaskedData[RunType]:
) -> MaskedData[ScatteringRunType]:
"""Apply pixel-specific masks to raw data.

This depends on the configured raw data (which has been configured with component
Expand All @@ -49,7 +49,7 @@ def apply_pixel_masks(
data = data.copy(deep=False)
for name, mask in masks.items():
data.masks[name] = mask
return MaskedData[RunType](data)
return MaskedData[ScatteringRunType](data)


providers = (
Expand Down
19 changes: 10 additions & 9 deletions src/esssans/loki/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RawData,
RawMonitor,
RunType,
ScatteringRunType,
TransformationPath,
Transmission,
)
Expand All @@ -35,10 +36,10 @@


def get_detector_data(
dg: LoadedFileContents[RunType], detector_name: NeXusDetectorName
) -> RawData[RunType]:
dg: LoadedFileContents[ScatteringRunType], detector_name: NeXusDetectorName
) -> RawData[ScatteringRunType]:
da = dg[NEXUS_INSTRUMENT_PATH][detector_name][f'{detector_name}_events']
return RawData[RunType](da)
return RawData[ScatteringRunType](da)


def get_monitor_data(
Expand All @@ -51,19 +52,19 @@ def get_monitor_data(


def detector_pixel_shape(
dg: LoadedFileContents[RunType], detector_name: NeXusDetectorName
) -> DetectorPixelShape[RunType]:
return DetectorPixelShape[RunType](
dg: LoadedFileContents[ScatteringRunType], detector_name: NeXusDetectorName
) -> DetectorPixelShape[ScatteringRunType]:
return DetectorPixelShape[ScatteringRunType](
dg[NEXUS_INSTRUMENT_PATH][detector_name]['pixel_shape']
)


def detector_lab_frame_transform(
dg: LoadedFileContents[RunType],
dg: LoadedFileContents[ScatteringRunType],
detector_name: NeXusDetectorName,
transform_path: TransformationPath,
) -> LabFrameTransform[RunType]:
return LabFrameTransform[RunType](
) -> LabFrameTransform[ScatteringRunType]:
return LabFrameTransform[ScatteringRunType](
dg[NEXUS_INSTRUMENT_PATH][detector_name][transform_path]
)

Expand Down
8 changes: 4 additions & 4 deletions src/esssans/loki/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
BeamStopRadius,
MaskedData,
RawData,
RunType,
SampleRun,
ScatteringRunType,
)

DetectorLowCountsStrawMask = NewType('DetectorLowCountsStrawMask', sc.Variable)
Expand Down Expand Up @@ -80,11 +80,11 @@ def detector_tube_edge_mask(


def mask_detectors(
da: RawData[RunType],
da: RawData[ScatteringRunType],
lowcounts_straw_mask: Optional[DetectorLowCountsStrawMask],
beam_stop_mask: Optional[DetectorBeamStopMask],
tube_edge_mask: Optional[DetectorTubeEdgeMask],
) -> MaskedData[RunType]:
) -> MaskedData[ScatteringRunType]:
"""Apply pixel-specific masks to raw data.

Parameters
Expand All @@ -105,7 +105,7 @@ def mask_detectors(
da.masks['beam_stop'] = beam_stop_mask
if tube_edge_mask is not None:
da.masks['tube_edges'] = tube_edge_mask
return MaskedData[RunType](da)
return MaskedData[ScatteringRunType](da)


providers = (
Expand Down
46 changes: 24 additions & 22 deletions src/esssans/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Numerator,
ProcessedWavelengthBands,
ReturnEvents,
RunType,
ScatteringRunType,
SolidAngle,
Transmission,
TransmissionFraction,
Expand All @@ -39,10 +39,10 @@


def solid_angle(
data: CalibratedMaskedData[RunType],
pixel_shape: DetectorPixelShape[RunType],
transform: LabFrameTransform[RunType],
) -> SolidAngle[RunType]:
data: CalibratedMaskedData[ScatteringRunType],
pixel_shape: DetectorPixelShape[ScatteringRunType],
transform: LabFrameTransform[ScatteringRunType],
) -> SolidAngle[ScatteringRunType]:
"""
Solid angle for cylindrical pixels.

Expand Down Expand Up @@ -80,7 +80,7 @@ def solid_angle(
radius=radius,
length=length,
)
return SolidAngle[RunType](
return SolidAngle[ScatteringRunType](
concepts.rewrap_reduced_data(
prototype=data, data=omega, dim=set(data.dims) - set(omega.dims)
)
Expand Down Expand Up @@ -127,11 +127,13 @@ def _approximate_solid_angle_for_cylinder_shaped_pixel_of_detector(


def transmission_fraction(
sample_incident_monitor: CleanMonitor[TransmissionRun[RunType], Incident],
sample_transmission_monitor: CleanMonitor[TransmissionRun[RunType], Transmission],
sample_incident_monitor: CleanMonitor[TransmissionRun[ScatteringRunType], Incident],
sample_transmission_monitor: CleanMonitor[
TransmissionRun[ScatteringRunType], Transmission
],
direct_incident_monitor: CleanMonitor[EmptyBeamRun, Incident],
direct_transmission_monitor: CleanMonitor[EmptyBeamRun, Transmission],
) -> TransmissionFraction[RunType]:
) -> TransmissionFraction[ScatteringRunType]:
"""
Approximation based on equations in
`CalculateTransmission <https://docs.mantidproject.org/v4.0.0/algorithms/CalculateTransmission-v1.html>`_
Expand Down Expand Up @@ -160,7 +162,7 @@ def transmission_fraction(
frac = (sample_transmission_monitor / direct_transmission_monitor) * (
direct_incident_monitor / sample_incident_monitor
)
return TransmissionFraction[RunType](frac)
return TransmissionFraction[ScatteringRunType](frac)


_broadcasters = {
Expand All @@ -171,11 +173,11 @@ def transmission_fraction(


def iofq_norm_wavelength_term(
incident_monitor: CleanMonitor[RunType, Incident],
transmission_fraction: TransmissionFraction[RunType],
incident_monitor: CleanMonitor[ScatteringRunType, Incident],
transmission_fraction: TransmissionFraction[ScatteringRunType],
direct_beam: Optional[CleanDirectBeam],
uncertainties: UncertaintyBroadcastMode,
) -> NormWavelengthTerm[RunType]:
) -> NormWavelengthTerm[ScatteringRunType]:
"""
Compute the wavelength-dependent contribution to the denominator term for the I(Q)
normalization.
Expand Down Expand Up @@ -224,14 +226,14 @@ def iofq_norm_wavelength_term(
out = direct_beam * broadcast(out, sizes=direct_beam.sizes)
# Convert wavelength coordinate to midpoints for future histogramming
out.coords['wavelength'] = sc.midpoints(out.coords['wavelength'])
return NormWavelengthTerm[RunType](out)
return NormWavelengthTerm[ScatteringRunType](out)


def iofq_denominator(
wavelength_term: NormWavelengthTerm[RunType],
solid_angle: SolidAngle[RunType],
wavelength_term: NormWavelengthTerm[ScatteringRunType],
solid_angle: SolidAngle[ScatteringRunType],
uncertainties: UncertaintyBroadcastMode,
) -> CleanWavelength[RunType, Denominator]:
) -> CleanWavelength[ScatteringRunType, Denominator]:
"""
Compute the denominator term for the I(Q) normalization.

Expand Down Expand Up @@ -297,7 +299,7 @@ def iofq_denominator(
""" # noqa: E501
broadcast = _broadcasters[uncertainties]
denominator = solid_angle * broadcast(wavelength_term, sizes=solid_angle.sizes)
return CleanWavelength[RunType, Denominator](denominator)
return CleanWavelength[ScatteringRunType, Denominator](denominator)


def process_wavelength_bands(
Expand Down Expand Up @@ -337,12 +339,12 @@ def process_wavelength_bands(


def normalize(
numerator: CleanSummedQ[RunType, Numerator],
denominator: CleanSummedQ[RunType, Denominator],
numerator: CleanSummedQ[ScatteringRunType, Numerator],
denominator: CleanSummedQ[ScatteringRunType, Denominator],
return_events: ReturnEvents,
uncertainties: UncertaintyBroadcastMode,
wavelength_bands: ProcessedWavelengthBands,
) -> IofQ[RunType]:
) -> IofQ[ScatteringRunType]:
"""
Perform normalization of counts as a function of Q.
If the numerator contains events, we use the sc.lookup function to perform the
Expand Down Expand Up @@ -406,7 +408,7 @@ def _reduce(da: sc.DataArray) -> sc.DataArray:
)
elif numerator.bins is not None:
numerator = numerator.hist()
return IofQ[RunType](numerator / denominator)
return IofQ[ScatteringRunType](numerator / denominator)


providers = (
Expand Down
Loading
Loading