Skip to content

Commit

Permalink
Pass the shape of ArrayArg to loopy
Browse files Browse the repository at this point in the history
  • Loading branch information
thilinarmtb committed Dec 10, 2023
1 parent 77f4f2c commit c16ce57
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions python/loopy_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,28 @@ def set_tf_results(
return true_result, false_result


def check_and_parse_decl(expr: cindex.CursorKind):
"""Check and parse a C variable declaration."""
name = expr.spelling
shape = ()
init = None
children = list(expr.get_children())

if expr.type.kind in _ARRAY_TYPES:
dims = []
for child in children:
if child.kind == cindex.CursorKind.INIT_LIST_EXPR:
init = child
else:
dims.append(child)
shape = tuple(CToLoopyExpressionMapper()(dim) for dim in dims)
return (name, shape, init)
if isinstance(expr.type.kind, cindex.TypeKind):
if len(children) == 1:
init = children[0]
return (name, shape, init)


class CToLoopyMapper(IdentityMapper):
"""Map C expressions and statemements to Loopy expressions."""

Expand Down Expand Up @@ -508,25 +530,6 @@ def map_var_decl(
) -> CToLoopyMapperAccumulator:
"""Maps a C variable declaration."""

def check_and_parse_decl(expr: cindex.CursorKind):
name = expr.spelling
children = list(expr.get_children())

init = None
if expr.type.kind in _ARRAY_TYPES:
dims = []
for child in children:
if child.kind == cindex.CursorKind.INIT_LIST_EXPR:
init = child
else:
dims.append(child)
shape = tuple(CToLoopyExpressionMapper()(dim) for dim in dims)
return (name, shape, init)
if isinstance(expr.type.kind, cindex.TypeKind):
if len(children) == 1:
init = children[0]
return (name, (), init)

(name, shape, init) = check_and_parse_decl(expr)

if name.startswith(NOMP_VAR_PREFIX):
Expand Down Expand Up @@ -733,11 +736,13 @@ def get_knl_args(self) -> list[lp.KernelArgument]:
for arg in self.var_to_decl.values():
dtype = _get_dtype_from_decl_type(arg.type)
if arg.type.kind in _ARRAY_TYPES_W_PTR:
(_, shape, _) = check_and_parse_decl(arg)
knl_args.append(
lp.ArrayArg(
arg.spelling,
dtype=dtype,
address_space=AddressSpace.GLOBAL,
shape=shape if shape else None,
)
)
elif isinstance(arg.type.kind, cindex.TypeKind):
Expand Down

0 comments on commit c16ce57

Please sign in to comment.