Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next] enable field origin in GTFN backend #1277

Merged
merged 8 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def axes(self) -> tuple[common.Dimension, ...]:
def field_getitem(self, indices: FieldIndexOrIndices) -> Any:
...

@property
def __gt_origin__(self) -> tuple[int, ...]:
return tuple([0] * len(self.axes))


class MutableLocatedField(LocatedField, Protocol):
"""A LocatedField with write access."""
Expand Down Expand Up @@ -887,12 +891,14 @@ def __init__(
*,
setter: Callable[[FieldIndexOrIndices, Any], None],
array: Callable[[], npt.NDArray],
origin: Optional[dict[common.Dimension, int]] = None,
):
self.getter = getter
self._axes = axes
self.setter = setter
self.array = array
self.dtype = dtype
self.origin = origin

def __getitem__(self, indices: ArrayIndexOrIndices) -> Any:
return self.array()[indices]
Expand All @@ -911,6 +917,14 @@ def field_setitem(self, indices: FieldIndexOrIndices, value: Any):
def __array__(self) -> np.ndarray:
return self.array()

@property
def __gt_origin__(self) -> tuple[int, ...]:
if not self.origin:
return tuple([0] * len(self.axes))
return cast(
tuple[int], get_ordered_indices(self.axes, {k.value: v for k, v in self.origin.items()})
)

@property
def shape(self):
if self.array is None:
Expand Down Expand Up @@ -1027,6 +1041,7 @@ def getter(indices):
dtype=a.dtype,
setter=setter,
array=a.__array__,
origin=origin,
)

return _maker
Expand Down
26 changes: 14 additions & 12 deletions src/gt4py/next/otf/binding/pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def _type_string(type_: ts.TypeSpec) -> str:
if isinstance(type_, ts.TupleType):
return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>"
elif isinstance(type_, ts.FieldType):
return "pybind11::buffer"
ndims = len(type_.dims)
buffer_t = "pybind11::buffer"
origin_t = f"std::tuple<{', '.join(['ptrdiff_t'] * ndims)}>"
return f"std::pair<{buffer_t}, {origin_t}>"
elif isinstance(type_, ts.ScalarType):
return cpp_interface.render_scalar_type(type_)
else:
Expand Down Expand Up @@ -142,17 +145,16 @@ def visit_FunctionCall(self, call: FunctionCall):
return cpp_interface.render_function_call(call.target, args)

def visit_BufferSID(self, sid: BufferSID, **kwargs):
return self.generic_visit(
sid, rendered_scalar_type=cpp_interface.render_scalar_type(sid.scalar_type)
)

BufferSID = as_jinja(
"""gridtools::sid::rename_numbered_dimensions<{{", ".join(dimensions)}}>(
gridtools::as_sid<{{rendered_scalar_type}},\
{{dimensions.__len__()}},\
gridtools::sid::unknown_kind>({{source_buffer}})
)"""
)
pybuffer = f"{sid.source_buffer}.first"
dims = [self.visit(dim) for dim in sid.dimensions]
origin = f"{sid.source_buffer}.second"

as_sid = f"gridtools::as_sid<{cpp_interface.render_scalar_type(sid.scalar_type)},\
{sid.dimensions.__len__()},\
gridtools::sid::unknown_kind>({pybuffer})"
shifted = f"gridtools::sid::shift_sid_origin({as_sid}, {origin})"
renamed = f"gridtools::sid::rename_numbered_dimensions<{', '.join(dims)}>({shifted})"
return renamed

