Skip to content

Commit

Permalink
feat[next]: use context vars instead of global state in embedded iter…
Browse files Browse the repository at this point in the history
…ator execution (#1120)

Use context vars in the iterator embedded execution for `column_range` and `offset_provider` variables to isolate fencil execution contexts. Previously global variables were used, which could fail when executing several fencils simultaneously.

---------

Co-authored-by: Enrique G. Paredes <enriqueg@cscs.ch>
  • Loading branch information
havogt and egparedes committed May 8, 2023
1 parent ba907d9 commit 9241243
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 97 deletions.
117 changes: 66 additions & 51 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import abc
import contextvars as cvars
import copy
import dataclasses
import itertools
Expand Down Expand Up @@ -182,9 +183,10 @@ def field_setitem(self, indices: FieldIndexOrIndices, value: Any) -> None:
...


# TODO see https://github.com/GridTools/gt4py/pull/1120
_column_range: Optional[range] = None
_offset_provider: Optional[OffsetProvider] = None
#: Column range used in column mode (`column_axis != None`) in the current closure execution context.
column_range_cvar: cvars.ContextVar[range] = cvars.ContextVar("column_range")
#: Offset provider dict in the current closure execution context.
offset_provider_cvar: cvars.ContextVar[OffsetProvider] = cvars.ContextVar("offset_provider")


class Column(np.lib.mixins.NDArrayOperatorsMixin):
Expand All @@ -197,7 +199,8 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None:
self.kstart = kstart
assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug
self.data = data if isinstance(data, np.ndarray) else np.full(len(_column_range), data) # type: ignore[arg-type]
column_range = column_range_cvar.get()
self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range), data)

def __getitem__(self, i: int) -> Any:
result = self.data[i - self.kstart]
Expand Down Expand Up @@ -720,22 +723,23 @@ def _make_tuple(
*,
column_axis: Optional[Tag] = None,
) -> Column | npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike, ...]:
column_range = column_range_cvar.get()
if isinstance(field_or_tuple, tuple):
if column_axis is not None:
assert _column_range
assert column_range
# construct a Column of tuples
column_axis_idx = _axis_idx(_get_axes(field_or_tuple), column_axis)
if column_axis_idx is None:
column_axis_idx = -1 # field doesn't have the column index, e.g. ContantField
first = tuple(
_make_tuple(f, _single_vertical_idx(indices, column_axis_idx, _column_range.start))
_make_tuple(f, _single_vertical_idx(indices, column_axis_idx, column_range.start))
for f in field_or_tuple
)
col = Column(
_column_range.start, np.zeros(len(_column_range), dtype=_column_dtype(first))
column_range.start, np.zeros(len(column_range), dtype=_column_dtype(first))
)
col[0] = first
for i in _column_range[1:]:
for i in column_range[1:]:
col[i] = tuple(
_make_tuple(f, _single_vertical_idx(indices, column_axis_idx, i))
for f in field_or_tuple
Expand All @@ -747,8 +751,8 @@ def _make_tuple(
data = field_or_tuple.field_getitem(indices)
if column_axis is not None:
# wraps a vertical slice of an input field into a `Column`
assert _column_range is not None
return Column(_column_range.start, data)
assert column_range is not None
return Column(column_range.start, data)
else:
return data

Expand All @@ -768,10 +772,11 @@ class MDIterator:

def shift(self, *offsets: OffsetPart) -> MDIterator:
complete_offsets = group_offsets(*offsets)
assert _offset_provider is not None
offset_provider = offset_provider_cvar.get()
assert offset_provider is not None
return MDIterator(
self.field,
shift_position(self.pos, *complete_offsets, offset_provider=_offset_provider),
shift_position(self.pos, *complete_offsets, offset_provider=offset_provider),
column_axis=self.column_axis,
)

Expand All @@ -792,13 +797,14 @@ def deref(self) -> Any:
if not all(axis.value in shifted_pos.keys() for axis in axes if axis is not None):
raise IndexError("Iterator position doesn't point to valid location for its field.")
slice_column = dict[Tag, range]()
column_range = column_range_cvar.get()
if self.column_axis is not None:
assert _column_range is not None
assert column_range is not None
k_pos = shifted_pos.pop(self.column_axis)
assert isinstance(k_pos, int)
# the following range describes a range in the field
# (negative values are relative to the origin, not relative to the size)
slice_column[self.column_axis] = range(k_pos, k_pos + len(_column_range))
slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range))

