diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index e3dd18a871..a8b011ce85 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -17,6 +17,7 @@ from __future__ import annotations import abc +import contextvars as cvars import copy import dataclasses import itertools @@ -182,9 +183,10 @@ def field_setitem(self, indices: FieldIndexOrIndices, value: Any) -> None: ... -# TODO see https://github.com/GridTools/gt4py/pull/1120 -_column_range: Optional[range] = None -_offset_provider: Optional[OffsetProvider] = None +context_column_range: cvars.ContextVar[range] = cvars.ContextVar("context_column_range") +context_offset_provider: cvars.ContextVar[OffsetProvider] = cvars.ContextVar( + "context_offset_provider" +) class Column(np.lib.mixins.NDArrayOperatorsMixin): @@ -197,7 +199,8 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin): def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: self.kstart = kstart assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug - self.data = data if isinstance(data, np.ndarray) else np.full(len(_column_range), data) # type: ignore[arg-type] + _column_range = context_column_range.get() + self.data = data if isinstance(data, np.ndarray) else np.full(len(_column_range), data) def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] @@ -720,6 +723,7 @@ def _make_tuple( *, column_axis: Optional[Tag] = None, ) -> Column | npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike, ...]: + _column_range = context_column_range.get() if isinstance(field_or_tuple, tuple): if column_axis is not None: assert _column_range @@ -768,6 +772,7 @@ class MDIterator: def shift(self, *offsets: OffsetPart) -> MDIterator: complete_offsets = group_offsets(*offsets) + _offset_provider = context_offset_provider.get() assert _offset_provider is not None return MDIterator( self.field, @@ -792,6 +797,7 @@ def deref(self) -> Any: if not all(axis.value in shifted_pos.keys() for axis in axes if axis is not None): raise IndexError("Iterator position doesn't point to valid location for its field.") slice_column = dict[Tag, range]() + _column_range = context_column_range.get() if self.column_axis is not None: assert _column_range is not None k_pos = shifted_pos.pop(self.column_axis) @@ -834,6 +840,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: + _column_range = context_column_range.get() # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert _column_range is not None new_pos[column_axis] = _column_range.start @@ -1083,6 +1090,7 @@ def __getitem__(self, _): def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: offset_str = offset.value if isinstance(offset, runtime.Offset) else offset assert isinstance(offset_str, str) + _offset_provider = context_offset_provider.get() assert _offset_provider is not None connectivity = _offset_provider[offset_str] assert isinstance(connectivity, common.Connectivity) @@ -1141,6 +1149,7 @@ class SparseListIterator: offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True) def deref(self) -> Any: + _offset_provider = context_offset_provider.get() assert _offset_provider is not None connectivity = _offset_provider[self.list_offset] assert isinstance(connectivity, common.Connectivity) @@ -1290,6 +1299,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): + _column_range = context_column_range.get() if _column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") @@ -1321,7 +1331,6 @@ def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any): if "offset_provider" not in kwargs: raise RuntimeError("offset_provider not provided") - global _offset_provider _offset_provider = kwargs["offset_provider"] @runtime.closure.register(EMBEDDED) @@ -1336,7 +1345,7 @@ def closure( if not (is_located_field(out) or can_be_tuple_field(out)): raise TypeError("Out needs to be a located field.") - global _column_range + _column_range = None column: Optional[ColumnDescriptor] = None if kwargs.get("column_axis") and kwargs["column_axis"].value in domain: column_axis = kwargs["column_axis"] @@ -1347,31 +1356,37 @@ def closure( out = as_tuple_field(out) if can_be_tuple_field(out) else out - for pos in _domain_iterator(domain): - promoted_ins = [promote_scalars(inp) for inp in ins] - ins_iters = list( - make_in_iterator( - inp, - pos, - column_axis=column.axis if column else None, + def _closure_runner(): + # Set context variables before executing the closure + context_column_range.set(_column_range) + context_offset_provider.set(_offset_provider) + + for pos in _domain_iterator(domain): + promoted_ins = [promote_scalars(inp) for inp in ins] + ins_iters = list( + make_in_iterator( + inp, + pos, + column_axis=column.axis if column else None, + ) + for inp in promoted_ins ) - for inp in promoted_ins - ) - res = sten(*ins_iters) + res = sten(*ins_iters) - if column is None: - assert _is_concrete_position(pos) - ordered_indices = get_ordered_indices(out.axes, pos) - out.field_setitem(ordered_indices, res) - else: - col_pos = pos.copy() - for k in column.col_range: - col_pos[column.axis] = k - assert _is_concrete_position(col_pos) - ordered_indices = get_ordered_indices(out.axes, col_pos) - out.field_setitem(ordered_indices, res[k]) - - _column_range = None + if column is None: + assert _is_concrete_position(pos) + ordered_indices = get_ordered_indices(out.axes, pos) + out.field_setitem(ordered_indices, res) + else: + col_pos = pos.copy() + for k in column.col_range: + col_pos[column.axis] = k + assert _is_concrete_position(col_pos) + ordered_indices = get_ordered_indices(out.axes, col_pos) + out.field_setitem(ordered_indices, res[k]) + + ctx = cvars.copy_context() + ctx.run(_closure_runner) fun(*args)