def visit_CompositeSID(self, node: CompositeSID, **kwargs):
kwargs["composite_ids"] = (
Expand Down
18 changes: 11 additions & 7 deletions src/gt4py/next/program_processors/runners/gtfn_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
def convert_arg(arg: Any) -> Any:
if isinstance(arg, tuple):
return tuple(convert_arg(a) for a in arg)
if hasattr(arg, "__array__"):
return np.asarray(arg)
if hasattr(arg, "__array__") and hasattr(arg, "axes"):
arr = np.asarray(arg)
petiaccja marked this conversation as resolved.
Show resolved Hide resolved
origin = getattr(arg, "__gt_origin__", tuple([0] * arr.ndim))
return arr, origin
else:
return arg

Expand All @@ -42,26 +44,28 @@ def convert_args(inp: stages.CompiledProgram) -> stages.CompiledProgram:
def decorated_program(
*args, offset_provider: dict[str, common.Connectivity | common.Dimension]
):
converted_args = [convert_arg(arg) for arg in args]
conn_args = extract_connectivity_args(offset_provider)
return inp(
*[convert_arg(arg) for arg in args],
*extract_connectivity_args(offset_provider),
*converted_args,
*conn_args,
)

return decorated_program


def extract_connectivity_args(
offset_provider: dict[str, common.Connectivity | common.Dimension]
) -> list[npt.NDArray]:
) -> list[tuple[npt.NDArray, tuple[int, ...]]]:
# note: the order here needs to agree with the order of the generated bindings
args: list[npt.NDArray] = []
args: list[tuple[npt.NDArray, tuple[int, ...]]] = []
for name, conn in offset_provider.items():
if isinstance(conn, common.Connectivity):
if not isinstance(conn, common.NeighborTable):
raise NotImplementedError(
"Only `NeighborTable` connectivities implemented at this point."
)
args.append(conn.table)
args.append((conn.table, tuple([0] * 2)))
elif isinstance(conn, common.Dimension):
pass
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def baz(baz_inp):
def test_trivial(program_processor, lift_mode):
program_processor, validate = program_processor

if program_processor == run_gtfn:
pytest.xfail("origin not yet supported in gtfn")

rng = np.random.default_rng()
inp = rng.uniform(size=(5, 7, 9))
out = np.copy(inp)
Expand Down Expand Up @@ -80,9 +77,6 @@ def stencil_shifted_arg_to_lift(inp):
def test_shifted_arg_to_lift(program_processor, lift_mode):
program_processor, validate = program_processor

if program_processor == run_gtfn:
pytest.xfail("origin not yet supported in gtfn")

if lift_mode != transforms.LiftMode.FORCE_INLINE:
pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ def test_gtfn_cpp_with_cmake(program_source_with_name):
),
)
compiled_program = build_the_program(example_program_source)
buf = np.zeros(shape=(6, 5), dtype=np.float32)
tup = [np.zeros(shape=(6, 5), dtype=np.float32), np.zeros(shape=(6, 5), dtype=np.float32)]
buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0))
tup = [
(np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)),
(np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)),
]
sc = np.float32(3.1415926)
res = compiled_program(buf, tup, sc)
assert math.isclose(res, 6 * 5 * 3.1415926, rel_tol=1e-4)
Expand All @@ -50,8 +53,11 @@ def test_gtfn_cpp_with_compiledb(program_source_with_name):
),
)
compiled_program = build_the_program(example_program_source)
buf = np.zeros(shape=(6, 5), dtype=np.float32)
tup = [np.zeros(shape=(6, 5), dtype=np.float32), np.zeros(shape=(6, 5), dtype=np.float32)]
buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0))
tup = [
(np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)),
(np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)),
]
sc = np.float32(3.1415926)
res = compiled_program(buf, tup, sc)
assert math.isclose(res, 6 * 5 * 3.1415926, rel_tol=1e-4)
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,23 @@ def shift_stencil(inp):

@pytest.fixture(
params=[
# (stencil, reference_function, inp_fun (None=default), (skip_backend_fun, msg))
(add_scalar, lambda inp: np.asarray(inp) + 1.0, None, None),
(if_scalar_cond, lambda inp: np.asarray(inp), None, None),
(if_scalar_return, lambda inp: np.ones_like(inp), None, None),
# (stencil, reference_function, inp_fun (None=default)
(add_scalar, lambda inp: np.asarray(inp) + 1.0, None),
(if_scalar_cond, lambda inp: np.asarray(inp), None),
(if_scalar_return, lambda inp: np.ones_like(inp), None),
(
shift_stencil,
lambda inp: np.asarray(inp)[1:, 1:],
lambda shape: gtx.np_as_located_field(IDim, KDim)(
np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 1])
),
None,
),
(
shift_stencil,
lambda inp: np.asarray(inp)[1:, 2:],
lambda shape: gtx.np_as_located_field(IDim, KDim, origin={IDim: 0, KDim: 1})(
np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 2])
),
(
lambda backend: backend == run_gtfn or backend == run_gtfn_imperative,
"origin not supported in gtfn",
),
),
],
ids=lambda p: f"{p[0].__name__}",
Expand All @@ -85,11 +80,7 @@ def basic_stencils(request):

