Skip to content

Commit

Permalink
Add back propagating of fast register state (#887)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
wild-endeavor committed Mar 11, 2022
1 parent e0dcac5 commit d98b482
Showing 1 changed file with 55 additions and 6 deletions.
61 changes: 55 additions & 6 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from flyteidl.core import literals_pb2 as _literals_pb2

from flytekit import PythonFunctionTask
from flytekit.configuration import SerializationSettings, StatsConfig
from flytekit.configuration import FastSerializationSettings, SerializationSettings, StatsConfig
from flytekit.core import SERIALIZED_CONTEXT_ENV_VAR
from flytekit.core import constants as _constants
from flytekit.core import utils
Expand All @@ -23,8 +23,8 @@
from flytekit.exceptions import scopes as _scoped_exceptions
from flytekit.exceptions import scopes as _scopes
from flytekit.interfaces.stats.taggable import get_stats as _get_stats
from flytekit.loggers import entrypoint_logger
from flytekit.loggers import entrypoint_logger as logger
from flytekit.loggers import user_space_logger
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import literals as _literal_models
from flytekit.models.core import errors as _error_models
Expand Down Expand Up @@ -175,7 +175,19 @@ def setup_execution(
raw_output_data_prefix: str,
checkpoint_path: Optional[str] = None,
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
):
"""
:param raw_output_data_prefix:
:param checkpoint_path:
:param prev_checkpoint:
:param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic
task were to run, it should set fast serialize to true and use these values in FastSerializationSettings
:param dynamic_dest_dir: See above.
:return:
"""
exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ")
exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM")
exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_NAME", "_F_NM")
Expand Down Expand Up @@ -222,7 +234,7 @@ def setup_execution(
"api_version": _api_version,
},
),
logging=entrypoint_logger,
logging=user_space_logger,
tmp_dir=user_workspace_dir,
raw_output_prefix=raw_output_data_prefix,
checkpoint=checkpointer,
Expand All @@ -249,6 +261,12 @@ def setup_execution(
ssb.project = exe_project
ssb.domain = exe_domain
ssb.version = tk_version
if dynamic_addl_distro:
ssb.fast_serialization_settings = FastSerializationSettings(
enabled=True,
destination_dir=dynamic_dest_dir,
distribution_location=dynamic_addl_distro,
)
cb = cb.with_serialization_settings(ssb.build())

with FlyteContextManager.with_context(cb) as ctx:
Expand Down Expand Up @@ -277,6 +295,8 @@ def _execute_task(
resolver_args: List[str],
checkpoint_path: Optional[str] = None,
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
):
"""
This function should be called for new API tasks (those only available in 0.16 and later that leverage Python
Expand Down Expand Up @@ -304,7 +324,13 @@ def _execute_task(
if len(resolver_args) < 1:
raise Exception("cannot be <1")

with setup_execution(raw_output_data_prefix, checkpoint_path, prev_checkpoint) as ctx:
with setup_execution(
raw_output_data_prefix,
checkpoint_path,
prev_checkpoint,
dynamic_addl_distro,
dynamic_dest_dir,
) as ctx:
resolver_obj = load_object_from_module(resolver)
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
Expand All @@ -327,6 +353,8 @@ def _execute_map_task(
resolver_args: List[str],
checkpoint_path: Optional[str] = None,
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
):
"""
This function should be called by map task and aws-batch task
Expand All @@ -348,7 +376,9 @@ def _execute_map_task(
if len(resolver_args) < 1:
raise Exception(f"Resolver args cannot be <1, got {resolver_args}")

with setup_execution(raw_output_data_prefix, checkpoint_path, prev_checkpoint) as ctx:
with setup_execution(
raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir
) as ctx:
resolver_obj = load_object_from_module(resolver)
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
Expand Down Expand Up @@ -398,6 +428,8 @@ def _pass_through():
@_click.option("--checkpoint-path", required=False)
@_click.option("--prev-checkpoint", required=False)
@_click.option("--test", is_flag=True)
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
@_click.option("--resolver", required=False)
@_click.argument(
"resolver-args",
Expand All @@ -411,6 +443,8 @@ def execute_task_cmd(
test,
prev_checkpoint,
checkpoint_path,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
):
Expand All @@ -434,6 +468,8 @@ def execute_task_cmd(
test=test,
resolver=resolver,
resolver_args=resolver_args,
dynamic_addl_distro=dynamic_addl_distro,
dynamic_dest_dir=dynamic_dest_dir,
checkpoint_path=checkpoint_path,
prev_checkpoint=prev_checkpoint,
)
Expand All @@ -453,9 +489,16 @@ def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, task_exec
dest_dir = os.getcwd()
_download_distribution(additional_distribution, dest_dir)

# Insert the call to fast before the unbounded resolver args
cmd = []
for arg in task_execute_cmd:
if arg == "--resolver":
cmd.extend(["--dynamic-addl-distro", additional_distribution, "--dynamic-dest-dir", dest_dir])
cmd.append(arg)

# Use the commandline to run the task execute command rather than calling it directly in python code
# since the current runtime bytecode references the older user code, rather than the downloaded distribution.
os.system(" ".join(task_execute_cmd))
os.system(" ".join(cmd))


@_pass_through.command("pyflyte-map-execute")
Expand All @@ -464,6 +507,8 @@ def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, task_exec
@_click.option("--raw-output-data-prefix", required=False)
@_click.option("--max-concurrency", type=int, required=False)
@_click.option("--test", is_flag=True)
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
@_click.option("--resolver", required=True)
@_click.option("--checkpoint-path", required=False)
@_click.option("--prev-checkpoint", required=False)
Expand All @@ -478,6 +523,8 @@ def map_execute_task_cmd(
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
prev_checkpoint,
Expand All @@ -495,6 +542,8 @@ def map_execute_task_cmd(
raw_output_data_prefix=raw_output_data_prefix,
max_concurrency=max_concurrency,
test=test,
dynamic_addl_distro=dynamic_addl_distro,
dynamic_dest_dir=dynamic_dest_dir,
resolver=resolver,
resolver_args=resolver_args,
checkpoint_path=checkpoint_path,
Expand Down

0 comments on commit d98b482

Please sign in to comment.