Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IR: Automatic sanitisation of tuples in IR constructors #350

Merged
merged 5 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pymbolic.primitives import Expression

from pydantic.dataclasses import dataclass as dataclass_validated
from pydantic import model_validator

from loki.scope import Scope
from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict
Expand Down Expand Up @@ -50,6 +51,14 @@
# Using this decorator, we can force strict validation
dataclass_strict = partial(dataclass_validated, config=dataclass_validation_config)


def _sanitize_tuple(t):
"""
Small helper method to ensure non-nested tuples without ``None``.
"""
return tuple(n for n in flatten(as_tuple(t)) if n is not None)


# Abstract base classes

@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -224,7 +233,14 @@ def uses_symbols(self):


@dataclass_strict(frozen=True)
class InternalNode(Node):
class _InternalNode():
""" Type definitions for :any:`InternalNode` node type. """

body: Tuple[Union[Node, Scope], ...] = ()


@dataclass_strict(frozen=True)
class InternalNode(Node, _InternalNode):
"""
Internal representation of a control flow node that has a traversable
`body` property.
Expand All @@ -235,14 +251,19 @@ class InternalNode(Node):
The nodes that make up the body.
"""

# Certain Node types may contain Module / Subroutine objects
body: Tuple[Any, ...] = None

_traversable = ['body']

def __post_init__(self):
super().__post_init__()
assert self.body is None or isinstance(self.body, tuple)
@model_validator(mode='before')
@classmethod
def pre_init(cls, values):
""" Ensure non-nested tuples for body. """
if values.kwargs and 'body' in values.kwargs:
values.kwargs['body'] = _sanitize_tuple(values.kwargs['body'])
if values.args:
# ArgsKwargs are immutable, so we need to force it a little
new_args = (_sanitize_tuple(values.args[0]),) + values.args[1:]
values = type(values)(args=new_args, kwargs=values.kwargs)
return values

def __repr__(self):
raise NotImplementedError
Expand Down Expand Up @@ -315,24 +336,13 @@ def __setstate__(self, s):
class _SectionBase():
""" Type definitions for :any:`Section` node type. """

# Sections may contain Module / Subroutine objects
body: Tuple[Any, ...] = ()


@dataclass_strict(frozen=True)
class Section(InternalNode, _SectionBase):
"""
Internal representation of a single code region.
"""

def __post_init__(self):
super().__post_init__()
assert self.body is None or isinstance(self.body, tuple)

# Ensure we have no nested tuples in the body
if not all(not isinstance(n, tuple) for n in as_tuple(self.body)):
self._update(body=as_tuple(flatten(self.body)))

def append(self, node):
"""
Append the given node(s) to the section's body.
Expand Down Expand Up @@ -382,7 +392,7 @@ class _AssociateBase():


@dataclass_strict(frozen=True)
class Associate(ScopedNode, Section, _AssociateBase):
class Associate(ScopedNode, Section, _AssociateBase): # pylint: disable=too-many-ancestors
"""
Internal representation of a code region in which names are associated
with expressions or variables.
Expand Down Expand Up @@ -560,7 +570,7 @@ class _ConditionalBase():

condition: Expression
body: Tuple[Node, ...]
else_body: Optional[Tuple[Node, ...]] = None
else_body: Optional[Tuple[Node, ...]] = ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that a change in behaviour for the default value of else_body?
I think so far we used else_body=None as indicator that there isn't an else branch in the conditional - as opposed to an empty else branch (which has the same result but different on a string level).

As in, will ir.Conditional(condition=<...>, body=<...>) now be represented as the following?

IF (<...>) THEN
  <...>
ELSE
ENDIF

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, we're doing a plain old if o.else_body: which evaluates to the same thing:
https://github.com/ecmwf-ifs/loki/blob/main/loki/backend/fgen.py#L633

inline: bool = False
has_elseif: bool = False
name: Optional[str] = None
Expand Down Expand Up @@ -596,14 +606,19 @@ class Conditional(InternalNode, _ConditionalBase):

_traversable = ['condition', 'body', 'else_body']

@model_validator(mode='before')
@classmethod
def pre_init(cls, values):
values = super().pre_init(values)
# Ensure non-nested tuples for else_body
if 'else_body' in values.kwargs:
values.kwargs['else_body'] = _sanitize_tuple(values.kwargs['else_body'])
return values

def __post_init__(self):
super().__post_init__()
assert self.condition is not None

if self.body is not None:
assert isinstance(self.body, tuple)
assert all(isinstance(c, Node) for c in self.body) # pylint: disable=not-an-iterable

if self.has_elseif:
assert len(self.else_body) == 1
assert isinstance(self.else_body[0], Conditional) # pylint: disable=unsubscriptable-object
Expand Down
83 changes: 79 additions & 4 deletions loki/ir/tests/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from loki.expression import symbols as sym, parse_expr
from loki.ir import nodes as ir
from loki.scope import Scope
from loki.subroutine import Subroutine


@pytest.fixture(name='scope')
Expand All @@ -36,6 +37,10 @@ def fixture_n(scope):
def fixture_a_i(scope, i):
return sym.Array('a', dimensions=(i,), scope=scope)

@pytest.fixture(name='a_n')
def fixture_a_n(scope, n):
return sym.Array('a', dimensions=(n,), scope=scope)