def test_basic_column_stencils(program_processor, lift_mode, basic_stencils):
program_processor, validate = program_processor
stencil, ref_fun, inp_fun, skip_backend = basic_stencils
if skip_backend is not None:
skip_backend_fun, msg = skip_backend
if skip_backend_fun(program_processor):
pytest.xfail(msg)
stencil, ref_fun, inp_fun = basic_stencils

shape = [5, 7]
inp = (
Expand Down Expand Up @@ -327,8 +318,6 @@ def sum_fencil(out, inp0, inp1, k_size):

def test_different_vertical_sizes_with_origin(program_processor):
program_processor, validate = program_processor
if program_processor in [run_gtfn, run_gtfn_imperative]:
pytest.xfail("origin not supported in gtfn")

k_size = 10
inp0 = gtx.np_as_located_field(KDim)(np.asarray(list(range(k_size))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def hdiff(inp, coeff, out, x, y):
def test_hdiff(hdiff_reference, program_processor, lift_mode):
program_processor, validate = program_processor
if program_processor == run_gtfn or program_processor == run_gtfn_imperative:
pytest.xfail("origin not yet supported in gtfn")
from gt4py.next.iterator import transforms

if lift_mode != transforms.LiftMode.FORCE_INLINE:
pytest.xfail("there is an issue with temporaries that crashes the application")

inp, coeff, out = hdiff_reference
shape = (out.shape[0], out.shape[1])
Expand Down
38 changes: 23 additions & 15 deletions tests/next_tests/unit_tests/otf_tests/binding_tests/test_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_bindings(program_source_example):
program_source_example.language_settings,
"""\
#include "stencil.cpp.inc"

#include <gridtools/common/defs.hpp>
#include <gridtools/common/tuple_util.hpp>
#include <gridtools/fn/backend/naive.hpp>
Expand All @@ -37,28 +37,36 @@ def test_bindings(program_source_example):
#include <gridtools/storage/adapter/python_sid_adapter.hpp>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

decltype(auto)
stencil_wrapper(pybind11::buffer buf,
std::tuple<pybind11::buffer, pybind11::buffer> tup, float sc) {

decltype(auto) stencil_wrapper(
std::pair<pybind11::buffer, std::tuple<ptrdiff_t, ptrdiff_t>> buf,
std::tuple<std::pair<pybind11::buffer, std::tuple<ptrdiff_t, ptrdiff_t>>,
std::pair<pybind11::buffer, std::tuple<ptrdiff_t, ptrdiff_t>>>
tup,
float sc) {
return stencil(
gridtools::sid::rename_numbered_dimensions<generated::I_t,
generated::J_t>(
gridtools::as_sid<float, 2, gridtools::sid::unknown_kind>(buf)),
gridtools::sid::rename_numbered_dimensions<
generated::I_t, generated::J_t>(gridtools::sid::shift_sid_origin(
gridtools::as_sid<float, 2, gridtools::sid::unknown_kind>(buf.first),
buf.second)),
gridtools::sid::composite::keys<gridtools::integral_constant<int, 0>,
gridtools::integral_constant<int, 1>>::
make_values(
gridtools::sid::rename_numbered_dimensions<generated::I_t,
generated::J_t>(
gridtools::as_sid<
float, 2, gridtools::sid::unknown_kind>(gridtools::tuple_util::get<0>(tup))),
generated::J_t>(
gridtools::sid::shift_sid_origin(
gridtools::as_sid<float, 2, gridtools::sid::unknown_kind>(
gridtools::tuple_util::get<0>(tup).first),
gridtools::tuple_util::get<0>(tup).second)),
gridtools::sid::rename_numbered_dimensions<generated::I_t,
generated::J_t>(
gridtools::as_sid<
float, 2, gridtools::sid::unknown_kind>(gridtools::tuple_util::get<1>(tup)))),
generated::J_t>(
gridtools::sid::shift_sid_origin(
gridtools::as_sid<float, 2, gridtools::sid::unknown_kind>(
gridtools::tuple_util::get<1>(tup).first),
gridtools::tuple_util::get<1>(tup).second))),
sc);
}

PYBIND11_MODULE(stencil, module) {
module.doc() = "";
module.def("stencil", &stencil_wrapper, "");
Expand Down