diff --git a/.dict_custom.txt b/.dict_custom.txt index ae99f31ed4..5d99e21194 100644 --- a/.dict_custom.txt +++ b/.dict_custom.txt @@ -118,3 +118,4 @@ datatyping datatypes indexable traceback +GPUs diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d99c60127..7c1dcffc55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ All notable changes to this project will be documented in this file. - #32 : Add support for `nvcc` Compiler and `cuda` language as a possible option. - #48 : Fix incorrect handling of imports in `cuda`. +- #42 : Add support for custom kernel in`cuda`. +- #42 : Add Cuda module to Pyccel. Add support for `cuda.synchronize` function. ## \[UNRELEASED\] diff --git a/docs/cuda.md b/docs/cuda.md new file mode 100644 index 0000000000..de30d52b80 --- /dev/null +++ b/docs/cuda.md @@ -0,0 +1,23 @@ +# Getting started GPU + +Pyccel now supports NVIDIA CUDA, empowering users to accelerate numerical computations on GPUs seamlessly. With Pyccel's high-level syntax and automatic code generation, harnessing the power of CUDA becomes effortless. This documentation provides a quick guide to enabling CUDA in Pyccel + +## Cuda Decorator + +### kernel + +The kernel decorator allows the user to declare a CUDA kernel. The kernel can be defined in Python, and the syntax is similar to that of Numba. + +```python +from pyccel.decorators import kernel + +@kernel +def my_kernel(): + pass + +blockspergrid = 1 +threadsperblock = 1 +# Call your kernel function +my_kernel[blockspergrid, threadsperblock]() + +``` \ No newline at end of file diff --git a/pyccel/ast/core.py b/pyccel/ast/core.py index 013f206dd6..f0e5cc67f1 100644 --- a/pyccel/ast/core.py +++ b/pyccel/ast/core.py @@ -73,6 +73,7 @@ 'If', 'IfSection', 'Import', + 'IndexedFunctionCall', 'InProgram', 'InlineFunctionDef', 'Interface', @@ -2065,6 +2066,42 @@ def _ignore(cls, c): """ return c is None or isinstance(c, (FunctionDef, *cls._ignored_types)) +class IndexedFunctionCall(FunctionCall): + """ + Represents an indexed function call in the code. + + Class representing indexed function calls, encapsulating all + relevant information for such calls within the code base. + + Parameters + ---------- + func : FunctionDef + The function being called. + + args : iterable of FunctionCallArgument + The arguments passed to the function. + + indexes : iterable of TypedAstNode + The indexes of the function call. + + current_function : FunctionDef, optional + The function where the call takes place. + """ + __slots__ = ('_indexes',) + _attribute_nodes = FunctionCall._attribute_nodes + ('_indexes',) + def __init__(self, func, args, indexes, current_function = None): + self._indexes = indexes + super().__init__(func, args, current_function) + + @property + def indexes(self): + """ + Indexes of function call. + + Represents the indexes of the function call + """ + return self._indexes + class ConstructorCall(FunctionCall): """ diff --git a/pyccel/ast/cuda.py b/pyccel/ast/cuda.py new file mode 100644 index 0000000000..f1e50ef7f0 --- /dev/null +++ b/pyccel/ast/cuda.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +#------------------------------------------------------------------------------------------# +# This file is part of Pyccel which is released under MIT License. See the LICENSE file or # +# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. # +#------------------------------------------------------------------------------------------# +""" +CUDA Module +This module provides a collection of classes and utilities for CUDA programming. +""" +from pyccel.ast.core import FunctionCall + +__all__ = ( + 'KernelCall', +) + +class KernelCall(FunctionCall): + """ + Represents a kernel function call in the code. + + The class serves as a representation of a kernel + function call within the codebase. + + Parameters + ---------- + func : FunctionDef + The definition of the function being called. + + args : iterable of FunctionCallArgument + The arguments passed to the function. + + num_blocks : TypedAstNode + The number of blocks. These objects must have a primitive type of `PrimitiveIntegerType`. + + tp_block : TypedAstNode + The number of threads per block. These objects must have a primitive type of `PrimitiveIntegerType`. + + current_function : FunctionDef, optional + The function where the call takes place. + """ + __slots__ = ('_num_blocks','_tp_block') + _attribute_nodes = (*FunctionCall._attribute_nodes, '_num_blocks', '_tp_block') + + def __init__(self, func, args, num_blocks, tp_block, current_function = None): + self._num_blocks = num_blocks + self._tp_block = tp_block + super().__init__(func, args, current_function) + + @property + def num_blocks(self): + """ + The number of blocks in the kernel being called. + + The number of blocks in the kernel being called. + """ + return self._num_blocks + + @property + def tp_block(self): + """ + The number of threads per block. + + The number of threads per block. + """ + return self._tp_block + diff --git a/pyccel/ast/cudaext.py b/pyccel/ast/cudaext.py new file mode 100644 index 0000000000..b540f20993 --- /dev/null +++ b/pyccel/ast/cudaext.py @@ -0,0 +1,42 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +#------------------------------------------------------------------------------------------# +# This file is part of Pyccel which is released under MIT License. See the LICENSE file or # +# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. # +#------------------------------------------------------------------------------------------# +""" +CUDA Extension Module +Provides CUDA functionality for code generation. +""" +from .internals import PyccelFunction + +from .datatypes import VoidType +from .core import Module, PyccelFunctionDef + +__all__ = ( + 'CudaSynchronize', +) + +class CudaSynchronize(PyccelFunction): + """ + Represents a call to Cuda.synchronize for code generation. + + This class serves as a representation of the Cuda.synchronize method. + """ + __slots__ = () + _attribute_nodes = () + _shape = None + _class_type = VoidType() + def __init__(self): + super().__init__() + +cuda_funcs = { + 'synchronize' : PyccelFunctionDef('synchronize' , CudaSynchronize), +} + +cuda_mod = Module('cuda', + variables=[], + funcs=cuda_funcs.values(), + imports=[] +) + diff --git a/pyccel/ast/utilities.py b/pyccel/ast/utilities.py index 1e6c0422ab..e5cd77b168 100644 --- a/pyccel/ast/utilities.py +++ b/pyccel/ast/utilities.py @@ -25,6 +25,7 @@ from .literals import LiteralInteger, LiteralEllipsis, Nil from .mathext import math_mod from .sysext import sys_mod +from .cudaext import cuda_mod from .numpyext import (NumpyEmpty, NumpyArray, numpy_mod, NumpyTranspose, NumpyLinspace) @@ -49,7 +50,8 @@ decorators_mod = Module('decorators',(), funcs = [PyccelFunctionDef(d, PyccelFunction) for d in pyccel_decorators.__all__]) pyccel_mod = Module('pyccel',(),(), - imports = [Import('decorators', decorators_mod)]) + imports = [Import('decorators', decorators_mod), + Import('cuda', cuda_mod)]) # TODO add documentation builtin_import_registry = Module('__main__', diff --git a/pyccel/codegen/printing/cucode.py b/pyccel/codegen/printing/cucode.py index 277d2a3a6a..cd26843017 100644 --- a/pyccel/codegen/printing/cucode.py +++ b/pyccel/codegen/printing/cucode.py @@ -9,11 +9,12 @@ enabling the direct translation of high-level Pyccel expressions into CUDA code. """ -from pyccel.codegen.printing.ccode import CCodePrinter, c_library_headers +from pyccel.codegen.printing.ccode import CCodePrinter -from pyccel.ast.core import Import, Module +from pyccel.ast.core import Import, Module +from pyccel.ast.literals import Nil -from pyccel.errors.errors import Errors +from pyccel.errors.errors import Errors errors = Errors() @@ -61,6 +62,44 @@ def _print_Module(self, expr): self.exit_scope() return code + def function_signature(self, expr, print_arg_names = True): + """ + Get the Cuda representation of the function signature. + + Extract from the function definition `expr` all the + information (name, input, output) needed to create the + function signature and return a string describing the + function. + This is not a declaration as the signature does not end + with a semi-colon. + + Parameters + ---------- + expr : FunctionDef + The function definition for which a signature is needed. + + print_arg_names : bool, default : True + Indicates whether argument names should be printed. + + Returns + ------- + str + Signature of the function. + """ + cuda_decorater = '__global__' if 'kernel' in expr.decorators else '' + c_function_signature = super().function_signature(expr, print_arg_names) + return f'{cuda_decorater} {c_function_signature}' + + def _print_KernelCall(self, expr): + func = expr.funcdef + args = [a.value or Nil() for a in expr.args] + + args = ', '.join(self._print(a) for a in args) + return f"{func.name}<<<{expr.num_blocks}, {expr.tp_block}>>>({args});\n" + + def _print_CudaSynchronize(self, expr): + return 'cudaDeviceSynchronize();\n' + def _print_ModuleHeader(self, expr): self.set_scope(expr.module.scope) self._in_header = True @@ -87,6 +126,7 @@ def _print_ModuleHeader(self, expr): }}\n' return '\n'.join((f"#ifndef {name.upper()}_H", f"#define {name.upper()}_H", + imports, global_variables, function_declaration, "#endif // {name.upper()}_H\n")) diff --git a/pyccel/cuda/__init__.py b/pyccel/cuda/__init__.py new file mode 100644 index 0000000000..e8542ad5d5 --- /dev/null +++ b/pyccel/cuda/__init__.py @@ -0,0 +1,10 @@ +#------------------------------------------------------------------------------------------# +# This file is part of Pyccel which is released under MIT License. See the LICENSE file or # +# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. # +#------------------------------------------------------------------------------------------# +""" + This module is for exposing the CudaSubmodule functions. +""" +from .cuda_sync_primitives import synchronize + +__all__ = ['synchronize'] diff --git a/pyccel/cuda/cuda_sync_primitives.py b/pyccel/cuda/cuda_sync_primitives.py new file mode 100644 index 0000000000..f3442fe9e2 --- /dev/null +++ b/pyccel/cuda/cuda_sync_primitives.py @@ -0,0 +1,16 @@ +#------------------------------------------------------------------------------------------# +# This file is part of Pyccel which is released under MIT License. See the LICENSE file or # +# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. # +#------------------------------------------------------------------------------------------# +""" +This submodule contains CUDA methods for Pyccel. +""" + + +def synchronize(): + """ + Synchronize CUDA device execution. + + Synchronize CUDA device execution. + """ + diff --git a/pyccel/decorators.py b/pyccel/decorators.py index 1f640043db..77717a991f 100644 --- a/pyccel/decorators.py +++ b/pyccel/decorators.py @@ -19,6 +19,7 @@ 'sympy', 'template', 'types', + 'kernel' ) @@ -109,3 +110,34 @@ def allow_negative_index(f,*args): def identity(f): return f return identity + +def kernel(f): + """ + Decorator for marking a Python function as a kernel. + + This class serves as a decorator to mark a Python function + as a kernel function, typically used for GPU computations. + This allows the function to be indexed with the number of blocks and threads. + + Parameters + ---------- + f : function + The function to which the decorator is applied. + + Returns + ------- + KernelAccessor + A class representing the kernel function. + """ + class KernelAccessor: + """ + Class representing the kernel function. + + Class representing the kernel function. + """ + def __init__(self, f): + self._f = f + def __getitem__(self, args): + return self._f + + return KernelAccessor(f) diff --git a/pyccel/errors/messages.py b/pyccel/errors/messages.py index 79eccc1df2..09966d810c 100644 --- a/pyccel/errors/messages.py +++ b/pyccel/errors/messages.py @@ -162,3 +162,11 @@ WRONG_LINSPACE_ENDPOINT = 'endpoint argument must be boolean' NON_LITERAL_KEEP_DIMS = 'keep_dims argument must be a literal, otherwise rank is unknown' NON_LITERAL_AXIS = 'axis argument must be a literal, otherwise pyccel cannot determine which dimension to operate on' +MISSING_KERNEL_CONFIGURATION = 'Kernel launch configuration not specified' +INVALID_KERNEL_LAUNCH_CONFIG = 'Expected exactly 2 parameters for kernel launch' +INVALID_KERNEL_CALL_BP_GRID = 'Invalid Block per grid parameter for Kernel call' +INVALID_KERNEL_CALL_TP_BLOCK = 'Invalid Thread per Block parameter for Kernel call' + + + + diff --git a/pyccel/parser/semantic.py b/pyccel/parser/semantic.py index e94b9c8413..fde10d6317 100644 --- a/pyccel/parser/semantic.py +++ b/pyccel/parser/semantic.py @@ -116,6 +116,8 @@ from pyccel.ast.variable import IndexedElement, AnnotatedPyccelSymbol from pyccel.ast.variable import DottedName, DottedVariable +from pyccel.ast.cuda import KernelCall + from pyccel.errors.errors import Errors from pyccel.errors.errors import PyccelSemanticError @@ -133,7 +135,9 @@ PYCCEL_RESTRICTION_LIST_COMPREHENSION_LIMITS, PYCCEL_RESTRICTION_LIST_COMPREHENSION_SIZE, UNUSED_DECORATORS, UNSUPPORTED_POINTER_RETURN_VALUE, PYCCEL_RESTRICTION_OPTIONAL_NONE, PYCCEL_RESTRICTION_PRIMITIVE_IMMUTABLE, PYCCEL_RESTRICTION_IS_ISNOT, - FOUND_DUPLICATED_IMPORT, UNDEFINED_WITH_ACCESS, MACRO_MISSING_HEADER_OR_FUNC) + FOUND_DUPLICATED_IMPORT, UNDEFINED_WITH_ACCESS, MACRO_MISSING_HEADER_OR_FUNC, PYCCEL_RESTRICTION_INHOMOG_SET, + MISSING_KERNEL_CONFIGURATION, + INVALID_KERNEL_LAUNCH_CONFIG, INVALID_KERNEL_CALL_BP_GRID, INVALID_KERNEL_CALL_TP_BLOCK) from pyccel.parser.base import BasicParser from pyccel.parser.syntactic import SyntaxParser @@ -1139,6 +1143,67 @@ def _handle_function(self, expr, func, args, *, is_method = False, use_build_fun return new_expr + def _handle_kernel(self, expr, func, args): + """ + Create the node representing the kernel function call. + + Create a FunctionCall or an instance of a PyccelInternalFunction + from the function information and arguments. + + Parameters + ---------- + expr : IndexedFunctionCall + Node has all the information about the function call. + + func : FunctionDef | Interface | PyccelInternalFunction type + The function being called. + + args : iterable of FunctionCallArgument + The arguments passed to the function. + + Returns + ------- + Pyccel.ast.cuda.KernelCall + The semantic representation of the kernel call. + """ + if len(expr.indexes) != 2: + errors.report(INVALID_KERNEL_LAUNCH_CONFIG, + symbol=expr, + severity='fatal') + if len(func.results): + errors.report(f"cuda kernel function '{func.name}' returned a value in violation of the laid-down specification", + symbol=expr, + severity='fatal') + if isinstance(func, FunctionDef) and len(args) != len(func.arguments): + errors.report(f"{len(args)} argument types given, but function takes {len(func.arguments)} arguments", + symbol=expr, + severity='fatal') + if not isinstance(expr.indexes[0], (LiteralInteger)): + if isinstance(expr.indexes[0], PyccelSymbol): + num_blocks = self.get_variable(expr.indexes[0]) + + if not isinstance(num_blocks.dtype, PythonNativeInt): + errors.report(INVALID_KERNEL_CALL_BP_GRID, + symbol = expr, + severity='fatal') + else: + errors.report(INVALID_KERNEL_CALL_BP_GRID, + symbol = expr, + severity='fatal') + if not isinstance(expr.indexes[1], (LiteralInteger)): + if isinstance(expr.indexes[1], PyccelSymbol): + tp_block = self.get_variable(expr.indexes[1]) + if not isinstance(tp_block.dtype, PythonNativeInt): + errors.report(INVALID_KERNEL_CALL_TP_BLOCK, + symbol = expr, + severity='fatal') + else: + errors.report(INVALID_KERNEL_CALL_TP_BLOCK, + symbol = expr, + severity='fatal') + new_expr = KernelCall(func, args, expr.indexes[0], expr.indexes[1]) + return new_expr + def _sort_function_call_args(self, func_args, args): """ Sort and add the missing call arguments to match the arguments in the function definition. @@ -2815,6 +2880,23 @@ def _visit_Lambda(self, expr): expr = Lambda(tuple(expr.variables), expr_new) return expr + def _visit_IndexedFunctionCall(self, expr): + name = expr.funcdef + name = self.scope.get_expected_name(name) + func = self.scope.find(name, 'functions') + args = self._handle_function_args(expr.args) + + if func is None: + return errors.report(UNDEFINED_FUNCTION, symbol=expr.funcdef, + bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset), + severity='fatal') + + func = self._annotate_the_called_function_def(func) + if 'kernel' in func.decorators : + return self._handle_kernel(expr, func, args) + else: + return errors.report("Unknown function type", + symbol=expr, severity='fatal') def _visit_FunctionCall(self, expr): name = expr.funcdef try: diff --git a/pyccel/parser/syntactic.py b/pyccel/parser/syntactic.py index 2967f4999b..3af7f0728a 100644 --- a/pyccel/parser/syntactic.py +++ b/pyccel/parser/syntactic.py @@ -64,6 +64,8 @@ from pyccel.ast.type_annotations import SyntacticTypeAnnotation, UnionTypeAnnotation +from pyccel.ast.core import IndexedFunctionCall + from pyccel.parser.base import BasicParser from pyccel.parser.extend_tree import extend_tree from pyccel.parser.utilities import get_default_path @@ -1102,6 +1104,8 @@ def _visit_Call(self, stmt): elif isinstance(func, DottedName): func_attr = FunctionCall(func.name[-1], args) func = DottedName(*func.name[:-1], func_attr) + elif isinstance(func,IndexedElement): + func = IndexedFunctionCall(func.base, args, func.indices) else: raise NotImplementedError(f' Unknown function type {type(func)}') diff --git a/tests/conftest.py b/tests/conftest.py index a5082ef6e8..4e74d1ec7a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,15 @@ def pytest_runtest_teardown(item, nextitem): def pytest_addoption(parser): parser.addoption("--developer-mode", action="store_true", default=github_debugging, help="Show tracebacks when pyccel errors are raised") + parser.addoption("--gpu_available", action="store_true", + default=False, help="enable GPU tests") + +def pytest_generate_tests(metafunc): + if "gpu_available" in metafunc.fixturenames: + if metafunc.config.getoption("gpu_available"): + metafunc.parametrize("gpu_available", [True]) + else: + metafunc.parametrize("gpu_available", [False]) def pytest_sessionstart(session): # setup_stuff diff --git a/tests/cuda/test_kernel_semantic.py b/tests/cuda/test_kernel_semantic.py new file mode 100644 index 0000000000..00b74c3bea --- /dev/null +++ b/tests/cuda/test_kernel_semantic.py @@ -0,0 +1,176 @@ +# pylint: disable=missing-function-docstring, missing-module-docstring +import pytest + +from pyccel import epyccel +from pyccel.decorators import kernel +from pyccel.errors.errors import Errors, PyccelSemanticError +from pyccel.errors.messages import (INVALID_KERNEL_CALL_TP_BLOCK, + INVALID_KERNEL_CALL_BP_GRID, + INVALID_KERNEL_LAUNCH_CONFIG) + + +@pytest.mark.cuda +def test_invalid_block_number(): + def invalid_block_number(): + @kernel + def kernel_call(): + pass + + blocks_per_grid = 1.0 + threads_per_block = 1 + kernel_call[blocks_per_grid, threads_per_block]() + + errors = Errors() + + with pytest.raises(PyccelSemanticError): + epyccel(invalid_block_number, language="cuda") + + assert errors.has_errors() + + assert errors.num_messages() == 1 + + error_info = [*errors.error_info_map.values()][0][0] + assert error_info.symbol.funcdef == 'kernel_call' + assert INVALID_KERNEL_CALL_BP_GRID == error_info.message + + +@pytest.mark.cuda +def test_invalid_thread_per_block(): + def invalid_thread_per_block(): + @kernel + def kernel_call(): + pass + + blocks_per_grid = 1 + threads_per_block = 1.0 + kernel_call[blocks_per_grid, threads_per_block]() + + errors = Errors() + + with pytest.raises(PyccelSemanticError): + epyccel(invalid_thread_per_block, language="cuda") + assert errors.has_errors() + assert errors.num_messages() == 1 + error_info = [*errors.error_info_map.values()][0][0] + assert error_info.symbol.funcdef == 'kernel_call' + assert INVALID_KERNEL_CALL_TP_BLOCK == error_info.message + + +@pytest.mark.cuda +def test_invalid_launch_config_high(): + def invalid_launch_config_high(): + @kernel + def kernel_call(): + pass + + blocks_per_grid = 1 + threads_per_block = 1 + third_param = 1 + kernel_call[blocks_per_grid, threads_per_block, third_param]() + + errors = Errors() + + with pytest.raises(PyccelSemanticError): + epyccel(invalid_launch_config_high, language="cuda") + + assert errors.has_errors() + assert errors.num_messages() == 1 + + error_info = [*errors.error_info_map.values()][0][0] + assert error_info.symbol.funcdef == 'kernel_call' + assert INVALID_KERNEL_LAUNCH_CONFIG == error_info.message + + +@pytest.mark.cuda +def test_invalid_launch_config_low(): + def invalid_launch_config_low(): + @kernel + def kernel_call(): + pass + + blocks_per_grid = 1 + kernel_call[blocks_per_grid]() + + errors = Errors() + + with pytest.raises(PyccelSemanticError): + epyccel(invalid_launch_config_low, language="cuda") + + assert errors.has_errors() + assert errors.num_messages() == 1 + + error_info = [*errors.error_info_map.values()][0][0] + assert error_info.symbol.funcdef == 'kernel_call' + assert INVALID_KERNEL_LAUNCH_CONFIG == error_info.message + + +@pytest.mark.cuda +def test_invalid_arguments_for_kernel_call(): + def invalid_arguments(): + @kernel + def kernel_call(arg : int): + pass + + blocks_per_grid = 1 + threads_per_block = 1 + kernel_call[blocks_per_grid, threads_per_block]() + + errors = Errors() + + with pytest.raises(PyccelSemanticError): + epyccel(invalid_arguments, language="cuda") + + assert errors.has_errors() + assert errors.num_messages() == 1 + + error_info = [*errors.error_info_map.values()][0][0] + assert error_info.symbol.funcdef == 'kernel_call' + assert "0 argument types given, but function takes 1 arguments" == error_info.message + + +@pytest.mark.cuda +def test_invalid_arguments_for_kernel_call_2(): + def invalid_arguments_(): + @kernel + def kernel_call(): + pass + + blocks_per_grid = 1 + threads_per_block = 1 + kernel_call[blocks_per_grid, threads_per_block](1) + + errors = Errors() + + with pytest.raises(PyccelSemanticError): + epyccel(invalid_arguments_, language="cuda") + + assert errors.has_errors() + assert errors.num_messages() == 1 + + error_info = [*errors.error_info_map.values()][0][0] + assert error_info.symbol.funcdef == 'kernel_call' + assert "1 argument types given, but function takes 0 arguments" == error_info.message + + +@pytest.mark.cuda +def test_kernel_return(): + def kernel_return(): + @kernel + def kernel_call(): + return 7 + + blocks_per_grid = 1 + threads_per_block = 1 + kernel_call[blocks_per_grid, threads_per_block](1) + + errors = Errors() + + with pytest.raises(PyccelSemanticError): + epyccel(kernel_return, language="cuda") + + assert errors.has_errors() + assert errors.num_messages() == 1 + + error_info = [*errors.error_info_map.values()][0][0] + assert error_info.symbol.funcdef == 'kernel_call' + assert "cuda kernel function 'kernel_call' returned a value in violation of the laid-down specification" == error_info.message diff --git a/tests/pyccel/scripts/kernel/hello_kernel.py b/tests/pyccel/scripts/kernel/hello_kernel.py new file mode 100644 index 0000000000..b6901b25a1 --- /dev/null +++ b/tests/pyccel/scripts/kernel/hello_kernel.py @@ -0,0 +1,19 @@ +# pylint: disable=missing-function-docstring, missing-module-docstring +from pyccel.decorators import kernel +from pyccel import cuda + +@kernel +def say_hello(its_morning : bool): + if(its_morning): + print("Hello and Good morning") + else: + print("Hello and Good afternoon") + +def f(): + its_morning = True + say_hello[1,1](its_morning) + cuda.synchronize() + +if __name__ == '__main__': + f() + diff --git a/tests/pyccel/scripts/kernel/kernel_name_collision.py b/tests/pyccel/scripts/kernel/kernel_name_collision.py new file mode 100644 index 0000000000..ac7abe25ae --- /dev/null +++ b/tests/pyccel/scripts/kernel/kernel_name_collision.py @@ -0,0 +1,8 @@ +# pylint: disable=missing-function-docstring, missing-module-docstring +from pyccel.decorators import kernel + +@kernel +def do(): + pass + +do[1,1]() diff --git a/tests/pyccel/test_pyccel.py b/tests/pyccel/test_pyccel.py index ec1e846549..b4757a3c31 100644 --- a/tests/pyccel/test_pyccel.py +++ b/tests/pyccel/test_pyccel.py @@ -294,7 +294,7 @@ def compare_pyth_fort_output( p_output, f_output, dtype=float, language=None): #------------------------------------------------------------------------------ def pyccel_test(test_file, dependencies = None, compile_with_pyccel = True, cwd = None, pyccel_commands = "", output_dtype = float, - language = None, output_dir = None): + language = None, output_dir = None, execute_code = True): """ Run pyccel and compare the output to ensure that the results are equivalent @@ -394,13 +394,14 @@ def pyccel_test(test_file, dependencies = None, compile_with_pyccel = True, compile_fortran(cwd, output_test_file, dependencies) elif language == 'c': compile_c(cwd, output_test_file, dependencies) - - lang_out = get_lang_output(output_test_file, language) - compare_pyth_fort_output(pyth_out, lang_out, output_dtype, language) + if execute_code: + lang_out = get_lang_output(output_test_file, language) + compare_pyth_fort_output(pyth_out, lang_out, output_dtype, language) #============================================================================== # UNIT TESTS #============================================================================== + def test_relative_imports_in_project(language): base_dir = os.path.dirname(os.path.realpath(__file__)) @@ -728,6 +729,19 @@ def test_multiple_results(language): def test_elemental(language): pyccel_test("scripts/decorators_elemental.py", language = language) +#------------------------------------------------------------------------------ +@pytest.mark.cuda +def test_hello_kernel(gpu_available): + types = str + pyccel_test("scripts/kernel/hello_kernel.py", + language="cuda", output_dtype=types , execute_code=gpu_available) + +#------------------------------------------------------------------------------ +@pytest.mark.cuda +def test_kernel_collision(gpu_available): + pyccel_test("scripts/kernel/kernel_name_collision.py", + language="cuda", execute_code=gpu_available) + #------------------------------------------------------------------------------ def test_print_strings(language): types = str