From b25f4cd8b9a3134075a6426c08f24e1aa56bc0c7 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Wed, 25 Sep 2024 16:19:19 +0200 Subject: [PATCH 1/4] [TRANSPILATION] handle modulo operator/function for c-like-backends --- loki/backend/cgen.py | 22 +++++- .../transpile/tests/test_transpile.py | 73 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/loki/backend/cgen.py b/loki/backend/cgen.py index 57ecb99f1..2648f15cc 100644 --- a/loki/backend/cgen.py +++ b/loki/backend/cgen.py @@ -14,7 +14,8 @@ from loki.ir import Import, Stringifier, FindNodes from loki.expression import ( LokiStringifyMapper, Array, symbolic_op, Literal, - symbols as sym + symbols as sym, FindVariables, ExpressionFinder, + ExpressionRetriever ) from loki.types import BasicType, SymbolAttributes, DerivedType @@ -140,6 +141,25 @@ def map_c_reference(self, expr, enclosing_prec, *args, **kwargs): def map_c_dereference(self, expr, enclosing_prec, *args, **kwargs): return self.format(' (*%s)', self.rec(expr.expression, PREC_NONE, *args, **kwargs)) + def map_inline_call(self, expr, enclosing_prec, *args, **kwargs): + + class FindFloatLiterals(ExpressionFinder): + retriever = ExpressionRetriever(lambda e: isinstance(e, sym.FloatLiteral)) + + if expr.function.name.lower() == 'mod': + parameters = [self.rec(param, PREC_NONE, *args, **kwargs) for param in expr.parameters] + # TODO: this check is not quite correct, as it should evaluate the + # expression(s) of both arguments/parameters and choose the integer version of modulo ('%') + # instead of the floating-point version ('fmod') + # whenever the mentioned evaluations result in being of kind 'integer' ... + # as an example: 'celing(3.1415)' got an floating point value in it, however it evaluates/returns + # an integer, in that case the wrong modulo function/operation is chosen + if any(var.type.dtype != BasicType.INTEGER for var in FindVariables().visit(expr.parameters)) or\ + FindFloatLiterals().visit(expr.parameters): + return f'fmod({parameters[0]}, {parameters[1]})' + return f'({parameters[0]})%({parameters[1]})' + return super().map_inline_call(expr, enclosing_prec, *args, **kwargs) + class CCodegen(Stringifier): """ diff --git a/loki/transformations/transpile/tests/test_transpile.py b/loki/transformations/transpile/tests/test_transpile.py index d0c1bdc8f..9a8962e28 100644 --- a/loki/transformations/transpile/tests/test_transpile.py +++ b/loki/transformations/transpile/tests/test_transpile.py @@ -19,6 +19,8 @@ from loki.transformations.transpile import FortranCTransformation from loki.transformations.single_column import SCCLowLevelHoist, SCCLowLevelParametrise +# pylint: disable=too-many-lines + @pytest.fixture(scope='function', name='builder') def fixture_builder(tmp_path): yield Builder(source_dirs=tmp_path, build_dir=tmp_path) @@ -1249,6 +1251,77 @@ def test_transpile_multiconditional_range(tmp_path, frontend): with pytest.raises(NotImplementedError): f2c.apply(source=routine, path=tmp_path) + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('dtype', ('integer', 'real',)) +@pytest.mark.parametrize('add_float', (False, True)) +def test_transpile_special_functions(tmp_path, builder, frontend, dtype, add_float): + """ + A simple test to verify multiconditionals/select case statements. + """ + + fcode = f""" +subroutine transpile_special_functions(in, out) + use iso_fortran_env, only: real64 + implicit none + {dtype}{'(kind=real64)' if dtype == 'real' else ''}, intent(in) :: in + {dtype}{'(kind=real64)' if dtype == 'real' else ''}, intent(inout) :: out + + if (mod(in{'+ 2._real64' if add_float else ''}, 2{'._real64' if dtype == 'real' else ''}{'+ 0._real64' if add_float else ''}) .eq. 0) then + out = 42{'._real64' if dtype == 'real' else ''} + else + out = 11{'._real64' if dtype == 'real' else ''} + endif +end subroutine transpile_special_functions +""".strip() + + def init_var(dtype, val=0): + if dtype == 'real': + return np.float64([val]) + return np.int_([val]) + + # for testing purposes + in_var = init_var(dtype) # np.float64([0]) # 0 + test_vals = [2, 10, 5, 3] + expected_results = [42, 42, 11, 11] + # out_var = np.int_([0]) + out_var = init_var(dtype) # np.float64([0]) + + # compile original Fortran version + routine = Subroutine.from_source(fcode, frontend=frontend) + filepath = tmp_path/f'{routine.name}_{frontend!s}.f90' + function = jit_compile(routine, filepath=filepath, objname=routine.name) + # test Fortran version + for i, val in enumerate(test_vals): + in_var = val + function(in_var, out_var) + assert out_var == expected_results[i] + + clean_test(filepath) + + # apply F2C trafo + f2c = FortranCTransformation() + f2c.apply(source=routine, path=tmp_path) + + # check whether correct modulo was inserted + with open(f2c.c_path, 'r') as f: + ccode = f.read() + if dtype == 'integer' and not add_float: + assert '%' in ccode + if dtype == 'real' or add_float: + assert 'fmod' in ccode + + # compile C version + libname = f'fc_{routine.name}_{frontend}' + c_kernel = jit_compile_lib([f2c.wrapperpath, f2c.c_path], path=tmp_path, name=libname, builder=builder) + fc_function = c_kernel.transpile_special_functions_fc_mod.transpile_special_functions_fc + # test C version + for i, val in enumerate(test_vals): + in_var = val + fc_function(in_var, out_var) + assert int(out_var) == expected_results[i] + + @pytest.fixture(scope='module', name='horizontal') def fixture_horizontal(): return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'iend')) From 1bc47060c0e872429dcff0f32807546ee7388425 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Sat, 5 Oct 2024 14:20:38 +0200 Subject: [PATCH 2/4] improve readability, fix imports, ... --- loki/backend/cgen.py | 8 +++--- .../transpile/tests/test_transpile.py | 26 ++++++++++--------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/loki/backend/cgen.py b/loki/backend/cgen.py index 2648f15cc..14cc2d955 100644 --- a/loki/backend/cgen.py +++ b/loki/backend/cgen.py @@ -11,11 +11,13 @@ ) from loki.tools import as_tuple -from loki.ir import Import, Stringifier, FindNodes +from loki.ir import ( + Import, Stringifier, FindNodes, + FindVariables, ExpressionFinder +) from loki.expression import ( LokiStringifyMapper, Array, symbolic_op, Literal, - symbols as sym, FindVariables, ExpressionFinder, - ExpressionRetriever + symbols as sym, ExpressionRetriever ) from loki.types import BasicType, SymbolAttributes, DerivedType diff --git a/loki/transformations/transpile/tests/test_transpile.py b/loki/transformations/transpile/tests/test_transpile.py index 9a8962e28..9a8a072f2 100644 --- a/loki/transformations/transpile/tests/test_transpile.py +++ b/loki/transformations/transpile/tests/test_transpile.py @@ -6,7 +6,6 @@ # nor does it submit to any jurisdiction. from pathlib import Path -# from shutil import rmtree import pytest import numpy as np @@ -1259,18 +1258,23 @@ def test_transpile_special_functions(tmp_path, builder, frontend, dtype, add_flo """ A simple test to verify multiconditionals/select case statements. """ + if dtype == 'real': + decl_type = f'{dtype}(kind=real64)' + kind = '._real64' + else: + decl_type = dtype + kind = '' fcode = f""" subroutine transpile_special_functions(in, out) use iso_fortran_env, only: real64 implicit none - {dtype}{'(kind=real64)' if dtype == 'real' else ''}, intent(in) :: in - {dtype}{'(kind=real64)' if dtype == 'real' else ''}, intent(inout) :: out - - if (mod(in{'+ 2._real64' if add_float else ''}, 2{'._real64' if dtype == 'real' else ''}{'+ 0._real64' if add_float else ''}) .eq. 0) then - out = 42{'._real64' if dtype == 'real' else ''} + {decl_type}, intent(in) :: in + {decl_type}, intent(inout) :: out + if (mod(in{'+ 2._real64' if add_float else ''}, 2{kind}{'+ 0._real64' if add_float else ''}) .eq. 0) then + out = 42{kind} else - out = 11{'._real64' if dtype == 'real' else ''} + out = 11{kind} endif end subroutine transpile_special_functions """.strip() @@ -1281,11 +1285,10 @@ def init_var(dtype, val=0): return np.int_([val]) # for testing purposes - in_var = init_var(dtype) # np.float64([0]) # 0 + in_var = init_var(dtype) test_vals = [2, 10, 5, 3] expected_results = [42, 42, 11, 11] - # out_var = np.int_([0]) - out_var = init_var(dtype) # np.float64([0]) + out_var = init_var(dtype) # compile original Fortran version routine = Subroutine.from_source(fcode, frontend=frontend) @@ -1304,8 +1307,7 @@ def init_var(dtype, val=0): f2c.apply(source=routine, path=tmp_path) # check whether correct modulo was inserted - with open(f2c.c_path, 'r') as f: - ccode = f.read() + ccode = Path(f2c.c_path).read_text() if dtype == 'integer' and not add_float: assert '%' in ccode if dtype == 'real' or add_float: From 2ce21a22eaadb57a5a9dcb1b278fec7d6ae5e78f Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Mon, 7 Oct 2024 11:17:42 +0200 Subject: [PATCH 3/4] Move 'FindRealLiterals' (formerly 'FindFloatLiterals') to expr_visitors.py and add corresponding tests --- loki/backend/cgen.py | 9 +++---- loki/ir/expr_visitors.py | 14 ++++++++-- loki/ir/tests/test_expr_visitors.py | 40 ++++++++++++++++++++++++++++- 3 files changed, 54 insertions(+), 9 deletions(-) diff --git a/loki/backend/cgen.py b/loki/backend/cgen.py index 14cc2d955..b0abce9d8 100644 --- a/loki/backend/cgen.py +++ b/loki/backend/cgen.py @@ -13,11 +13,11 @@ from loki.tools import as_tuple from loki.ir import ( Import, Stringifier, FindNodes, - FindVariables, ExpressionFinder + FindVariables, FindRealLiterals ) from loki.expression import ( LokiStringifyMapper, Array, symbolic_op, Literal, - symbols as sym, ExpressionRetriever + symbols as sym ) from loki.types import BasicType, SymbolAttributes, DerivedType @@ -145,9 +145,6 @@ def map_c_dereference(self, expr, enclosing_prec, *args, **kwargs): def map_inline_call(self, expr, enclosing_prec, *args, **kwargs): - class FindFloatLiterals(ExpressionFinder): - retriever = ExpressionRetriever(lambda e: isinstance(e, sym.FloatLiteral)) - if expr.function.name.lower() == 'mod': parameters = [self.rec(param, PREC_NONE, *args, **kwargs) for param in expr.parameters] # TODO: this check is not quite correct, as it should evaluate the @@ -157,7 +154,7 @@ class FindFloatLiterals(ExpressionFinder): # as an example: 'celing(3.1415)' got an floating point value in it, however it evaluates/returns # an integer, in that case the wrong modulo function/operation is chosen if any(var.type.dtype != BasicType.INTEGER for var in FindVariables().visit(expr.parameters)) or\ - FindFloatLiterals().visit(expr.parameters): + FindRealLiterals().visit(expr.parameters): return f'fmod({parameters[0]}, {parameters[1]})' return f'({parameters[0]})%({parameters[1]})' return super().map_inline_call(expr, enclosing_prec, *args, **kwargs) diff --git a/loki/ir/expr_visitors.py b/loki/ir/expr_visitors.py index f2ed2a089..d4516e066 100644 --- a/loki/ir/expr_visitors.py +++ b/loki/ir/expr_visitors.py @@ -25,8 +25,9 @@ __all__ = [ 'FindExpressions', 'FindVariables', 'FindTypedSymbols', - 'FindInlineCalls', 'FindLiterals', 'SubstituteExpressions', - 'SubstituteStringExpressions', 'ExpressionFinder', 'AttachScopes' + 'FindInlineCalls', 'FindLiterals', 'FindRealLiterals', + 'SubstituteExpressions', 'SubstituteStringExpressions', + 'ExpressionFinder', 'AttachScopes' ] @@ -201,6 +202,15 @@ class FindLiterals(ExpressionFinder): FloatLiteral, IntLiteral, LogicLiteral, StringLiteral, IntrinsicLiteral ))) +class FindRealLiterals(ExpressionFinder): + """ + A visitor to collect all real/float literals (which includes :any:`FloatLiteral`) + used in an IR tree. + + See :class:`ExpressionFinder` + """ + retriever = ExpressionRetriever(lambda e: isinstance(e, FloatLiteral)) + class SubstituteExpressions(Transformer): """ diff --git a/loki/ir/tests/test_expr_visitors.py b/loki/ir/tests/test_expr_visitors.py index 149242f35..b5e3a3187 100644 --- a/loki/ir/tests/test_expr_visitors.py +++ b/loki/ir/tests/test_expr_visitors.py @@ -12,7 +12,8 @@ from loki.frontend import available_frontends from loki.ir import ( nodes as ir, FindNodes, FindVariables, FindTypedSymbols, - SubstituteExpressions, SubstituteStringExpressions + SubstituteExpressions, SubstituteStringExpressions, + FindLiterals, FindRealLiterals ) @@ -123,6 +124,43 @@ def test_find_variables(frontend, tmp_path): assert len(body_vars) == 10 assert all(v in body_vars for v in expected) +@pytest.mark.parametrize('frontend', available_frontends()) +def test_find_literals(frontend): + """ + Test that :any:`FindLiterals` finds all literals + and :any:`FindRealLiterals` all real/float literals. + """ + fcode = """ +subroutine test_find_literals() + implicit none + integer :: n, n1 + real(kind=8) :: x + + n = 1 + 5 + 42 + x = 1.0 / 10.5 + n1 = int(B'00000') + if (.TRUE.) then + call some_func(x, some_string='string_kwarg') + endif + +end subroutine test_find_literals +""" + expected_int_literals = ('1', '5', '42') + expected_real_literals = ('1.0', '10.5') + expected_intrinsic_literals = ("B'00000'",) + expected_logic_literals = ('True',) + expected_string_literals = ('string_kwarg',) + expected_literals = expected_int_literals + expected_real_literals +\ + expected_intrinsic_literals + expected_logic_literals +\ + expected_string_literals + routine = Subroutine.from_source(fcode, frontend=frontend) + literals = FindLiterals().visit(routine.body) + assert sorted(list(expected_literals)) == sorted([str(literal.value) for literal in literals]) + real_literals = FindRealLiterals().visit(routine.body) + assert sorted(list(expected_real_literals)) == sorted([str(literal.value) for literal in real_literals]) + real_literals_isinstance = [literal for literal in literals if isinstance(literal, sym.FloatLiteral)] + assert sorted(list(expected_real_literals)) == sorted([str(literal.value) for literal in real_literals_isinstance]) + @pytest.mark.parametrize('frontend', available_frontends()) def test_substitute_expressions(frontend): From a428a3b085b93dd0cea8993e6d7dcad4d0c8d270 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Mon, 7 Oct 2024 12:29:44 +0200 Subject: [PATCH 4/4] fix test for OMNI (evaluation of BOZ constants) --- loki/ir/tests/test_expr_visitors.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/loki/ir/tests/test_expr_visitors.py b/loki/ir/tests/test_expr_visitors.py index b5e3a3187..f0c947ae2 100644 --- a/loki/ir/tests/test_expr_visitors.py +++ b/loki/ir/tests/test_expr_visitors.py @@ -9,7 +9,7 @@ from loki import Sourcefile, Subroutine from loki.expression import symbols as sym, parse_expr -from loki.frontend import available_frontends +from loki.frontend import available_frontends, OMNI from loki.ir import ( nodes as ir, FindNodes, FindVariables, FindTypedSymbols, SubstituteExpressions, SubstituteStringExpressions, @@ -147,7 +147,8 @@ def test_find_literals(frontend): """ expected_int_literals = ('1', '5', '42') expected_real_literals = ('1.0', '10.5') - expected_intrinsic_literals = ("B'00000'",) + # Omni evaluates BOZ constants, so it creates IntegerLiteral instead... + expected_intrinsic_literals = ("B'00000'",) if frontend != OMNI else ('0',) expected_logic_literals = ('True',) expected_string_literals = ('string_kwarg',) expected_literals = expected_int_literals + expected_real_literals +\