diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index e92cfb45c6..63bf11ddba 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -182,6 +182,10 @@ def axes(self) -> tuple[common.Dimension, ...]: def field_getitem(self, indices: FieldIndexOrIndices) -> Any: ... + @property + def __gt_origin__(self) -> tuple[int, ...]: + return tuple([0] * len(self.axes)) + class MutableLocatedField(LocatedField, Protocol): """A LocatedField with write access.""" @@ -887,12 +891,14 @@ def __init__( *, setter: Callable[[FieldIndexOrIndices, Any], None], array: Callable[[], npt.NDArray], + origin: Optional[dict[common.Dimension, int]] = None, ): self.getter = getter self._axes = axes self.setter = setter self.array = array self.dtype = dtype + self.origin = origin def __getitem__(self, indices: ArrayIndexOrIndices) -> Any: return self.array()[indices] @@ -911,6 +917,14 @@ def field_setitem(self, indices: FieldIndexOrIndices, value: Any): def __array__(self) -> np.ndarray: return self.array() + @property + def __gt_origin__(self) -> tuple[int, ...]: + if not self.origin: + return tuple([0] * len(self.axes)) + return cast( + tuple[int], get_ordered_indices(self.axes, {k.value: v for k, v in self.origin.items()}) + ) + @property def shape(self): if self.array is None: @@ -1027,6 +1041,7 @@ def getter(indices): dtype=a.dtype, setter=setter, array=a.__array__, + origin=origin, ) return _maker diff --git a/src/gt4py/next/otf/binding/pybind.py b/src/gt4py/next/otf/binding/pybind.py index 5ba229a8a8..82b06a31ae 100644 --- a/src/gt4py/next/otf/binding/pybind.py +++ b/src/gt4py/next/otf/binding/pybind.py @@ -89,7 +89,10 @@ def _type_string(type_: ts.TypeSpec) -> str: if isinstance(type_, ts.TupleType): return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>" elif isinstance(type_, ts.FieldType): - return "pybind11::buffer" + ndims = len(type_.dims) + buffer_t = "pybind11::buffer" + origin_t = f"std::tuple<{', '.join(['ptrdiff_t'] * ndims)}>" + return f"std::pair<{buffer_t}, {origin_t}>" elif isinstance(type_, ts.ScalarType): return cpp_interface.render_scalar_type(type_) else: @@ -142,17 +145,16 @@ def visit_FunctionCall(self, call: FunctionCall): return cpp_interface.render_function_call(call.target, args) def visit_BufferSID(self, sid: BufferSID, **kwargs): - return self.generic_visit( - sid, rendered_scalar_type=cpp_interface.render_scalar_type(sid.scalar_type) - ) - - BufferSID = as_jinja( - """gridtools::sid::rename_numbered_dimensions<{{", ".join(dimensions)}}>( - gridtools::as_sid<{{rendered_scalar_type}},\ - {{dimensions.__len__()}},\ - gridtools::sid::unknown_kind>({{source_buffer}}) - )""" - ) + pybuffer = f"{sid.source_buffer}.first" + dims = [self.visit(dim) for dim in sid.dimensions] + origin = f"{sid.source_buffer}.second" + + as_sid = f"gridtools::as_sid<{cpp_interface.render_scalar_type(sid.scalar_type)},\ + {sid.dimensions.__len__()},\ + gridtools::sid::unknown_kind>({pybuffer})" + shifted = f"gridtools::sid::shift_sid_origin({as_sid}, {origin})" + renamed = f"gridtools::sid::rename_numbered_dimensions<{', '.join(dims)}>({shifted})" + return renamed def visit_CompositeSID(self, node: CompositeSID, **kwargs): kwargs["composite_ids"] = ( diff --git a/src/gt4py/next/program_processors/runners/gtfn_cpu.py b/src/gt4py/next/program_processors/runners/gtfn_cpu.py index 3800380fbb..0e3d7851ce 100644 --- a/src/gt4py/next/program_processors/runners/gtfn_cpu.py +++ b/src/gt4py/next/program_processors/runners/gtfn_cpu.py @@ -32,8 +32,10 @@ def convert_arg(arg: Any) -> Any: if isinstance(arg, tuple): return tuple(convert_arg(a) for a in arg) - if hasattr(arg, "__array__"): - return np.asarray(arg) + if hasattr(arg, "__array__") and hasattr(arg, "axes"): + arr = np.asarray(arg) + origin = getattr(arg, "__gt_origin__", tuple([0] * arr.ndim)) + return arr, origin else: return arg @@ -42,9 +44,11 @@ def convert_args(inp: stages.CompiledProgram) -> stages.CompiledProgram: def decorated_program( *args, offset_provider: dict[str, common.Connectivity | common.Dimension] ): + converted_args = [convert_arg(arg) for arg in args] + conn_args = extract_connectivity_args(offset_provider) return inp( - *[convert_arg(arg) for arg in args], - *extract_connectivity_args(offset_provider), + *converted_args, + *conn_args, ) return decorated_program @@ -52,16 +56,16 @@ def decorated_program( def extract_connectivity_args( offset_provider: dict[str, common.Connectivity | common.Dimension] -) -> list[npt.NDArray]: +) -> list[tuple[npt.NDArray, tuple[int, ...]]]: # note: the order here needs to agree with the order of the generated bindings - args: list[npt.NDArray] = [] + args: list[tuple[npt.NDArray, tuple[int, ...]]] = [] for name, conn in offset_provider.items(): if isinstance(conn, common.Connectivity): if not isinstance(conn, common.NeighborTable): raise NotImplementedError( "Only `NeighborTable` connectivities implemented at this point." ) - args.append(conn.table) + args.append((conn.table, tuple([0] * 2))) elif isinstance(conn, common.Dimension): pass else: diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 4fb0d2fd7b..7cc4e95949 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -48,9 +48,6 @@ def baz(baz_inp): def test_trivial(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor == run_gtfn: - pytest.xfail("origin not yet supported in gtfn") - rng = np.random.default_rng() inp = rng.uniform(size=(5, 7, 9)) out = np.copy(inp) @@ -80,9 +77,6 @@ def stencil_shifted_arg_to_lift(inp): def test_shifted_arg_to_lift(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor == run_gtfn: - pytest.xfail("origin not yet supported in gtfn") - if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_pybind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_pybind_build.py index 4eee84c8cc..90b471dad1 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_pybind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_pybind_build.py @@ -34,8 +34,11 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): ), ) compiled_program = build_the_program(example_program_source) - buf = np.zeros(shape=(6, 5), dtype=np.float32) - tup = [np.zeros(shape=(6, 5), dtype=np.float32), np.zeros(shape=(6, 5), dtype=np.float32)] + buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) + tup = [ + (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), + (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), + ] sc = np.float32(3.1415926) res = compiled_program(buf, tup, sc) assert math.isclose(res, 6 * 5 * 3.1415926, rel_tol=1e-4) @@ -50,8 +53,11 @@ def test_gtfn_cpp_with_compiledb(program_source_with_name): ), ) compiled_program = build_the_program(example_program_source) - buf = np.zeros(shape=(6, 5), dtype=np.float32) - tup = [np.zeros(shape=(6, 5), dtype=np.float32), np.zeros(shape=(6, 5), dtype=np.float32)] + buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) + tup = [ + (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), + (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), + ] sc = np.float32(3.1415926) res = compiled_program(buf, tup, sc) assert math.isclose(res, 6 * 5 * 3.1415926, rel_tol=1e-4) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index aef6d5d8e1..e921c1e8b3 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -53,17 +53,16 @@ def shift_stencil(inp): @pytest.fixture( params=[ - # (stencil, reference_function, inp_fun (None=default), (skip_backend_fun, msg)) - (add_scalar, lambda inp: np.asarray(inp) + 1.0, None, None), - (if_scalar_cond, lambda inp: np.asarray(inp), None, None), - (if_scalar_return, lambda inp: np.ones_like(inp), None, None), + # (stencil, reference_function, inp_fun (None=default) + (add_scalar, lambda inp: np.asarray(inp) + 1.0, None), + (if_scalar_cond, lambda inp: np.asarray(inp), None), + (if_scalar_return, lambda inp: np.ones_like(inp), None), ( shift_stencil, lambda inp: np.asarray(inp)[1:, 1:], lambda shape: gtx.np_as_located_field(IDim, KDim)( np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 1]) ), - None, ), ( shift_stencil, @@ -71,10 +70,6 @@ def shift_stencil(inp): lambda shape: gtx.np_as_located_field(IDim, KDim, origin={IDim: 0, KDim: 1})( np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 2]) ), - ( - lambda backend: backend == run_gtfn or backend == run_gtfn_imperative, - "origin not supported in gtfn", - ), ), ], ids=lambda p: f"{p[0].__name__}", @@ -85,11 +80,7 @@ def basic_stencils(request): def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): program_processor, validate = program_processor - stencil, ref_fun, inp_fun, skip_backend = basic_stencils - if skip_backend is not None: - skip_backend_fun, msg = skip_backend - if skip_backend_fun(program_processor): - pytest.xfail(msg) + stencil, ref_fun, inp_fun = basic_stencils shape = [5, 7] inp = ( @@ -327,8 +318,6 @@ def sum_fencil(out, inp0, inp1, k_size): def test_different_vertical_sizes_with_origin(program_processor): program_processor, validate = program_processor - if program_processor in [run_gtfn, run_gtfn_imperative]: - pytest.xfail("origin not supported in gtfn") k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.asarray(list(range(k_size)))) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 13f7f11292..e144e6096a 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -74,7 +74,10 @@ def hdiff(inp, coeff, out, x, y): def test_hdiff(hdiff_reference, program_processor, lift_mode): program_processor, validate = program_processor if program_processor == run_gtfn or program_processor == run_gtfn_imperative: - pytest.xfail("origin not yet supported in gtfn") + from gt4py.next.iterator import transforms + + if lift_mode != transforms.LiftMode.FORCE_INLINE: + pytest.xfail("there is an issue with temporaries that crashes the application") inp, coeff, out = hdiff_reference shape = (out.shape[0], out.shape[1]) diff --git a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_pybind.py b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_pybind.py index df3de2fc8d..783ff8bdfd 100644 --- a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_pybind.py +++ b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_pybind.py @@ -25,7 +25,7 @@ def test_bindings(program_source_example): program_source_example.language_settings, """\ #include "stencil.cpp.inc" - + #include #include #include @@ -37,28 +37,36 @@ def test_bindings(program_source_example): #include #include #include - - decltype(auto) - stencil_wrapper(pybind11::buffer buf, - std::tuple tup, float sc) { + + decltype(auto) stencil_wrapper( + std::pair> buf, + std::tuple>, + std::pair>> + tup, + float sc) { return stencil( - gridtools::sid::rename_numbered_dimensions( - gridtools::as_sid(buf)), + gridtools::sid::rename_numbered_dimensions< + generated::I_t, generated::J_t>(gridtools::sid::shift_sid_origin( + gridtools::as_sid(buf.first), + buf.second)), gridtools::sid::composite::keys, gridtools::integral_constant>:: make_values( gridtools::sid::rename_numbered_dimensions( - gridtools::as_sid< - float, 2, gridtools::sid::unknown_kind>(gridtools::tuple_util::get<0>(tup))), + generated::J_t>( + gridtools::sid::shift_sid_origin( + gridtools::as_sid( + gridtools::tuple_util::get<0>(tup).first), + gridtools::tuple_util::get<0>(tup).second)), gridtools::sid::rename_numbered_dimensions( - gridtools::as_sid< - float, 2, gridtools::sid::unknown_kind>(gridtools::tuple_util::get<1>(tup)))), + generated::J_t>( + gridtools::sid::shift_sid_origin( + gridtools::as_sid( + gridtools::tuple_util::get<1>(tup).first), + gridtools::tuple_util::get<1>(tup).second))), sc); } - + PYBIND11_MODULE(stencil, module) { module.doc() = ""; module.def("stencil", &stencil_wrapper, "");