Skip to content

Commit

Permalink
Merge pull request #350 from ecmwf-ifs/naml-ir-autocast-constructors
Browse files Browse the repository at this point in the history
IR: Automatic sanitisation of tuples in IR constructors
  • Loading branch information
reuterbal authored Jul 26, 2024
2 parents 190bdfa + 561d748 commit 2b3b206
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 29 deletions.
63 changes: 39 additions & 24 deletions loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pymbolic.primitives import Expression

from pydantic.dataclasses import dataclass as dataclass_validated
from pydantic import model_validator

from loki.scope import Scope
from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict
Expand Down Expand Up @@ -50,6 +51,14 @@
# Using this decorator, we can force strict validation
dataclass_strict = partial(dataclass_validated, config=dataclass_validation_config)


def _sanitize_tuple(t):
"""
Small helper method to ensure non-nested tuples without ``None``.
"""
return tuple(n for n in flatten(as_tuple(t)) if n is not None)


# Abstract base classes

@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -224,7 +233,14 @@ def uses_symbols(self):


@dataclass_strict(frozen=True)
class InternalNode(Node):
class _InternalNode():
""" Type definitions for :any:`InternalNode` node type. """

body: Tuple[Union[Node, Scope], ...] = ()


@dataclass_strict(frozen=True)
class InternalNode(Node, _InternalNode):
"""
Internal representation of a control flow node that has a traversable
`body` property.
Expand All @@ -235,14 +251,19 @@ class InternalNode(Node):
The nodes that make up the body.
"""

# Certain Node types may contain Module / Subroutine objects
body: Tuple[Any, ...] = None

_traversable = ['body']

def __post_init__(self):
super().__post_init__()
assert self.body is None or isinstance(self.body, tuple)
@model_validator(mode='before')
@classmethod
def pre_init(cls, values):
""" Ensure non-nested tuples for body. """
if values.kwargs and 'body' in values.kwargs:
values.kwargs['body'] = _sanitize_tuple(values.kwargs['body'])
if values.args:
# ArgsKwargs are immutable, so we need to force it a little
new_args = (_sanitize_tuple(values.args[0]),) + values.args[1:]
values = type(values)(args=new_args, kwargs=values.kwargs)
return values

def __repr__(self):
raise NotImplementedError
Expand Down Expand Up @@ -315,24 +336,13 @@ def __setstate__(self, s):
class _SectionBase():
""" Type definitions for :any:`Section` node type. """

# Sections may contain Module / Subroutine objects
body: Tuple[Any, ...] = ()


@dataclass_strict(frozen=True)
class Section(InternalNode, _SectionBase):
"""
Internal representation of a single code region.
"""

def __post_init__(self):
super().__post_init__()
assert self.body is None or isinstance(self.body, tuple)

# Ensure we have no nested tuples in the body
if not all(not isinstance(n, tuple) for n in as_tuple(self.body)):
self._update(body=as_tuple(flatten(self.body)))

def append(self, node):
"""
Append the given node(s) to the section's body.
Expand Down Expand Up @@ -382,7 +392,7 @@ class _AssociateBase():


@dataclass_strict(frozen=True)
class Associate(ScopedNode, Section, _AssociateBase):
class Associate(ScopedNode, Section, _AssociateBase): # pylint: disable=too-many-ancestors
"""
Internal representation of a code region in which names are associated
with expressions or variables.
Expand Down Expand Up @@ -560,7 +570,7 @@ class _ConditionalBase():

condition: Expression
body: Tuple[Node, ...]
else_body: Optional[Tuple[Node, ...]] = None
else_body: Optional[Tuple[Node, ...]] = ()
inline: bool = False
has_elseif: bool = False
name: Optional[str] = None
Expand Down Expand Up @@ -596,14 +606,19 @@ class Conditional(InternalNode, _ConditionalBase):

_traversable = ['condition', 'body', 'else_body']

@model_validator(mode='before')
@classmethod
def pre_init(cls, values):
values = super().pre_init(values)
# Ensure non-nested tuples for else_body
if 'else_body' in values.kwargs:
values.kwargs['else_body'] = _sanitize_tuple(values.kwargs['else_body'])
return values

def __post_init__(self):
super().__post_init__()
assert self.condition is not None

if self.body is not None:
assert isinstance(self.body, tuple)
assert all(isinstance(c, Node) for c in self.body) # pylint: disable=not-an-iterable

if self.has_elseif:
assert len(self.else_body) == 1
assert isinstance(self.else_body[0], Conditional) # pylint: disable=unsubscriptable-object
Expand Down
83 changes: 79 additions & 4 deletions loki/ir/tests/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from loki.expression import symbols as sym, parse_expr
from loki.ir import nodes as ir
from loki.scope import Scope
from loki.subroutine import Subroutine


@pytest.fixture(name='scope')
Expand All @@ -36,6 +37,10 @@ def fixture_n(scope):
def fixture_a_i(scope, i):
return sym.Array('a', dimensions=(i,), scope=scope)

@pytest.fixture(name='a_n')
def fixture_a_n(scope, n):
return sym.Array('a', dimensions=(n,), scope=scope)