def test_assignment(scope, a_i):
"""
Expand Down Expand Up @@ -73,6 +78,7 @@ def test_loop(scope, one, i, n, a_i):
assert isinstance(loop.bounds, Expression)
assert isinstance(loop.body, tuple)
assert all(isinstance(n, ir.Node) for n in loop.body)
assert loop.children == ( i, bounds, (assign,) )

# Ensure "frozen" status of node objects
with pytest.raises(FrozenInstanceError) as error:
Expand All @@ -82,9 +88,17 @@ def test_loop(scope, one, i, n, a_i):
with pytest.raises(FrozenInstanceError) as error:
loop.body = (assign, assign, assign)

# Test auto-casting of the body to tuple
loop = ir.Loop(variable=i, bounds=bounds, body=assign)
assert loop.body == (assign,)
loop = ir.Loop(variable=i, bounds=bounds, body=( (assign,), ))
assert loop.body == (assign,)
loop = ir.Loop(variable=i, bounds=bounds, body=( assign, (assign,), assign, None))
assert loop.body == (assign, assign, assign)

# Test errors for wrong contructor usage
with pytest.raises(ValidationError) as error:
ir.Loop(variable=i, bounds=bounds, body=assign)
ir.Loop(variable=i, bounds=bounds, body=n)
with pytest.raises(ValidationError) as error:
ir.Loop(variable=None, bounds=bounds, body=(assign,))
with pytest.raises(ValidationError) as error:
Expand All @@ -108,6 +122,7 @@ def test_conditional(scope, one, i, n, a_i):
assert all(isinstance(n, ir.Node) for n in cond.body)
assert isinstance(cond.else_body, tuple) and len(cond.else_body) == 1
assert all(isinstance(n, ir.Node) for n in cond.else_body)
assert cond.children == ( condition, (assign, assign), (assign,) )

with pytest.raises(FrozenInstanceError) as error:
cond.condition = parse_expr('k == 0', scope=scope)
Expand All @@ -116,8 +131,68 @@ def test_conditional(scope, one, i, n, a_i):
with pytest.raises(FrozenInstanceError) as error:
cond.else_body = (assign, assign, assign)

# Test errors for wrong contructor usage
with pytest.raises(ValidationError) as error:
ir.Conditional(condition=condition, body=assign)
# Test auto-casting of the body / else_body to tuple
cond = ir.Conditional(condition=condition, body=assign)
assert cond.body == (assign,) and cond.else_body == ()
cond = ir.Conditional(condition=condition, body=( (assign,), ))
assert cond.body == (assign,) and cond.else_body == ()
cond = ir.Conditional(condition=condition, body=( assign, (assign,), assign, None))
assert cond.body == (assign, assign, assign) and cond.else_body == ()

cond = ir.Conditional(condition=condition, body=(), else_body=assign)
assert cond.body == () and cond.else_body == (assign,)
cond = ir.Conditional(condition=condition, body=(), else_body=( (assign,), ))
assert cond.body == () and cond.else_body == (assign,)
cond = ir.Conditional(
condition=condition, body=(), else_body=( assign, (assign,), assign, None)
)
assert cond.body == () and cond.else_body == (assign, assign, assign)

# TODO: Test inline, name, has_elseif


def test_section(scope, one, i, n, a_n, a_i):
"""
Test constructors and behaviour of :any:`Section` nodes.
"""
assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
decl = ir.VariableDeclaration(symbols=(a_n,))
func = Subroutine(
name='F', is_function=True, spec=(decl,), body=(assign,)
)

# Test constructor for nodes and subroutine objects
sec = ir.Section(body=(assign, assign))
assert isinstance(sec.body, tuple) and len(sec.body) == 2
assert all(isinstance(n, ir.Node) for n in sec.body)
with pytest.raises(FrozenInstanceError) as error:
sec.body = (assign, assign)

sec = ir.Section(body=(func, func))
assert isinstance(sec.body, tuple) and len(sec.body) == 2
assert all(isinstance(n, Scope) for n in sec.body)
with pytest.raises(FrozenInstanceError) as error:
sec.body = (func, func)

sec = ir.Section((assign, assign))
assert sec.body == (assign, assign)

# Test auto-casting of the body to tuple
sec = ir.Section(body=assign)
assert sec.body == (assign,)
sec = ir.Section(body=( (assign,), ))
assert sec.body == (assign,)
sec = ir.Section(body=( assign, (assign,), assign, None))
assert sec.body == (assign, assign, assign)
sec = ir.Section((assign, (func,), assign, None))
assert sec.body == (assign, func, assign)

# Test prepend/insert/append additions
sec = ir.Section(body=func)
assert sec.body == (func,)
sec.prepend(assign)
assert sec.body == (assign, func)
sec.append((assign, assign))
assert sec.body == (assign, func, assign, assign)
sec.insert(pos=3, node=func)
assert sec.body == (assign, func, assign, func, assign)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"coloredlogs", # optional for loki-build utility
"junit_xml", # optional for JunitXML output in loki-lint
"codetiming", # essential for scheduler and sourcefile timings
"pydantic", # type checking for IR nodes
"pydantic>=2.0", # type checking for IR nodes
]

[project.optional-dependencies]
Expand Down
Loading