Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle modulo operator/function for c-like-backends #383

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion loki/backend/cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

from loki.logging import warning
from loki.tools import as_tuple
from loki.ir import Import, Stringifier, FindNodes
from loki.ir import (
Import, Stringifier, FindNodes,
FindVariables, FindRealLiterals
)
from loki.expression import (
LokiStringifyMapper, Array, symbolic_op, Literal,
symbols as sym
Expand Down Expand Up @@ -142,10 +145,25 @@ def map_c_dereference(self, expr, enclosing_prec, *args, **kwargs):
return self.format(' (*%s)', self.rec(expr.expression, PREC_NONE, *args, **kwargs))

def map_inline_call(self, expr, enclosing_prec, *args, **kwargs):
if expr.function.name.lower() == 'mod':
parameters = [self.rec(param, PREC_NONE, *args, **kwargs) for param in expr.parameters]
# TODO: this check is not quite correct, as it should evaluate the
# expression(s) of both arguments/parameters and choose the integer version of modulo ('%')
# instead of the floating-point version ('fmod')
# whenever the mentioned evaluations result in being of kind 'integer' ...
# as an example: 'celing(3.1415)' got an floating point value in it, however it evaluates/returns
# an integer, in that case the wrong modulo function/operation is chosen
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[no action] This is indeed a tricky corner case. Solution looks good for now, but might need some more thought when this gets problematic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely the correct approach would be an ExpressionTypeMapper that attempts to determine the return type of an expression (similar to ExpressionDimensionsMapper) - but it will require baking in a lot of knowledge about Fortran intrinsics. Fparser might provide some of that, so it could be feasible.

Definitely way beyond the scope of this PR, though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! We could also think about implementing a C++ (templated) function mymod() which either calls fmod or uses the modulo operator in dependence of the arguments types.

if any(var.type.dtype != BasicType.INTEGER for var in FindVariables().visit(expr.parameters)) or\
FindRealLiterals().visit(expr.parameters):
return f'fmod({parameters[0]}, {parameters[1]})'
return f'({parameters[0]})%({parameters[1]})'

if expr.function.name.lower() == 'present':
return self.format('true /*ATTENTION: present({%s})*/', expr.parameters[0].name)

return super().map_inline_call(expr, enclosing_prec, *args, **kwargs)


class CCodegen(Stringifier):
"""
Tree visitor to generate standardized C code from IR.
Expand Down
14 changes: 12 additions & 2 deletions loki/ir/expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@

__all__ = [
'FindExpressions', 'FindVariables', 'FindTypedSymbols',
'FindInlineCalls', 'FindLiterals', 'SubstituteExpressions',
'SubstituteStringExpressions', 'ExpressionFinder', 'AttachScopes'
'FindInlineCalls', 'FindLiterals', 'FindRealLiterals',
'SubstituteExpressions', 'SubstituteStringExpressions',
'ExpressionFinder', 'AttachScopes'
]


Expand Down Expand Up @@ -201,6 +202,15 @@ class FindLiterals(ExpressionFinder):
FloatLiteral, IntLiteral, LogicLiteral, StringLiteral, IntrinsicLiteral
)))

class FindRealLiterals(ExpressionFinder):
"""
A visitor to collect all real/float literals (which includes :any:`FloatLiteral`)
used in an IR tree.

See :class:`ExpressionFinder`
"""
retriever = ExpressionRetriever(lambda e: isinstance(e, FloatLiteral))


class SubstituteExpressions(Transformer):
"""
Expand Down
43 changes: 41 additions & 2 deletions loki/ir/tests/test_expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

from loki import Sourcefile, Subroutine
from loki.expression import symbols as sym, parse_expr
from loki.frontend import available_frontends
from loki.frontend import available_frontends, OMNI
from loki.ir import (
nodes as ir, FindNodes, FindVariables, FindTypedSymbols,
SubstituteExpressions, SubstituteStringExpressions
SubstituteExpressions, SubstituteStringExpressions,
FindLiterals, FindRealLiterals
)


Expand Down Expand Up @@ -123,6 +124,44 @@ def test_find_variables(frontend, tmp_path):
assert len(body_vars) == 10
assert all(v in body_vars for v in expected)

@pytest.mark.parametrize('frontend', available_frontends())
def test_find_literals(frontend):
"""
Test that :any:`FindLiterals` finds all literals
and :any:`FindRealLiterals` all real/float literals.
"""
fcode = """
subroutine test_find_literals()
implicit none
integer :: n, n1
real(kind=8) :: x

n = 1 + 5 + 42
x = 1.0 / 10.5
n1 = int(B'00000')
if (.TRUE.) then
call some_func(x, some_string='string_kwarg')
endif

end subroutine test_find_literals
"""
expected_int_literals = ('1', '5', '42')
expected_real_literals = ('1.0', '10.5')
# Omni evaluates BOZ constants, so it creates IntegerLiteral instead...
expected_intrinsic_literals = ("B'00000'",) if frontend != OMNI else ('0',)
expected_logic_literals = ('True',)
expected_string_literals = ('string_kwarg',)
expected_literals = expected_int_literals + expected_real_literals +\
expected_intrinsic_literals + expected_logic_literals +\
expected_string_literals
routine = Subroutine.from_source(fcode, frontend=frontend)
literals = FindLiterals().visit(routine.body)
assert sorted(list(expected_literals)) == sorted([str(literal.value) for literal in literals])
real_literals = FindRealLiterals().visit(routine.body)
assert sorted(list(expected_real_literals)) == sorted([str(literal.value) for literal in real_literals])
real_literals_isinstance = [literal for literal in literals if isinstance(literal, sym.FloatLiteral)]
assert sorted(list(expected_real_literals)) == sorted([str(literal.value) for literal in real_literals_isinstance])


@pytest.mark.parametrize('frontend', available_frontends())
def test_substitute_expressions(frontend):
Expand Down
75 changes: 74 additions & 1 deletion loki/transformations/transpile/tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# nor does it submit to any jurisdiction.

from pathlib import Path
# from shutil import rmtree
import pytest
import numpy as np

Expand Down Expand Up @@ -1278,6 +1277,80 @@ def test_transpile_multiconditional(tmp_path, builder, frontend):
assert out_var == val[1]



@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('dtype', ('integer', 'real',))
@pytest.mark.parametrize('add_float', (False, True))
def test_transpile_special_functions(tmp_path, builder, frontend, dtype, add_float):
"""
A simple test to verify multiconditionals/select case statements.
"""
if dtype == 'real':
decl_type = f'{dtype}(kind=real64)'
kind = '._real64'
else:
decl_type = dtype
kind = ''

fcode = f"""
subroutine transpile_special_functions(in, out)
use iso_fortran_env, only: real64
implicit none
{decl_type}, intent(in) :: in
{decl_type}, intent(inout) :: out
if (mod(in{'+ 2._real64' if add_float else ''}, 2{kind}{'+ 0._real64' if add_float else ''}) .eq. 0) then
out = 42{kind}
else
out = 11{kind}
endif
end subroutine transpile_special_functions
""".strip()

def init_var(dtype, val=0):
if dtype == 'real':
return np.float64([val])
return np.int_([val])

# for testing purposes
in_var = init_var(dtype)
test_vals = [2, 10, 5, 3]
expected_results = [42, 42, 11, 11]
out_var = init_var(dtype)

# compile original Fortran version
routine = Subroutine.from_source(fcode, frontend=frontend)
filepath = tmp_path/f'{routine.name}_{frontend!s}.f90'
function = jit_compile(routine, filepath=filepath, objname=routine.name)
# test Fortran version
for i, val in enumerate(test_vals):
in_var = val
function(in_var, out_var)
assert out_var == expected_results[i]

clean_test(filepath)

# apply F2C trafo
f2c = FortranCTransformation()
f2c.apply(source=routine, path=tmp_path)

# check whether correct modulo was inserted
ccode = Path(f2c.c_path).read_text()
if dtype == 'integer' and not add_float:
assert '%' in ccode
if dtype == 'real' or add_float:
assert 'fmod' in ccode

# compile C version
libname = f'fc_{routine.name}_{frontend}'
c_kernel = jit_compile_lib([f2c.wrapperpath, f2c.c_path], path=tmp_path, name=libname, builder=builder)
fc_function = c_kernel.transpile_special_functions_fc_mod.transpile_special_functions_fc
# test C version
for i, val in enumerate(test_vals):
in_var = val
fc_function(in_var, out_var)
assert int(out_var) == expected_results[i]


@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'iend'))
Expand Down
Loading