diff --git a/sdk/python/kfp/cli/cli_test.py b/sdk/python/kfp/cli/cli_test.py index 361db73a14e..bcec7798a64 100644 --- a/sdk/python/kfp/cli/cli_test.py +++ b/sdk/python/kfp/cli/cli_test.py @@ -166,6 +166,82 @@ def test_deprecation_warning(self): res.stdout.decode('utf-8')) +class TestKfpDslCompile(unittest.TestCase): + + def invoke(self, args): + starting_args = ['dsl', 'compile'] + args = starting_args + args + runner = testing.CliRunner() + return runner.invoke( + cli=cli.cli, args=args, catch_exceptions=False, obj={}) + + def test_compile_with_caching_flag_enabled(self): + with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline: + # Write the pipeline function to a temporary file + temp_pipeline.write(b""" +from kfp import dsl + +@dsl.component +def my_component(): + pass + +@dsl.pipeline(name="tiny-pipeline") +def my_pipeline(): + my_component_task = my_component() +""") + temp_pipeline.flush() # Ensure the data is written to disk + + # Invoke the CLI command with the temporary file + result = self.invoke( + ['--py', temp_pipeline.name, '--output', 'test_output.yaml']) + print(result.output) # Print the command output + self.assertEqual(result.exit_code, 0) + + def test_compile_with_caching_flag_disabled(self): + with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline: + temp_pipeline.write(b""" +from kfp import dsl + +@dsl.component +def my_component(): + pass + +@dsl.pipeline(name="tiny-pipeline") +def my_pipeline(): + my_component_task = my_component() +""") + temp_pipeline.flush() + + result = self.invoke([ + '--py', temp_pipeline.name, '--output', 'test_output.yaml', + '--disable-execution-caching-by-default' + ]) + print(result.output) + self.assertEqual(result.exit_code, 0) + + def test_compile_with_caching_disabled_env_var(self): + with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline: + temp_pipeline.write(b""" +from kfp import dsl + +@dsl.component +def my_component(): + pass + +@dsl.pipeline(name="tiny-pipeline") +def my_pipeline(): + my_component_task = my_component() +""") + temp_pipeline.flush() + + os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true' + result = self.invoke( + ['--py', temp_pipeline.name, '--output', 'test_output.yaml']) + print(result.output) + self.assertEqual(result.exit_code, 0) + del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] + + info_dict = cli.cli.to_info_dict(ctx=click.Context(cli.cli)) commands_dict = { command: list(body.get('commands', {}).keys()) diff --git a/sdk/python/kfp/cli/compile_.py b/sdk/python/kfp/cli/compile_.py index f4b2c3b4570..e1fc28a8328 100644 --- a/sdk/python/kfp/cli/compile_.py +++ b/sdk/python/kfp/cli/compile_.py @@ -26,6 +26,7 @@ from kfp.dsl import graph_component from kfp.dsl.pipeline_context import Pipeline + def is_pipeline_func(func: Callable) -> bool: """Checks if a function is a pipeline function. diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 7f0cfd4b98a..16ebe9e655d 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -910,6 +910,59 @@ def my_pipeline() -> NamedTuple('Outputs', [ task = print_and_return(text='Hello') +class TestCompilePipelineCaching(unittest.TestCase): + + def test_compile_pipeline_with_caching_enabled(self): + """Test pipeline compilation with caching enabled.""" + + @dsl.component + def my_component(): + pass + + @dsl.pipeline(name='tiny-pipeline') + def my_pipeline(): + my_task = my_component() + my_task.set_caching_options(True) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + + task_spec = pipeline_spec['root']['dag']['tasks']['my-component'] + caching_options = task_spec['cachingOptions'] + + self.assertTrue(caching_options['enableCache']) + + def test_compile_pipeline_with_caching_disabled(self): + """Test pipeline compilation with caching disabled.""" + + @dsl.component + def my_component(): + pass + + @dsl.pipeline(name='tiny-pipeline') + def my_pipeline(): + my_task = my_component() + my_task.set_caching_options(False) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + + task_spec = pipeline_spec['root']['dag']['tasks']['my-component'] + caching_options = task_spec.get('cachingOptions', {}) + + self.assertEqual(caching_options, {}) + + class V2NamespaceAliasTest(unittest.TestCase): """Test that imports of both modules and objects are aliased (e.g. all import path variants work).""" diff --git a/sdk/python/kfp/dsl/base_component.py b/sdk/python/kfp/dsl/base_component.py index 089a1111637..e1693ece3e3 100644 --- a/sdk/python/kfp/dsl/base_component.py +++ b/sdk/python/kfp/dsl/base_component.py @@ -101,9 +101,10 @@ def __call__(self, *args, **kwargs) -> pipeline_task.PipelineTask: return pipeline_task.PipelineTask( component_spec=self.component_spec, args=task_inputs, - execute_locally=pipeline_context.Pipeline.get_default_pipeline() is - None, - execution_caching_default=pipeline_context.Pipeline.get_execution_caching_default(), + execute_locally=pipeline_context.Pipeline.get_default_pipeline() + is None, + execution_caching_default=pipeline_context.Pipeline + .get_execution_caching_default(), ) @property diff --git a/sdk/python/kfp/dsl/pipeline_context.py b/sdk/python/kfp/dsl/pipeline_context.py index f9a45d8676c..4d0bbbaa840 100644 --- a/sdk/python/kfp/dsl/pipeline_context.py +++ b/sdk/python/kfp/dsl/pipeline_context.py @@ -14,6 +14,7 @@ """Definition for Pipeline.""" import functools +import os from typing import Callable, Optional from kfp.dsl import component_factory @@ -21,8 +22,6 @@ from kfp.dsl import tasks_group from kfp.dsl import utils -import os - def pipeline(func: Optional[Callable] = None, *, @@ -107,7 +106,9 @@ def get_default_pipeline(): # or the env var KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT. # align with click's treatment of env vars for boolean flags. # per click doc, "1", "true", "t", "yes", "y", and "on" are all converted to True - _execution_caching_default = not str(os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower() in {"1", "true", "t", "yes", "y", "on"} + _execution_caching_default = not str( + os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower( + ) in {'1', 'true', 't', 'yes', 'y', 'on'} @staticmethod def get_execution_caching_default(): diff --git a/sdk/python/kfp/dsl/pipeline_task.py b/sdk/python/kfp/dsl/pipeline_task.py index 752a6cbdf84..46f0679eeda 100644 --- a/sdk/python/kfp/dsl/pipeline_task.py +++ b/sdk/python/kfp/dsl/pipeline_task.py @@ -162,13 +162,14 @@ def validate_placeholder_types( self.pipeline_spec = self.component_spec.implementation.graph self._outputs = { - output_name: pipeline_channel.create_pipeline_channel( - name=output_name, - channel_type=output_spec.type, - task_name=self._task_spec.name, - is_artifact_list=output_spec.is_artifact_list, - ) for output_name, output_spec in ( - component_spec.outputs or {}).items() + output_name: + pipeline_channel.create_pipeline_channel( + name=output_name, + channel_type=output_spec.type, + task_name=self._task_spec.name, + is_artifact_list=output_spec.is_artifact_list, + ) for output_name, output_spec in ( + component_spec.outputs or {}).items() } self._inputs = args