assert _is_concrete_position(shifted_pos)
ordered_indices = get_ordered_indices(
Expand Down Expand Up @@ -834,9 +840,10 @@ def make_in_iterator(
init = [None] * sparse_dimensions.count(sparse_dim)
new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused
if column_axis is not None:
column_range = column_range_cvar.get()
# if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted
assert _column_range is not None
new_pos[column_axis] = _column_range.start
assert column_range is not None
new_pos[column_axis] = column_range.start
it = MDIterator(
inp,
new_pos,
Expand Down Expand Up @@ -1083,8 +1090,9 @@ def __getitem__(self, _):
def neighbors(offset: runtime.Offset, it: ItIterator) -> _List:
offset_str = offset.value if isinstance(offset, runtime.Offset) else offset
assert isinstance(offset_str, str)
assert _offset_provider is not None
connectivity = _offset_provider[offset_str]
offset_provider = offset_provider_cvar.get()
assert offset_provider is not None
connectivity = offset_provider[offset_str]
assert isinstance(connectivity, common.Connectivity)
return _List(
shifted.deref()
Expand Down Expand Up @@ -1141,8 +1149,9 @@ class SparseListIterator:
offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True)

def deref(self) -> Any:
assert _offset_provider is not None
connectivity = _offset_provider[self.list_offset]
offset_provider = offset_provider_cvar.get()
assert offset_provider is not None
connectivity = offset_provider[self.list_offset]
assert isinstance(connectivity, common.Connectivity)
return _List(
shifted.deref()
Expand Down Expand Up @@ -1290,13 +1299,14 @@ def _column_dtype(elem: Any) -> np.dtype:
@builtins.scan.register(EMBEDDED)
def scan(scan_pass, is_forward: bool, init):
def impl(*iters: ItIterator):
if _column_range is None:
column_range = column_range_cvar.get()
if column_range is None:
raise RuntimeError("Column range is not defined, cannot scan.")

column_range = _column_range if is_forward else reversed(_column_range)
sorted_column_range = column_range if is_forward else reversed(column_range)
state = init
col = Column(_column_range.start, np.zeros(len(_column_range), dtype=_column_dtype(init)))
for i in column_range:
col = Column(column_range.start, np.zeros(len(column_range), dtype=_column_dtype(init)))
for i in sorted_column_range:
state = scan_pass(state, *map(shifted_scan_arg(i), iters))
col[i] = state

Expand All @@ -1321,8 +1331,7 @@ def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any):
if "offset_provider" not in kwargs:
raise RuntimeError("offset_provider not provided")

global _offset_provider
_offset_provider = kwargs["offset_provider"]
offset_provider = kwargs["offset_provider"]

@runtime.closure.register(EMBEDDED)
def closure(
Expand All @@ -1336,42 +1345,48 @@ def closure(
if not (is_located_field(out) or can_be_tuple_field(out)):
raise TypeError("Out needs to be a located field.")

global _column_range
column_range = None
column: Optional[ColumnDescriptor] = None
if kwargs.get("column_axis") and kwargs["column_axis"].value in domain:
column_axis = kwargs["column_axis"]
column = ColumnDescriptor(column_axis.value, domain[column_axis.value])
del domain[column_axis.value]

_column_range = column.col_range
column_range = column.col_range

out = as_tuple_field(out) if can_be_tuple_field(out) else out

for pos in _domain_iterator(domain):
promoted_ins = [promote_scalars(inp) for inp in ins]
ins_iters = list(
make_in_iterator(
inp,
pos,
column_axis=column.axis if column else None,
def _closure_runner():
# Set context variables before executing the closure
column_range_cvar.set(column_range)
offset_provider_cvar.set(offset_provider)

for pos in _domain_iterator(domain):
promoted_ins = [promote_scalars(inp) for inp in ins]
ins_iters = list(
make_in_iterator(
inp,
pos,
column_axis=column.axis if column else None,
)
for inp in promoted_ins
)
for inp in promoted_ins
)
res = sten(*ins_iters)
res = sten(*ins_iters)

if column is None:
assert _is_concrete_position(pos)
ordered_indices = get_ordered_indices(out.axes, pos)
out.field_setitem(ordered_indices, res)
else:
col_pos = pos.copy()
for k in column.col_range:
col_pos[column.axis] = k
assert _is_concrete_position(col_pos)
ordered_indices = get_ordered_indices(out.axes, col_pos)
out.field_setitem(ordered_indices, res[k])

_column_range = None
if column is None:
assert _is_concrete_position(pos)
ordered_indices = get_ordered_indices(out.axes, pos)
out.field_setitem(ordered_indices, res)
else:
col_pos = pos.copy()
for k in column.col_range:
col_pos[column.axis] = k
assert _is_concrete_position(col_pos)
ordered_indices = get_ordered_indices(out.axes, col_pos)
out.field_setitem(ordered_indices, res[k])

ctx = cvars.copy_context()
ctx.run(_closure_runner)

fun(*args)

Expand Down
Loading

0 comments on commit 9241243

Please sign in to comment.