def test_assignment(scope, a_i):
"""
Expand Down Expand Up @@ -73,6 +78,7 @@ def test_loop(scope, one, i, n, a_i):
assert isinstance(loop.bounds, Expression)
assert isinstance(loop.body, tuple)
assert all(isinstance(n, ir.Node) for n in loop.body)
assert loop.children == ( i, bounds, (assign,) )

# Ensure "frozen" status of node objects
with pytest.raises(FrozenInstanceError) as error:
Expand All @@ -82,9 +88,17 @@ def test_loop(scope, one, i, n, a_i):
with pytest.raises(FrozenInstanceError) as error:
loop.body = (assign, assign, assign)

# Test auto-casting of the body to tuple
loop = ir.Loop(variable=i, bounds=bounds, body=assign)
assert loop.body == (assign,)
loop = ir.Loop(variable=i, bounds=bounds, body=( (assign,), ))
assert loop.body == (assign,)
loop = ir.Loop(variable=i, bounds=bounds, body=( assign, (assign,), assign, None))
assert loop.body == (assign, assign, assign)

# Test errors for wrong contructor usage
with pytest.raises(ValidationError) as error:
ir.Loop(variable=i, bounds=bounds, body=assign)
ir.Loop(variable=i, bounds=bounds, body=n)
with pytest.raises(ValidationError) as error:
ir.Loop(variable=None, bounds=bounds, body=(assign,))
with pytest.raises(ValidationError) as error:
Expand All @@ -108,6 +122,7 @@ def test_conditional(scope, one, i, n, a_i):
assert all(isinstance(n, ir.Node) for n in cond.body)
assert isinstance(cond.else_body, tuple) and len(cond.else_body) == 1
assert all(isinstance(n, ir.Node) for n in cond.else_body)
assert cond.children == ( condition, (assign, assign), (assign,) )

with pytest.raises(FrozenInstanceError) as error:
cond.condition = parse_expr('k == 0', scope=scope)
Expand All @@ -116,8 +131,68 @@ def test_conditional(scope, one, i, n, a_i):
with pytest.raises(FrozenInstanceError) as error:
cond.else_body = (assign, assign, assign)

# Test errors for wrong contructor usage
with pytest.raises(ValidationError) as error:
ir.Conditional(condition=condition, body=assign)
# Test auto-casting of the body / else_body to tuple
cond = ir.Conditional(condition=condition, body=assign)
assert cond.body == (assign,) and cond.else_body == ()
cond = ir.Conditional(condition=condition, body=( (assign,), ))
assert cond.body == (assign,) and cond.else_body == ()
cond = ir.Conditional(condition=condition, body=( assign, (assign,), assign, None))
assert cond.body == (assign, assign, assign) and cond.else_body == ()

cond = ir.Conditional(condition=condition, body=(), else_body=assign)
assert cond.body == () and cond.else_body == (assign,)
cond = ir.Conditional(condition=condition, body=(), else_body=( (assign,), ))
assert cond.body == () and cond.else_body == (assign,)
cond = ir.Conditional(
condition=condition, body=(), else_body=( assign, (assign,), assign, None)
)
assert cond.body == () and cond.else_body == (assign, assign, assign)

# TODO: Test inline, name, has_elseif


def test_section(scope, one, i, n, a_n, a_i):
"""
Test constructors and behaviour of :any:`Section` nodes.
"""
assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
decl = ir.VariableDeclaration(symbols=(a_n,))
func = Subroutine(
name='F', is_function=True, spec=(decl,), body=(assign,)
)

# Test constructor for nodes and subroutine objects
sec = ir.Section(body=(assign, assign))
assert isinstance(sec.body, tuple) and len(sec.body) == 2
assert all(isinstance(n, ir.Node) for n in sec.body)
with pytest.raises(FrozenInstanceError) as error:
sec.body = (assign, assign)

sec = ir.Section(body=(func, func))
assert isinstance(sec.body, tuple) and len(sec.body) == 2
assert all(isinstance(n, Scope) for n in sec.body)
with pytest.raises(FrozenInstanceError) as error:
sec.body = (func, func)

sec = ir.Section((assign, assign))
assert sec.body == (assign, assign)

# Test auto-casting of the body to tuple
sec = ir.Section(body=assign)
assert sec.body == (assign,)
sec = ir.Section(body=( (assign,), ))
assert sec.body == (assign,)
sec = ir.Section(body=( assign, (assign,), assign, None))
assert sec.body == (assign, assign, assign)
sec = ir.Section((assign, (func,), assign, None))
assert sec.body == (assign, func, assign)

# Test prepend/insert/append additions
sec = ir.Section(body=func)
assert sec.body == (func,)
sec.prepend(assign)
assert sec.body == (assign, func)
sec.append((assign, assign))
assert sec.body == (assign, func, assign, assign)
sec.insert(pos=3, node=func)
assert sec.body == (assign, func, assign, func, assign)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"coloredlogs", # optional for loki-build utility
"junit_xml", # optional for JunitXML output in loki-lint
"codetiming", # essential for scheduler and sourcefile timings
"pydantic", # type checking for IR nodes
"pydantic>=2.0", # type checking for IR nodes
]

[project.optional-dependencies]
Expand Down

0 comments on commit 2b3b206

Please sign in to comment.