Skip to content

Commit

Permalink
Update PR to current main
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Apr 28, 2023
1 parent 0c08638 commit 7499384
Showing 1 changed file with 44 additions and 29 deletions.
73 changes: 44 additions & 29 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import abc
import contextvars as cvars
import copy
import dataclasses
import itertools
Expand Down Expand Up @@ -182,9 +183,10 @@ def field_setitem(self, indices: FieldIndexOrIndices, value: Any) -> None:
...


# TODO see https://github.com/GridTools/gt4py/pull/1120
_column_range: Optional[range] = None
_offset_provider: Optional[OffsetProvider] = None
context_column_range: cvars.ContextVar[range] = cvars.ContextVar("context_column_range")
context_offset_provider: cvars.ContextVar[OffsetProvider] = cvars.ContextVar(
"context_offset_provider"
)


class Column(np.lib.mixins.NDArrayOperatorsMixin):
Expand All @@ -197,7 +199,8 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None:
self.kstart = kstart
assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug
self.data = data if isinstance(data, np.ndarray) else np.full(len(_column_range), data) # type: ignore[arg-type]
_column_range = context_column_range.get()
self.data = data if isinstance(data, np.ndarray) else np.full(len(_column_range), data)

def __getitem__(self, i: int) -> Any:
result = self.data[i - self.kstart]
Expand Down Expand Up @@ -720,6 +723,7 @@ def _make_tuple(
*,
column_axis: Optional[Tag] = None,
) -> Column | npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike, ...]:
_column_range = context_column_range.get()
if isinstance(field_or_tuple, tuple):
if column_axis is not None:
assert _column_range
Expand Down Expand Up @@ -768,6 +772,7 @@ class MDIterator:

def shift(self, *offsets: OffsetPart) -> MDIterator:
complete_offsets = group_offsets(*offsets)
_offset_provider = context_offset_provider.get()
assert _offset_provider is not None
return MDIterator(
self.field,
Expand All @@ -792,6 +797,7 @@ def deref(self) -> Any:
if not all(axis.value in shifted_pos.keys() for axis in axes if axis is not None):
raise IndexError("Iterator position doesn't point to valid location for its field.")
slice_column = dict[Tag, range]()
_column_range = context_column_range.get()
if self.column_axis is not None:
assert _column_range is not None
k_pos = shifted_pos.pop(self.column_axis)
Expand Down Expand Up @@ -834,6 +840,7 @@ def make_in_iterator(
init = [None] * sparse_dimensions.count(sparse_dim)
new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused
if column_axis is not None:
_column_range = context_column_range.get()
# if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted
assert _column_range is not None
new_pos[column_axis] = _column_range.start
Expand Down Expand Up @@ -1083,6 +1090,7 @@ def __getitem__(self, _):
def neighbors(offset: runtime.Offset, it: ItIterator) -> _List:
offset_str = offset.value if isinstance(offset, runtime.Offset) else offset
assert isinstance(offset_str, str)
_offset_provider = context_offset_provider.get()
assert _offset_provider is not None
connectivity = _offset_provider[offset_str]
assert isinstance(connectivity, common.Connectivity)
Expand Down Expand Up @@ -1141,6 +1149,7 @@ class SparseListIterator:
offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True)

def deref(self) -> Any:
_offset_provider = context_offset_provider.get()
assert _offset_provider is not None
connectivity = _offset_provider[self.list_offset]
assert isinstance(connectivity, common.Connectivity)
Expand Down Expand Up @@ -1290,6 +1299,7 @@ def _column_dtype(elem: Any) -> np.dtype:
@builtins.scan.register(EMBEDDED)
def scan(scan_pass, is_forward: bool, init):
def impl(*iters: ItIterator):
_column_range = context_column_range.get()
if _column_range is None:
raise RuntimeError("Column range is not defined, cannot scan.")

Expand Down Expand Up @@ -1321,7 +1331,6 @@ def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any):
if "offset_provider" not in kwargs:
raise RuntimeError("offset_provider not provided")

global _offset_provider
_offset_provider = kwargs["offset_provider"]

@runtime.closure.register(EMBEDDED)
Expand All @@ -1336,7 +1345,7 @@ def closure(
if not (is_located_field(out) or can_be_tuple_field(out)):
raise TypeError("Out needs to be a located field.")

global _column_range
_column_range = None
column: Optional[ColumnDescriptor] = None
if kwargs.get("column_axis") and kwargs["column_axis"].value in domain:
column_axis = kwargs["column_axis"]
Expand All @@ -1347,31 +1356,37 @@ def closure(

out = as_tuple_field(out) if can_be_tuple_field(out) else out

for pos in _domain_iterator(domain):
promoted_ins = [promote_scalars(inp) for inp in ins]
ins_iters = list(
make_in_iterator(
inp,
pos,
column_axis=column.axis if column else None,
def _closure_runner():
# Set context variables before executing the closure
context_column_range.set(_column_range)
context_offset_provider.set(_offset_provider)

for pos in _domain_iterator(domain):
promoted_ins = [promote_scalars(inp) for inp in ins]
ins_iters = list(
make_in_iterator(
inp,
pos,
column_axis=column.axis if column else None,
)
for inp in promoted_ins
)
for inp in promoted_ins
)
res = sten(*ins_iters)
res = sten(*ins_iters)

if column is None:
assert _is_concrete_position(pos)
ordered_indices = get_ordered_indices(out.axes, pos)
out.field_setitem(ordered_indices, res)
else:
col_pos = pos.copy()
for k in column.col_range:
col_pos[column.axis] = k
assert _is_concrete_position(col_pos)
ordered_indices = get_ordered_indices(out.axes, col_pos)
out.field_setitem(ordered_indices, res[k])

_column_range = None
if column is None:
assert _is_concrete_position(pos)
ordered_indices = get_ordered_indices(out.axes, pos)
out.field_setitem(ordered_indices, res)
else:
col_pos = pos.copy()
for k in column.col_range:
col_pos[column.axis] = k
assert _is_concrete_position(col_pos)
ordered_indices = get_ordered_indices(out.axes, col_pos)
out.field_setitem(ordered_indices, res[k])

ctx = cvars.copy_context()
ctx.run(_closure_runner)

fun(*args)

Expand Down

0 comments on commit 7499384

Please sign in to comment.