-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
handle modulo operator/function for c-like-backends #383
Open
MichaelSt98
wants to merge
5
commits into
main
Choose a base branch
from
nams-cgen-modulo-handling
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+146
−6
Open
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
b25f4cd
[TRANSPILATION] handle modulo operator/function for c-like-backends
MichaelSt98 1bc4706
improve readability, fix imports, ...
MichaelSt98 2ce21a2
Move 'FindRealLiterals' (formerly 'FindFloatLiterals') to expr_visito…
MichaelSt98 a428a3b
fix test for OMNI (evaluation of BOZ constants)
MichaelSt98 c613469
Merge branch 'main' into nams-cgen-modulo-handling
reuterbal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
# nor does it submit to any jurisdiction. | ||
|
||
from pathlib import Path | ||
# from shutil import rmtree | ||
import pytest | ||
import numpy as np | ||
|
||
|
@@ -19,6 +18,8 @@ | |
from loki.transformations.transpile import FortranCTransformation | ||
from loki.transformations.single_column import SCCLowLevelHoist, SCCLowLevelParametrise | ||
|
||
# pylint: disable=too-many-lines | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should start splitting the tests into logical units (but that can be a separate PR) |
||
@pytest.fixture(scope='function', name='builder') | ||
def fixture_builder(tmp_path): | ||
yield Builder(source_dirs=tmp_path, build_dir=tmp_path) | ||
|
@@ -1249,6 +1250,80 @@ def test_transpile_multiconditional_range(tmp_path, frontend): | |
with pytest.raises(NotImplementedError): | ||
f2c.apply(source=routine, path=tmp_path) | ||
|
||
|
||
@pytest.mark.parametrize('frontend', available_frontends()) | ||
@pytest.mark.parametrize('dtype', ('integer', 'real',)) | ||
@pytest.mark.parametrize('add_float', (False, True)) | ||
def test_transpile_special_functions(tmp_path, builder, frontend, dtype, add_float): | ||
""" | ||
A simple test to verify multiconditionals/select case statements. | ||
""" | ||
if dtype == 'real': | ||
decl_type = f'{dtype}(kind=real64)' | ||
kind = '._real64' | ||
else: | ||
decl_type = dtype | ||
kind = '' | ||
|
||
fcode = f""" | ||
subroutine transpile_special_functions(in, out) | ||
use iso_fortran_env, only: real64 | ||
implicit none | ||
{decl_type}, intent(in) :: in | ||
{decl_type}, intent(inout) :: out | ||
if (mod(in{'+ 2._real64' if add_float else ''}, 2{kind}{'+ 0._real64' if add_float else ''}) .eq. 0) then | ||
out = 42{kind} | ||
else | ||
out = 11{kind} | ||
endif | ||
end subroutine transpile_special_functions | ||
""".strip() | ||
|
||
def init_var(dtype, val=0): | ||
if dtype == 'real': | ||
return np.float64([val]) | ||
return np.int_([val]) | ||
|
||
# for testing purposes | ||
in_var = init_var(dtype) | ||
test_vals = [2, 10, 5, 3] | ||
expected_results = [42, 42, 11, 11] | ||
out_var = init_var(dtype) | ||
|
||
# compile original Fortran version | ||
routine = Subroutine.from_source(fcode, frontend=frontend) | ||
filepath = tmp_path/f'{routine.name}_{frontend!s}.f90' | ||
function = jit_compile(routine, filepath=filepath, objname=routine.name) | ||
# test Fortran version | ||
for i, val in enumerate(test_vals): | ||
in_var = val | ||
function(in_var, out_var) | ||
assert out_var == expected_results[i] | ||
|
||
clean_test(filepath) | ||
|
||
# apply F2C trafo | ||
f2c = FortranCTransformation() | ||
f2c.apply(source=routine, path=tmp_path) | ||
|
||
# check whether correct modulo was inserted | ||
ccode = Path(f2c.c_path).read_text() | ||
if dtype == 'integer' and not add_float: | ||
assert '%' in ccode | ||
if dtype == 'real' or add_float: | ||
assert 'fmod' in ccode | ||
|
||
# compile C version | ||
libname = f'fc_{routine.name}_{frontend}' | ||
c_kernel = jit_compile_lib([f2c.wrapperpath, f2c.c_path], path=tmp_path, name=libname, builder=builder) | ||
fc_function = c_kernel.transpile_special_functions_fc_mod.transpile_special_functions_fc | ||
# test C version | ||
for i, val in enumerate(test_vals): | ||
in_var = val | ||
fc_function(in_var, out_var) | ||
assert int(out_var) == expected_results[i] | ||
|
||
|
||
@pytest.fixture(scope='module', name='horizontal') | ||
def fixture_horizontal(): | ||
return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'iend')) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[no action] This is indeed a tricky corner case. Solution looks good for now, but might need some more thought when this gets problematic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likely the correct approach would be an
ExpressionTypeMapper
that attempts to determine the return type of an expression (similar toExpressionDimensionsMapper
) - but it will require baking in a lot of knowledge about Fortran intrinsics. Fparser might provide some of that, so it could be feasible.Definitely way beyond the scope of this PR, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! We could also think about implementing a C++ (templated) function
mymod()
which either callsfmod
or uses the modulo operator in dependence of the arguments types.