diff --git a/MANIFEST.in b/MANIFEST.in index 9a6845ec5..f8d096895 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1 @@ include runhouse/builtins/* -include runhouse/resources/hardware/sagemaker/* diff --git a/README.md b/README.md index cc59ffcb9..70fa07b6a 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,6 @@ Please reach out (first name at run.house) if you don't see your favorite comput - Amazon Web Services (AWS) - EC2 - **Supported** - EKS - **Supported** - - SageMaker - **Supported** - Lambda - **Alpha** - Google Cloud Platform (GCP) - GCE - **Supported** diff --git a/docs/api/python/cluster.rst b/docs/api/python/cluster.rst index 5695870d8..e9ae26f06 100644 --- a/docs/api/python/cluster.rst +++ b/docs/api/python/cluster.rst @@ -1,9 +1,9 @@ Cluster ======= A Cluster is a Runhouse primitive used for abstracting a particular hardware configuration. -This can be either an :ref:`on-demand cluster ` (requires valid cloud credentials), a -:ref:`BYO (bring-your-own) cluster ` (requires IP address and ssh creds), or a -:ref:`SageMaker cluster ` (requires an ARN role). +This can be either an :ref:`on-demand cluster ` (requires valid cloud credentials or a +local Kube config if launching on Kubernetes), or a +:ref:`BYO (bring-your-own) cluster ` (requires IP address and ssh creds). A cluster is assigned a name, through which it can be accessed and reused later on. @@ -14,8 +14,6 @@ Cluster Factory Methods .. autofunction:: runhouse.ondemand_cluster -.. autofunction:: runhouse.sagemaker_cluster - Cluster Class ~~~~~~~~~~~~~ @@ -75,141 +73,6 @@ See the `SkyPilot docs ` or an :ref:`on-demand cluster `. - Runhouse will handle launching the SageMaker compute and creating the SSH connection - to the cluster. - -- **Dedicated training jobs**: You can use a SageMakerCluster class to run a training job on SageMaker compute. - To do so, you will need to provide an - `estimator `__. - -.. note:: - - Runhouse requires an AWS IAM role (either name or full ARN) whose credentials have adequate permissions to - create create SageMaker endpoints and access AWS resources. - - Please see :ref:`SageMaker Hardware Setup` for more specific instructions and - requirements for providing the role and setting up the cluster. - -.. autoclass:: runhouse.SageMakerCluster - :members: - :exclude-members: - - .. automethod:: __init__ - -SageMaker Hardware Setup ------------------------- - -IAM Role -^^^^^^^^ - -SageMaker clusters require `AWS CLI V2 `__ and -configuring the SageMaker IAM role with the -`AWS Systems Manager `__. - - -In order to launch a cluster, you must grant SageMaker the necessary permissions with an IAM role, which -can be provided either by name or by full ARN. You can also specify a profile explicitly or -with the :code:`AWS_PROFILE` environment variable. - -For example, let's say your local :code:`~/.aws/config` file contains: - -.. code-block:: ini - - [profile sagemaker] - role_arn = arn:aws:iam::123456789:role/service-role/AmazonSageMaker-ExecutionRole-20230717T192142 - region = us-east-1 - source_profile = default - -There are several ways to provide the necessary credentials when :ref:`initializing the cluster `: - -- Providing the AWS profile name: :code:`profile="sagemaker"` -- Providing the AWS Role ARN directly: :code:`role="arn:aws:iam::123456789:role/service-role/AmazonSageMaker-ExecutionRole-20230717T192142"` -- Environment Variable: setting :code:`AWS_PROFILE` to :code:`"sagemaker"` - -.. note:: - - If no role or profile is provided, Runhouse will try using the :code:`default` profile. Note if this default AWS - identity is not a role, then you will need to provide the :code:`role` or :code:`profile` explicitly. - -.. tip:: - - If you are providing an estimator, you must provide the role ARN explicitly as part of the estimator object. - More info on estimators `here `__. - -Please see the `AWS docs `__ for further -instructions on creating and configuring an ARN Role. - - -AWS CLI V2 -^^^^^^^^^^ - -The SageMaker SDK uses AWS CLI V2, which must be installed on your local machine. Doing so requires one of two steps: - -- `Migrate from V1 to V2 `_ - -- `Install V2 `_ - - -To confirm the installation succeeded, run ``aws --version`` in the command line. You should see something like: - -.. code-block:: - - $ aws-cli/2.13.8 Python/3.11.4 Darwin/21.3.0 source/arm64 prompt/off - -If you are still seeing the V1 version, first try uninstalling V1 in case it is still present -(e.g. ``pip uninstall awscli``). - -You may also need to add the V2 executable to the PATH of your python environment. For example, if you are using conda, -it’s possible the conda env will try using its own version of the AWS CLI located at a different -path (e.g. ``/opt/homebrew/anaconda3/bin/aws``), while the system wide installation of AWS CLI is located somewhere -else (e.g. ``/opt/homebrew/bin/aws``). - -To find the global AWS CLI path: - -.. code-block:: - - $ which aws - -To ensure that the global AWS CLI version is used within your python environment, you’ll need to adjust the -PATH environment variable so that it prioritizes the global AWS CLI path. - -.. code-block:: - - $ export PATH=/opt/homebrew/bin:$PATH - - -SSM Setup -^^^^^^^^^ -The AWS Systems Manager service is used to create SSH tunnels with the SageMaker cluster. - -To install the AWS Session Manager Plugin, please see the `AWS docs `_ -or `SageMaker SSH Helper `__. The SSH Helper package -simplifies the process of creating SSH tunnels with SageMaker clusters. It is installed by default if -you are installing Runhouse with the SageMaker dependency: :code:`pip install runhouse[sagemaker]`. - -You can also install the Session Manager by running the CLI command: - -.. code-block:: - - $ sm-local-configure - -To configure your SageMaker IAM role with the AWS Systems Manager, please -refer to `these instructions `__. - - Cluster Authentication & Verification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Runhouse provides a couple of options to manage the connection to the Runhouse API server running on a cluster. @@ -228,10 +91,6 @@ be started on the cluster on port :code:`32300`. - ``none``: Does not use any port forwarding or enforce any authentication. Connects to the cluster with HTTP by default on port :code:`80`. This is useful when connecting to a cluster within a VPC, or creating a tunnel manually on the side with custom settings. -- ``aws_ssm``: Uses the - `AWS Systems Manager `__ to - create an SSH tunnel to the cluster, by default on port :code:`32300`. *Note: this is currently only relevant - for SageMaker Clusters.* .. note:: diff --git a/docs/docker-setup.rst b/docs/docker-setup.rst index f930736e5..b20b4acf5 100644 --- a/docs/docker-setup.rst +++ b/docs/docker-setup.rst @@ -17,8 +17,7 @@ is automatically built and set up remotely on the cluster. The Runhouse server will start directly inside the remote container. **NOTE:** This guide details the setup and usage for on-demand clusters -only. Docker container is also supported for Sagemaker clusters, and it -is not yet supported for static clusters. +only. It is not yet supported for static clusters. Cluster & Docker Setup ---------------------- diff --git a/docs/requirements.txt b/docs/requirements.txt index 7617a198a..a112c1cb4 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,7 +4,6 @@ pint==0.20.1 pyarrow==9.0.0 pydata-sphinx-theme==0.13.3 ray>=2.2.0 -sagemaker sentry-sdk==1.28.1 sphinx-book-theme==1.0.1 sphinx-click==4.3.0 diff --git a/docs/tutorials/api-clusters.rst b/docs/tutorials/api-clusters.rst index 04c11b9cf..905fcbd8a 100644 --- a/docs/tutorials/api-clusters.rst +++ b/docs/tutorials/api-clusters.rst @@ -95,7 +95,7 @@ remotely on your AWS instance. On-Demand Clusters within Existing Cloud VPC -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If you would like to launch on-demand clusters using existing VPCs, you can easily set it up by configuring SkyPilot. Without setting VPC, we launch in the default VPC in the region of the cluster. If you do diff --git a/docs/tutorials/api-resources.rst b/docs/tutorials/api-resources.rst index dce5440cf..2585cead5 100644 --- a/docs/tutorials/api-resources.rst +++ b/docs/tutorials/api-resources.rst @@ -276,16 +276,16 @@ to notify them. INFO | 2024-08-18 06:51:39.797150 | Saving config for aws-cpu-ssh-secret to Den INFO | 2024-08-18 06:51:39.972763 | Saving secrets for aws-cpu-ssh-secret to Vault - INFO | 2024-08-18 06:51:40.190996 | Saving config to RNS: {'name': '/jlewitt1/aws-cpu_default_env', 'resource_type': 'env', 'resource_subtype': 'Env', 'provenance': None, 'visibility': 'private', 'env_vars': {}, 'env_name': 'aws-cpu_default_env', 'compute': {}, 'reqs': ['ray==2.30.0'], 'working_dir': None} - INFO | 2024-08-18 06:51:40.368442 | Saving config to RNS: {'name': '/jlewitt1/aws-cpu', 'resource_type': 'cluster', 'resource_subtype': 'OnDemandCluster', 'provenance': None, 'visibility': 'private', 'ips': ['3.14.144.103'], 'server_port': 32300, 'server_connection_type': 'ssh', 'den_auth': False, 'ssh_port': 22, 'client_port': 32300, 'creds': '/jlewitt1/aws-cpu-ssh-secret', 'api_server_url': 'https://api.run.house', 'default_env': '/jlewitt1/aws-cpu_default_env', 'instance_type': 'CPU:2+', 'provider': 'aws', 'open_ports': [], 'use_spot': False, 'image_id': 'docker:nvcr.io/nvidia/pytorch:23.10-py3', 'region': 'us-east-2', 'stable_internal_external_ips': [('172.31.5.134', '3.14.144.103')], 'sky_kwargs': {'launch': {'retry_until_up': True}}, 'launched_properties': {'cloud': 'aws', 'instance_type': 'm6i.large', 'region': 'us-east-2', 'cost_per_hour': 0.096, 'docker_user': 'root'}, 'autostop_mins': -1} + INFO | 2024-08-18 06:51:40.190996 | Saving config to RNS: {'name': '/jlewitt1/aws-cpu_default_env', 'resource_type': 'env', 'resource_subtype': 'Env', 'visibility': 'private', 'env_vars': {}, 'env_name': 'aws-cpu_default_env', 'compute': {}, 'reqs': ['ray==2.30.0'], 'working_dir': None} + INFO | 2024-08-18 06:51:40.368442 | Saving config to RNS: {'name': '/jlewitt1/aws-cpu', 'resource_type': 'cluster', 'resource_subtype': 'OnDemandCluster', 'visibility': 'private', 'ips': ['3.14.144.103'], 'server_port': 32300, 'server_connection_type': 'ssh', 'den_auth': False, 'ssh_port': 22, 'client_port': 32300, 'creds': '/jlewitt1/aws-cpu-ssh-secret', 'api_server_url': 'https://api.run.house', 'default_env': '/jlewitt1/aws-cpu_default_env', 'instance_type': 'CPU:2+', 'provider': 'aws', 'open_ports': [], 'use_spot': False, 'image_id': 'docker:nvcr.io/nvidia/pytorch:23.10-py3', 'region': 'us-east-2', 'stable_internal_external_ips': [('172.31.5.134', '3.14.144.103')], 'sky_kwargs': {'launch': {'retry_until_up': True}}, 'launched_properties': {'cloud': 'aws', 'instance_type': 'm6i.large', 'region': 'us-east-2', 'cost_per_hour': 0.096, 'docker_user': 'root'}, 'autostop_mins': -1} INFO | 2024-08-18 06:51:40.548233 | Sharing cluster credentials, which enables the recipient to SSH into the cluster. INFO | 2024-08-18 06:51:40.551277 | Saving config for aws-cpu-ssh-secret to Den INFO | 2024-08-18 06:51:40.728345 | Saving secrets for aws-cpu-ssh-secret to Vault - INFO | 2024-08-18 06:51:41.150745 | Saving config to RNS: {'name': '/jlewitt1/aws-cpu_default_env', 'resource_type': 'env', 'resource_subtype': 'Env', 'provenance': None, 'visibility': 'private', 'env_vars': {}, 'env_name': 'aws-cpu_default_env', 'compute': {}, 'reqs': ['ray==2.30.0'], 'working_dir': None} + INFO | 2024-08-18 06:51:41.150745 | Saving config to RNS: {'name': '/jlewitt1/aws-cpu_default_env', 'resource_type': 'env', 'resource_subtype': 'Env', 'visibility': 'private', 'env_vars': {}, 'env_name': 'aws-cpu_default_env', 'compute': {}, 'reqs': ['ray==2.30.0'], 'working_dir': None} INFO | 2024-08-18 06:51:42.006030 | Saving config for aws-cpu-ssh-secret to Den INFO | 2024-08-18 06:51:42.504070 | Saving secrets for aws-cpu-ssh-secret to Vault - INFO | 2024-08-18 06:51:42.728653 | Saving config to RNS: {'name': '/jlewitt1/aws-cpu_default_env', 'resource_type': 'env', 'resource_subtype': 'Env', 'provenance': None, 'visibility': 'private', 'env_vars': {}, 'env_name': 'aws-cpu_default_env', 'compute': {}, 'reqs': ['ray==2.30.0'], 'working_dir': None} - INFO | 2024-08-18 06:51:42.906615 | Saving config to RNS: {'name': '/jlewitt1/aws-cpu', 'resource_type': 'cluster', 'resource_subtype': 'OnDemandCluster', 'provenance': None, 'visibility': 'private', 'ips': ['3.14.144.103'], 'server_port': 32300, 'server_connection_type': 'ssh', 'den_auth': False, 'ssh_port': 22, 'client_port': 32300, 'creds': '/jlewitt1/aws-cpu-ssh-secret', 'api_server_url': 'https://api.run.house', 'default_env': '/jlewitt1/aws-cpu_default_env', 'instance_type': 'CPU:2+', 'provider': 'aws', 'open_ports': [], 'use_spot': False, 'image_id': 'docker:nvcr.io/nvidia/pytorch:23.10-py3', 'region': 'us-east-2', 'stable_internal_external_ips': [('172.31.5.134', '3.14.144.103')], 'sky_kwargs': {'launch': {'retry_until_up': True}}, 'launched_properties': {'cloud': 'aws', 'instance_type': 'm6i.large', 'region': 'us-east-2', 'cost_per_hour': 0.096, 'docker_user': 'root'}, 'autostop_mins': -1} + INFO | 2024-08-18 06:51:42.728653 | Saving config to RNS: {'name': '/jlewitt1/aws-cpu_default_env', 'resource_type': 'env', 'resource_subtype': 'Env', 'visibility': 'private', 'env_vars': {}, 'env_name': 'aws-cpu_default_env', 'compute': {}, 'reqs': ['ray==2.30.0'], 'working_dir': None} + INFO | 2024-08-18 06:51:42.906615 | Saving config to RNS: {'name': '/jlewitt1/aws-cpu', 'resource_type': 'cluster', 'resource_subtype': 'OnDemandCluster', 'visibility': 'private', 'ips': ['3.14.144.103'], 'server_port': 32300, 'server_connection_type': 'ssh', 'den_auth': False, 'ssh_port': 22, 'client_port': 32300, 'creds': '/jlewitt1/aws-cpu-ssh-secret', 'api_server_url': 'https://api.run.house', 'default_env': '/jlewitt1/aws-cpu_default_env', 'instance_type': 'CPU:2+', 'provider': 'aws', 'open_ports': [], 'use_spot': False, 'image_id': 'docker:nvcr.io/nvidia/pytorch:23.10-py3', 'region': 'us-east-2', 'stable_internal_external_ips': [('172.31.5.134', '3.14.144.103')], 'sky_kwargs': {'launch': {'retry_until_up': True}}, 'launched_properties': {'cloud': 'aws', 'instance_type': 'm6i.large', 'region': 'us-east-2', 'cost_per_hour': 0.096, 'docker_user': 'root'}, 'autostop_mins': -1} diff --git a/requirements.txt b/requirements.txt index ec33bcfd9..c7790fa3f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ wheel apispec httpx pydantic >=2.5.0 +pynvml diff --git a/runhouse/__init__.py b/runhouse/__init__.py index 1e87674af..8ae83f2f8 100644 --- a/runhouse/__init__.py +++ b/runhouse/__init__.py @@ -1,5 +1,4 @@ from runhouse.resources.asgi import Asgi, asgi -from runhouse.resources.blobs import blob, Blob, file, File from runhouse.resources.envs import conda_env, CondaEnv, env, Env from runhouse.resources.folders import Folder, folder, GCSFolder, S3Folder from runhouse.resources.functions.aws_lambda import LambdaFunction @@ -12,8 +11,6 @@ kubernetes_cluster, ondemand_cluster, OnDemandCluster, - sagemaker_cluster, - SageMakerCluster, ) # WARNING: Any built-in module that is imported here must be capitalized followed by all lowercase, or we will @@ -26,7 +23,6 @@ package, Package, ) -from runhouse.resources.provenance import capture_stdout, Run, run, RunStatus, RunType from runhouse.resources.resource import Resource from runhouse.resources.secrets import provider_secret, ProviderSecret, Secret, secret @@ -63,4 +59,4 @@ def __getattr__(name): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -__version__ = "0.0.33" +__version__ = "0.0.34" diff --git a/runhouse/constants.py b/runhouse/constants.py index 8be842b63..ba0bd41c1 100644 --- a/runhouse/constants.py +++ b/runhouse/constants.py @@ -10,6 +10,7 @@ LOCALHOST: str = "127.0.0.1" LOCAL_HOSTS: List[str] = ["localhost", LOCALHOST] TUNNEL_TIMEOUT = 5 +NUM_PORTS_TO_TRY = 10 LOGS_DIR = ".rh/logs" RH_LOGFILE_PATH = Path.home() / LOGS_DIR @@ -73,11 +74,23 @@ # Constants for the status check DOUBLE_SPACE_UNICODE = "\u00A0\u00A0" BULLET_UNICODE = "\u2022" +SECOND = 1 MINUTE = 60 HOUR = 3600 DEFAULT_STATUS_CHECK_INTERVAL = 1 * MINUTE INCREASED_STATUS_CHECK_INTERVAL = 1 * HOUR -STATUS_CHECK_DELAY = 1 * MINUTE +GPU_COLLECTION_INTERVAL = 5 * SECOND + +# We collect gpu every GPU_COLLECTION_INTERVAL. +# Meaning that in one minute we collect (MINUTE / GPU_COLLECTION_INTERVAL) gpu stats. +# Currently, we save gpu info of the last 10 minutes or less. +MAX_GPU_INFO_LEN = (MINUTE / GPU_COLLECTION_INTERVAL) * 10 + +# If we just collect the gpu stats (and not send them to den), the gpu_info dictionary *will not* be reseted by the servlets. +# Therefore, we need to cut the gpu_info size, so it doesn't consume too much cluster memory. +# Currently, we reduce the size by half, meaning we only keep the gpu_info of the last (MAX_GPU_INFO_LEN / 2) minutes. +REDUCED_GPU_INFO_LEN = MAX_GPU_INFO_LEN / 2 + # Constants Surfacing Logs to Den DEFAULT_LOG_SURFACING_INTERVAL = 2 * MINUTE diff --git a/runhouse/main.py b/runhouse/main.py index 0566d25d6..f9a65e603 100644 --- a/runhouse/main.py +++ b/runhouse/main.py @@ -366,14 +366,13 @@ def _print_envs_info( total_gpu_memory = math.ceil( float(env_gpu_info.get("total_memory")) / (1024**3) ) - gpu_util_percent = round(float(env_gpu_info.get("utilization_percent")), 2) used_gpu_memory = round( float(env_gpu_info.get("used_memory")) / (1024**3), 2 ) gpu_memory_usage_percent = round( float(used_gpu_memory / total_gpu_memory) * 100, 2 ) - gpu_usage_summery = f"{DOUBLE_SPACE_UNICODE}GPU: {gpu_util_percent}% | Memory: {used_gpu_memory} / {total_gpu_memory} Gb ({gpu_memory_usage_percent}%)" + gpu_usage_summery = f"{DOUBLE_SPACE_UNICODE}GPU Memory: {used_gpu_memory} / {total_gpu_memory} Gb ({gpu_memory_usage_percent}%)" console.print(gpu_usage_summery) resources_in_env = [ @@ -408,6 +407,8 @@ def _print_status(status_data: dict, current_cluster: Cluster) -> None: if "name" in cluster_config.keys(): console.print(cluster_config.get("name")) + has_cuda: bool = cluster_config.get("has_cuda") + # print headline daemon_headline_txt = ( "\N{smiling face with horns} Runhouse Daemon is running \N{Runner}" @@ -420,6 +421,22 @@ def _print_status(status_data: dict, current_cluster: Cluster) -> None: # Print relevant info from cluster config. _print_cluster_config(cluster_config) + # print general cpu and gpu utilization + cluster_gpu_utilization: float = status_data.get("server_gpu_utilization") + + # cluster_gpu_utilization can be none, if the cluster was not using its GPU at the moment cluster.status() was invoked. + if cluster_gpu_utilization is None and has_cuda: + cluster_gpu_utilization: float = 0.0 + + cluster_cpu_utilization: float = status_data.get("server_cpu_utilization") + + server_util_info = ( + f"CPU Utilization: {round(cluster_cpu_utilization, 2)}% | GPU Utilization: {round(cluster_gpu_utilization,2)}%" + if has_cuda + else f"CPU Utilization: {round(cluster_cpu_utilization, 2)}%" + ) + console.print(server_util_info) + # print the environments in the cluster, and the resources associated with each environment. _print_envs_info(env_servlet_processes, current_cluster) diff --git a/runhouse/resources/blobs/__init__.py b/runhouse/resources/blobs/__init__.py deleted file mode 100644 index b4fbc13eb..000000000 --- a/runhouse/resources/blobs/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .blob import Blob, blob -from .file import File, file diff --git a/runhouse/resources/blobs/blob.py b/runhouse/resources/blobs/blob.py deleted file mode 100644 index 3fff84eda..000000000 --- a/runhouse/resources/blobs/blob.py +++ /dev/null @@ -1,210 +0,0 @@ -from pathlib import Path -from typing import Any, Optional, Union - -from runhouse.resources.envs.env import Env -from runhouse.resources.envs.utils import _get_env_from -from runhouse.resources.hardware import _current_cluster, _get_cluster_from, Cluster - -from runhouse.resources.module import Module -from runhouse.rns.utils.names import _generate_default_path -from runhouse.utils import generate_default_name - - -class Blob(Module): - RESOURCE_TYPE = "blob" - DEFAULT_FOLDER_PATH = "/runhouse-blob" - DEFAULT_CACHE_FOLDER = ".cache/runhouse/blobs" - - def __init__( - self, - name: Optional[str] = None, - system: Union[Cluster, str] = None, - env: Optional[Env] = None, - dryrun: bool = False, - **kwargs, - ): - """ - Runhouse Blob object - - .. note:: - To build a Blob, please use the factory method :func:`blob`. - """ - self.data = None - super().__init__(name=name, system=system, env=env, dryrun=dryrun, **kwargs) - - def to( - self, - system: Union[str, Cluster], - env: Optional[Union[str, Env]] = None, - path: Optional[str] = None, - ): - """Return a copy of the blob on the destination system, and optionally path. - - Example: - >>> local_blob = rh.blob(data) - >>> s3_blob = blob.to("s3") - >>> cluster_blob = blob.to(my_cluster) - """ - if system == "here": - if not path: - current_cluster_config = _current_cluster(key="config") - if current_cluster_config: - system = Cluster.from_config(current_cluster_config) - else: - system = None - else: - system = "file" - - system = _get_cluster_from(system) - if (not system or isinstance(system, Cluster)) and not path: - self.name = self.name or generate_default_name(prefix="blob") - # TODO [DG] if system is the same, bounces off the laptop for no reason. Change to write through a - # call_module_method rpc (and same for similar file cases) - return super().to(system, env) - - from runhouse import Folder - - path = str( - path or Folder.default_path(self.rns_address, system) - ) # Make sure it's a string and not a Path - - from runhouse.resources.blobs.file import file - - new_blob = file(path=path, system=system) - new_blob.write(self.fetch()) - return new_blob - - # TODO delete - def write(self, data): - """Save the underlying blob to its cluster's store. - - Example: - >>> rh.blob(data).write() - """ - self.data = data - - def rm(self): - """Delete the blob from wherever it's stored. - - Example: - >>> blob = rh.blob(data) - >>> blob.rm() - """ - self.data = None - - def exists_in_system(self): - """Check whether the blob exists in the file system - - Example: - >>> blob = rh.blob(data) - >>> blob.exists_in_system() - """ - if self.data is not None: - return True - - def resolved_state(self, _state_dict=None): - """Return the resolved state of the blob, which is the data. - - Primarily used to define the behavior of the ``fetch`` method. - - Example: - >>> blob = rh.blob(data) - >>> blob.resolved_state() - """ - return self.data - - -def blob( - data: [Any] = None, - name: Optional[str] = None, - path: Optional[Union[str, Path]] = None, - system: Optional[str] = None, - env: Optional[Union[str, Env]] = None, - load_from_den: bool = True, - dryrun: bool = False, -): - """Returns a Blob object, which can be used to interact with the resource at the given path - - Args: - data: Blob data. The data to persist either on the cluster or in the filesystem. - name (Optional[str]): Name to give the blob object, to be reused later on. - path (Optional[str or Path]): Path (or path) to the blob object. Specfying a path will force the blob to be - saved to the filesystem rather than persist in the cluster's object store. - system (Optional[str or Cluster]): File system or cluster name. If providing a file system this must be one of: - [``file``, ``s3``, ``gs``]. - We are working to add additional file system support. If providing a cluster, this must be a cluster object - or name, and whether the data is saved to the object store or filesystem depends on whether a path is - specified. - env (Optional[Env or str]): Environment for the blob. If left empty, defaults to base environment. - (Default: ``None``) - load_from_den (bool): Whether to try to load the Blob resource from Den. (Default: ``True``) - dryrun (bool): Whether to create the Blob if it doesn't exist, or load a Blob object as a dryrun. - (Default: ``False``) - - Returns: - Blob: The resulting blob. - - Example: - >>> import runhouse as rh - >>> import json - >>> - >>> data = list(range(50) - >>> serialized_data = json.dumps(data) - >>> - >>> # Local blob with name and no path (saved to Runhouse object store) - >>> rh.blob(name="@/my-blob", data=data) - >>> - >>> # Remote blob with name and no path (saved to cluster's Runhouse object store) - >>> rh.blob(name="@/my-blob", data=data, system=my_cluster) - >>> - >>> # Remote blob with name, filesystem, and no path (saved to filesystem with default path) - >>> rh.blob(name="@/my-blob", data=serialized_data, system="s3") - >>> - >>> # Remote blob with name and path (saved to remote filesystem) - >>> rh.blob(name='@/my-blob', data=serialized_data, path='/runhouse-tests/my_blob.pickle', system='s3') - >>> - >>> # Local blob with path and no system (saved to local filesystem) - >>> rh.blob(data=serialized_data, path=str(Path.cwd() / "my_blob.pickle")) - - >>> # Loading a blob - >>> my_local_blob = rh.blob(name="~/my_blob") - >>> my_s3_blob = rh.blob(name="@/my_blob") - """ - if name and not any([data is not None, path, system]): - # Try reloading existing blob - try: - return Blob.from_name(name, load_from_den=load_from_den, dryrun=dryrun) - except ValueError: - # This is a rare instance where passing no constructor params is actually valid - # (e.g. rh.blob(name=key).write(data)), so if we don't find the name, we still want to - # create a new blob. - pass - - system = _get_cluster_from(system or _current_cluster(key="config"), dryrun=dryrun) - env = env or _get_env_from(env) - - if (not system or isinstance(system, Cluster)) and not path: - # Blobs must be named, or we don't have a key for the kv store - name = name or generate_default_name(prefix="blob") - new_blob = Blob(name=name, dryrun=dryrun).to(system, env) - if data is not None: - new_blob.data = data - return new_blob - - path = str(path or _generate_default_path(Blob, name, system)) - - from runhouse.resources.blobs.file import File - - name = name or generate_default_name(prefix="file") - new_blob = File( - name=name, - path=path, - system=system, - env=env, - dryrun=dryrun, - ) - if isinstance(system, Cluster): - system.put_resource(new_blob) - if data is not None: - new_blob.write(data) - return new_blob diff --git a/runhouse/resources/blobs/file.py b/runhouse/resources/blobs/file.py deleted file mode 100644 index 5605f18f7..000000000 --- a/runhouse/resources/blobs/file.py +++ /dev/null @@ -1,218 +0,0 @@ -import pickle -from pathlib import Path -from typing import Optional, Union - -from runhouse.resources.blobs.blob import Blob, blob -from runhouse.resources.envs import _get_env_from, Env -from runhouse.resources.folders import Folder, folder -from runhouse.resources.hardware import _current_cluster, _get_cluster_from, Cluster -from runhouse.utils import generate_default_name - - -class File(Blob): - def __init__( - self, - path: Optional[str] = None, - name: Optional[str] = None, - system: Optional[str] = Folder.DEFAULT_FS, - env: Optional[Env] = None, - dryrun: bool = False, - **kwargs, - ): - """ - Runhouse File object - - .. note:: - To build a File, please use the factory method :func:`file`. - """ - self._filename = str(Path(path).name) if path else name - # Use factory method so correct subclass for system is returned - self._folder = folder( - path=str(Path(path).parent) if path is not None else path, - system=system, - dryrun=dryrun, - ) - super().__init__(name=name, dryrun=dryrun, system=system, env=env, **kwargs) - - def config(self, condensed=True): - config = super().config(condensed) - file_config = { - "path": self.path, # pair with data source to create the physical URL - } - config.update(file_config) - return config - - @staticmethod - def from_config(config: dict, dryrun=False, _resolve_children=True): - return Blob(**config, dryrun=dryrun) - - @property - def system(self): - return self._folder.system - - @system.setter - def system(self, new_system): - self._folder.system = new_system - - @property - def path(self): - return self._folder.path + "/" + self._filename - - @path.setter - def path(self, new_path): - self._folder.path = str(Path(new_path).parent) - self._filename = str(Path(new_path).name) - - @property - def fsspec_url(self): - return self._folder.fsspec_url + "/" + self._filename - - def open(self, mode: str = "rb"): - """Get a file-like (OpenFile container object) of the file data. - User must close the file, or use this method inside of a with statement. - - Example: - >>> with my_file.open(mode="wb") as f: - >>> f.write(data) - >>> - >>> obj = my_file.open() - """ - return self._folder.open(self._filename, mode=mode) - - def to( - self, system, env: Optional[Union[str, Env]] = None, path: Optional[str] = None - ): - """Return a copy of the file on the destination system and path. - - Example: - >>> local_file = rh.file(data) - >>> s3_file = file.to("s3") - >>> cluster_file = file.to(my_cluster) - """ - if system == "here": - if not path: - current_cluster_config = _current_cluster(key="config") - if current_cluster_config: - system = Cluster.from_config(current_cluster_config) - else: - system = None - else: - system = "file" - - system = _get_cluster_from(system) - env = _get_env_from(env or self.env) - - if (not system or isinstance(system, Cluster)) and not path: - name = self.name or generate_default_name(prefix="blob") - data_backup = self.fetch() - new_blob = Blob(name=name).to(system, env) - new_blob.data = data_backup - return new_blob - - new_file = file(path=path, system=system) - try: - new_file.write( - self.fetch(mode="r", deserialize=False), serialize=False, mode="w" - ) - except UnicodeDecodeError: - new_file.write(self.fetch()) - - return new_file - - def resolved_state(self, deserialize: bool = True, mode: str = "rb"): - """Return the data for the user to deserialize. Primarily used to define the behavior of the ``fetch`` method. - - Example: - >>> data = file.fetch() - """ - data = self._folder.get(self._filename, mode=mode) - if deserialize: - return pickle.loads(data) - return data - - def _save_sub_resources(self, folder: str = None): - if isinstance(self.system, Cluster): - self.system.save(folder=folder) - - def write(self, data, serialize: bool = True, mode: str = "wb"): - """Save the underlying file to its specified fsspec URL. - - Example: - >>> rh.file(system="s3", path="path/to/save").write(data) - """ - self._folder.mkdir() - if serialize: - data = pickle.dumps(data) - with self.open(mode=mode) as f: - f.write(data) - return self - - def rm(self): - """Delete the file and the folder it lives in from the file system. - - Example: - >>> file = rh.file(data, path="saved/path") - >>> file.rm() - """ - self._folder.rm(contents=[self._filename], recursive=False) - - def exists_in_system(self): - """Check whether the file exists in the file system - - Example: - >>> file = rh.file(data, path="saved/path") - >>> file.exists_in_system() - """ - return self._folder.exists_in_system() - - -def file( - data=None, - name: Optional[str] = None, - path: Optional[str] = None, - system: Optional[str] = None, - dryrun: bool = False, -): - """Returns a File object, which can be used to interact with the resource at the given path - - Args: - data: File data. This should be a serializable object. - name (Optional[str]): Name to give the file object, to be reused later on. - path (Optional[str]): Path (or path) of the file object. - system (Optional[str or Cluster]): File system or cluster name. If providing a file system this must be one of: - [``file``, ``s3``, ``gs``]. - We are working to add additional file system support. - dryrun (bool): Whether to create the File if it doesn't exist, or load a File object as a dryrun. - (Default: ``False``) - - Returns: - File: The resulting file. - - Example: - >>> import runhouse as rh - >>> import json - >>> data = json.dumps(list(range(50)) - >>> - >>> # Remote file with name and no path (saved to bucket called runhouse/blobs/my-file) - >>> rh.file(name="@/my-file", data=data, system='s3').write() - >>> - >>> # Remote file with name and path - >>> rh.file(name='@/my-file', path='/runhouse-tests/my_file.pickle', system='s3').save() - >>> - >>> # Local file with name and path, save to local filesystem - >>> rh.file(data=data, path=str(Path.cwd() / "my_file.pickle")).write() - >>> - >>> # Local file with name and no path (saved to ~/.cache/blobs/my-file) - >>> rh.file(name="~/my-file", data=data).write().save() - - >>> # Loading a file - >>> my_local_file = rh.file(name="~/my_file") - >>> my_s3_file = rh.file(name="@/my_file") - """ - return blob( - name=name, - data=data, - path=path, - system=system, - dryrun=dryrun, - ) diff --git a/runhouse/resources/envs/env.py b/runhouse/resources/envs/env.py index c995ab5fb..41e011f40 100644 --- a/runhouse/resources/envs/env.py +++ b/runhouse/resources/envs/env.py @@ -51,7 +51,6 @@ def env_name(self): @staticmethod def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True): - """Create an Env object from a config dict""" config["reqs"] = [ Package.from_config(req, dryrun=True, _resolve_children=_resolve_children) if isinstance(req, dict) @@ -79,7 +78,9 @@ def _set_env_vars(env_vars): for k, v in env_vars.items(): os.environ[k] = v - def add_env_var(self, key, value): + def add_env_var(self, key: str, value: str): + """Add an env var to the environment. Environment must be re-installed to propagate new + environment variables if it already lives on a cluster.""" self.env_vars.update({key: value}) def config(self, condensed=True): @@ -168,7 +169,14 @@ def _run_setup_cmds(self, cluster: Cluster = None, setup_cmds: List = None): ) def install(self, force: bool = False, cluster: Cluster = None): - """Locally install packages and run setup commands.""" + """Locally install packages and run setup commands. + + Args: + force (bool, optional): Whether to setup the installation again if the env already exists + on the cluster. (Default: ``False``) + cluster (Clsuter, optional): Cluster to install the env on. If not provided, env is installed + on the current cluster. (Default: ``None``) + """ # Hash the config_for_rns to check if we need to install env_config = self.config() # Remove the name because auto-generated names will be different, but the installed components are the same @@ -197,13 +205,21 @@ def _run_command(self, command: str, **kwargs): def to( self, system: Union[str, Cluster], - node_idx=None, - path=None, - force_install=False, + node_idx: int = None, + path: str = None, + force_install: bool = False, ): """ - Send environment to the system (Cluster or file system). - This includes installing packages and running setup commands if system is a cluster. + Send environment to the system, and set it up if on a cluster. + + Args: + system (str or Cluster): Cluster or file system to send the env to. + node_idx (int, optional): Node index of the cluster to send the env to. If not specified, + uses the head node. (Default: ``None``) + path (str, optional): Path on the cluster to sync the env's working dir to. Uses a default + path if not specified. (Default: ``None``) + force_install (bool, optional): Whether to setup the installation again if the env already + exists on the cluster. (Default: ``False``) Example: >>> env = rh.env(reqs=["numpy", "pip"]) diff --git a/runhouse/resources/folders/folder.py b/runhouse/resources/folders/folder.py index 3af9b455c..9c87b2de0 100644 --- a/runhouse/resources/folders/folder.py +++ b/runhouse/resources/folders/folder.py @@ -4,7 +4,7 @@ import shutil import subprocess from pathlib import Path -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from runhouse.globals import rns_client from runhouse.logger import get_logger @@ -80,8 +80,7 @@ def default_path(cls, rns_address, system): # ---------------------------------- @staticmethod - def from_config(config: dict, dryrun=False, _resolve_children=True): - """Load config values into the object.""" + def from_config(config: Dict, dryrun: bool = False, _resolve_children: bool = True): if _resolve_children: config = Folder._check_for_child_configs(config) @@ -100,7 +99,7 @@ def from_config(config: dict, dryrun=False, _resolve_children=True): return Folder(**config, dryrun=dryrun) @classmethod - def _check_for_child_configs(cls, config: dict): + def _check_for_child_configs(cls, config: Dict): """Overload by child resources to load any resources they hold internally.""" system = config.get("system") if isinstance(system, str) or isinstance(system, dict): @@ -149,9 +148,20 @@ def _use_http_endpoint(self): """Whether to use system APIs to perform folder operations on the cluster via HTTP.""" return isinstance(self.system, Cluster) and not self.system.on_this_cluster() - def mv(self, system, path: Optional[str] = None, overwrite: bool = True) -> None: + def mv( + self, + system: Union[str, Cluster], + path: Optional[str], + overwrite: bool = True, + ) -> None: """Move the folder to a new filesystem or cluster. + Args: + system (str or Cluster): Filesystem or cluster to move the folder to. + path (str): Path to move the folder to. + overwrite (bool, optional): Whether to override if a file already exists at the destination path. + (Default: ``True``) + Example: >>> folder = rh.folder(path="local/path") >>> folder.mv(my_cluster) @@ -199,9 +209,13 @@ def to( system: Union[str, "Cluster"], path: Optional[Union[str, Path]] = None, ): - """Copy the folder to a new filesystem. + """Copy the folder to a new filesystem or cluster. Currently supported: ``here``, ``file``, ``gs``, ``s3``, or a cluster. + Args: + system (str or Cluster): Filesystem or cluster to move the folder to. + path (str, optional): Path to move the folder to. + Example: >>> local_folder = rh.folder(path="/my/local/folder") >>> s3_folder = local_folder.to("s3") @@ -445,7 +459,7 @@ def _download_command(self, src, dest): """CLI command for downloading folder from remote bucket. Needed when downloading a folder to a cluster.""" raise NotImplementedError - def config(self, condensed=True): + def config(self, condensed: bool = True): config = super().config(condensed) if self.system == Folder.DEFAULT_FS: @@ -472,7 +486,7 @@ def _save_sub_resources(self, folder: str = None): self.system.save(folder=folder) @staticmethod - def _path_relative_to_rh_workdir(path): + def _path_relative_to_rh_workdir(path: str): rh_workdir = Path(locate_working_dir()) try: return str(Path(path).relative_to(rh_workdir)) @@ -480,7 +494,7 @@ def _path_relative_to_rh_workdir(path): return path @staticmethod - def _path_absolute_to_rh_workdir(path): + def _path_absolute_to_rh_workdir(path: str): return ( path if Path(path).expanduser().is_absolute() @@ -551,6 +565,9 @@ def ls(self, full_paths: bool = True, sort: bool = False) -> List: def resources(self, full_paths: bool = False): """List the resources in the folder. + Args: + full_paths (bool, optional): Whether to list the full path or relative path. (Default: ``False``) + Example: >>> resources = my_folder.resources() """ @@ -581,9 +598,6 @@ def resources(self, full_paths: bool = False): @property def rns_address(self): - """Traverse up the filesystem until reaching one of the directories in rns_base_folders, - then compute the relative path to that. - """ # TODO Maybe later, account for folders along the path with a different RNS name. if self.name is None: # Anonymous folders have no rns address @@ -617,8 +631,11 @@ def rns_address(self): relative_path = str(Path(self.path).relative_to(base_folder.path)) return base_folder_path + "/" + relative_path - def contains(self, name_or_path) -> bool: - """Whether path of a Folder exists locally. + def contains(self, name_or_path: str) -> bool: + """Whether path exists locally inside a folder. + + Args: + name_or_path (str): Name or path of folder to check if it exists inside the folder. Example: >>> my_folder = rh.folder("local/folder/path") @@ -627,9 +644,12 @@ def contains(self, name_or_path) -> bool: path, _ = self.locate(name_or_path) return path is not None - def locate(self, name_or_path) -> (str, str): + def locate(self, name_or_path) -> Tuple[str, str]: """Locate the local path of a Folder given an rns path. + Args: + name_or_path (str): Name or path of folder to locate inside the folder. + Example: >>> my_folder = rh.folder("local/folder/path") >>> local_path = my_folder.locate("file_name") @@ -680,10 +700,15 @@ def locate(self, name_or_path) -> (str, str): return None, None - def open(self, name, mode="rb", encoding=None): + def open(self, name, mode: str = "rb", encoding: Optional[str] = None): """Returns the specified file as a stream (`botocore.response.StreamingBody`), which must be used as a content manager to be opened. + Args: + name (str): Name of file inside the folder to open. + mode (str, optional): Mode for opening the file. (Default: ``"rb"``) + encoding (str, optional): Encoding for opening the file. (Default: ``None``) + Example: >>> with my_folder.open('obj_name') as my_file: >>> pickle.load(my_file) @@ -698,9 +723,14 @@ def open(self, name, mode="rb", encoding=None): return open(file_path, mode=mode, encoding=encoding) - def get(self, name, mode="rb", encoding=None): + def get(self, name, mode: str = "rb", encoding: Optional[str] = None): """Returns the contents of a file as a string or bytes. + Args: + name (str): Name of file to get contents of. + mode (str, optional): Mode for opening the file. (Default: ``"rb"``) + encoding (str, optional): Encoding for opening the file. (Default: ``None``) + Example: >>> contents = my_folder.get(file_name) """ @@ -725,18 +755,19 @@ def exists_in_system(self): >>> exists_on_system = my_folder.exists_in_system() """ if self._use_http_endpoint: - return self.system._folder_exists(self.path) + return self.system._folder_exists(path=self.path) else: full_path = Path(self.path).expanduser() return full_path.exists() and full_path.is_dir() - def rm(self, contents: list = None, recursive: bool = True): + def rm(self, contents: List = None, recursive: bool = True): """Delete a folder from the file system. Optionally provide a list of folder contents to delete. Args: - contents (Optional[List]): Specific contents to delete in the folder. - recursive (bool): Delete the folder itself (including all its contents). - Defaults to ``True``. + contents (Optional[List]): Specific contents to delete in the folder. If None, removes + the entire folder. (Default: ``None``) + recursive (bool, optional): Delete the folder itself (including all its contents). + (Default: ``True``) Example: >>> my_folder.rm() @@ -762,15 +793,15 @@ def rm(self, contents: list = None, recursive: bool = True): else: folder_path.unlink() - def put(self, contents, overwrite=False, mode: str = "wb"): + def put(self, contents: Dict[str, Any], overwrite: bool = False, mode: str = "wb"): """Put given contents in folder. Args: contents (Dict[str, Any] or Resource or List[Resource]): Contents to put in folder. Must be a dict with keys being the file names (without full paths) and values being the file-like objects to write, or a Resource object, or a list of Resources. - overwrite (bool): Whether to overwrite the existing file if it exists. Defaults to ``False``. - mode (Optional(str)): Write mode to use. Defaults to ``wb``. + overwrite (bool, optional): Whether to overwrite the existing file if it exists. (Default: ``False``) + mode (str, optional): Write mode to use. (Default: ``wb``). Example: >>> my_folder.put(contents={"filename.txt": data}) diff --git a/runhouse/resources/folders/gcs_folder.py b/runhouse/resources/folders/gcs_folder.py index d616f78f0..5f269ed2c 100644 --- a/runhouse/resources/folders/gcs_folder.py +++ b/runhouse/resources/folders/gcs_folder.py @@ -2,7 +2,7 @@ import shutil import subprocess from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional from runhouse.logger import get_logger @@ -23,7 +23,7 @@ def __init__(self, dryrun: bool, **kwargs): self._urlpath = "gs://" @staticmethod - def from_config(config: dict, dryrun=False, _resolve_children=True): + def from_config(config: Dict, dryrun: bool = False, _resolve_children: bool = True): """Load config values into the object.""" return GCSFolder(**config, dryrun=dryrun) diff --git a/runhouse/resources/folders/s3_folder.py b/runhouse/resources/folders/s3_folder.py index 869449851..71465ac8d 100644 --- a/runhouse/resources/folders/s3_folder.py +++ b/runhouse/resources/folders/s3_folder.py @@ -3,7 +3,7 @@ import subprocess import time from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional, Union from runhouse.logger import get_logger @@ -28,7 +28,7 @@ def __init__(self, dryrun: bool, **kwargs): self._urlpath = "s3://" @staticmethod - def from_config(config: dict, dryrun=False, _resolve_children=True): + def from_config(config: Dict, dryrun: bool = False, _resolve_children: bool = True): """Load config values into the object.""" return S3Folder(**config, dryrun=dryrun) @@ -115,7 +115,12 @@ def _s3_copy(self, new_path): Key=new_key, ) - def put(self, contents, overwrite=False, mode: str = "wb"): + def put( + self, + contents: Union["S3Folder", Dict], + overwrite: bool = False, + mode: str = "wb", + ): """Put given contents in folder.""" self.mkdir() if isinstance(contents, list): @@ -131,7 +136,7 @@ def put(self, contents, overwrite=False, mode: str = "wb"): contents.folder_path = key + "/" + contents.folder_path return - if not isinstance(contents, dict): + if not isinstance(contents, Dict): raise TypeError( "`contents` argument must be a dict mapping filenames to file-like objects" ) diff --git a/runhouse/resources/functions/aws_lambda.py b/runhouse/resources/functions/aws_lambda.py index b51375943..9418710ea 100644 --- a/runhouse/resources/functions/aws_lambda.py +++ b/runhouse/resources/functions/aws_lambda.py @@ -8,7 +8,7 @@ import warnings import zipfile from pathlib import Path -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional, Union try: import boto3 @@ -142,8 +142,6 @@ def __init__( def from_config( cls, config: dict, dryrun: bool = False, _resolve_children: bool = True ): - """Create an AWS Lambda object from a config dictionary.""" - if "resource_subtype" in config.keys(): config.pop("resource_subtype", None) if "system" in config.keys(): @@ -171,10 +169,10 @@ def from_config( def from_name( cls, name, - load_from_den=True, - dryrun=False, - alt_options=None, - _resolve_children=True, + load_from_den: bool = True, + dryrun: bool = False, + _alt_options: Dict = None, + _resolve_children: bool = True, ): config = rns_client.load_config(name=name, load_from_den=load_from_den) if not config: @@ -187,7 +185,7 @@ def from_handler_file( paths_to_code: List[str], handler_function_name: str, name: Optional[str] = None, - env: Optional[dict or List[str] or Env] = None, + env: Optional[Union[Dict, List[str], Env]] = None, runtime: Optional[str] = None, timeout: Optional[int] = None, memory_size: Optional[int] = None, @@ -223,13 +221,13 @@ def from_handler_file( 3. An instance of Runhouse Env class. By default, ``runhouse`` package will be installed, and env_vars will include ``{HOME: /tmp/home}``. - timeout: Optional[int]: The maximum amount of time (in seconds) during which the Lambda will run in AWS + timeout (Optional[int]) The maximum amount of time (in seconds) during which the Lambda will run in AWS without timing-out. (Default: ``900``, Min: ``3``, Max: ``900``) - memory_size: Optional[int], The amount of memory (in MB) to be allocated to the Lambda. + memory_size (Optional[int]), The amount of memory (in MB) to be allocated to the Lambda. (Default: ``10240``, Min: ``128``, Max: ``10240``) - tmp_size: Optional[int], This size of the /tmp folder in the aws lambda file system. + tmp_size (Optional[int]), This size of the /tmp folder in the aws lambda file system. (Default: ``10240``, Min: ``512``, Max: ``10240``). - retention_time: Optional[int] The time (in days) the Lambda execution logs will be saved in AWS + retention_time (Optional[int]) The time (in days) the Lambda execution logs will be saved in AWS cloudwatch. After that, they will be deleted. (Default: ``30`` days) dryrun (bool): Whether to create the Function if it doesn't exist, or load the Function object as a dryrun. (Default: ``False``). @@ -825,7 +823,6 @@ def map(self, *args, **kwargs): return [self._invoke(*args, **kwargs) for args in zip(*args)] - # def starmap(self, args_lists, **kwargs): """Like :func:`map` except that the elements of the iterable are expected to be iterables that are unpacked as arguments. An iterable of [(1,2), (3, 4)] results in [func(1,2), func(3,4)]. diff --git a/runhouse/resources/functions/function.py b/runhouse/resources/functions/function.py index d46f575b3..b9e0e5530 100644 --- a/runhouse/resources/functions/function.py +++ b/runhouse/resources/functions/function.py @@ -1,6 +1,6 @@ import inspect from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union from runhouse import globals from runhouse.logger import get_logger @@ -40,8 +40,9 @@ def __init__( # ----------------- Constructor helper methods ----------------- @classmethod - def from_config(cls, config: dict, dryrun: bool = False, _resolve_children=True): - """Create a Function object from a config dictionary.""" + def from_config( + cls, config: dict, dryrun: bool = False, _resolve_children: bool = True + ): if isinstance(config["system"], dict): config["system"] = Cluster.from_config( config["system"], dryrun=dryrun, _resolve_children=_resolve_children @@ -74,11 +75,18 @@ def to( name: Optional[str] = None, force_install: bool = False, ): - """to(system: str | Cluster | None = None, env: List[str] | Env = [], force_install: bool = False) + """ + Send the function to the specified env on the cluster. This will sync over relevant code and packages + onto the cluster, and set up the environment if it does not yet exist on the cluster. - Set up a Function and Env on the given system. - If the function is sent to AWS, the system should be ``aws_lambda`` - See the args of the factory method :func:`function` for more information. + Args: + system (str or Cluster): The system to setup the function and env on. + env (str, List[str], or Env, optional): The environment where the function lives on in the cluster, + or the set of requirements necessary to run the function. (Default: ``None``) + name (Optional[str], optional): Name to give to the function resource, if you wish to rename it. + (Default: ``None``) + force_install (bool, optional): Whether to re-install and perform the environment setup steps, even + if it may already exist on the cluster. (Defualt: ``False``) Example: >>> rh.function(fn=local_fn).to(gpu_cluster) @@ -103,12 +111,12 @@ def __call__(self, *args, **kwargs) -> Any: """Call the function on its system Args: - *args: Optional args for the Function - stream_logs (bool): Whether to stream the logs from the Function's execution. - Defaults to ``True``. - run_name (Optional[str]): Name of the Run to create. If provided, a Run will be created - for this function call, which will be executed synchronously on the cluster before returning its result - **kwargs: Optional kwargs for the Function + *args: Optional args for the Function + stream_logs (bool): Whether to stream the logs from the Function's execution. + Defaults to ``True``. + run_name (Optional[str]): Name of the Run to create. If provided, a Run will be created + for this function call, which will be executed synchronously on the cluster before returning its result + **kwargs: Optional kwargs for the Function Returns: The Function's return value @@ -144,12 +152,15 @@ def map(self, *args, **kwargs): ray_wrapped_fn = ray.remote(fn) return ray.get([ray_wrapped_fn.remote(*args, **kwargs) for args in zip(*args)]) - def starmap(self, args_lists, **kwargs): + def starmap(self, args_lists: List[Iterable], **kwargs): """Like :func:`map` except that the elements of the iterable are expected to be iterables that are unpacked as arguments. An iterable of [(1,2), (3, 4)] results in [func(1,2), func(3,4)]. + Args: + arg_lists (List[Iterable]): List containing iterbles of arguments to be passed into the function. + Example: - >>> arg_list = [(1,2), (3, 4)] + >>> arg_list = [(1, 2), (3, 4)] >>> # runs the function twice, once with args (1, 2) and once with args (3, 4) >>> remote_fn.starmap(arg_list) """ @@ -178,6 +189,12 @@ def get(self, run_key): return self.system.get(run_key) def config(self, condensed=True): + """The config of the function. + + Args: + condensed (bool, optional): Whether to return the condensed config without expanding children subresources, + or return the whole expanded config. (Default: ``True``) + """ config = super().config(condensed) config.update( { @@ -189,6 +206,10 @@ def config(self, condensed=True): def send_secrets(self, providers: Optional[List[str]] = None): """Send secrets to the system. + Args: + providers (List[str], optional): List of secret names to send over to the system. If none are provided, + syncs over all locally detected provider secrets. (Default: ``None``) + Example: >>> remote_fn.send_secrets(providers=["aws", "lambda"]) """ @@ -212,16 +233,15 @@ def notebook(self, persist=False, sync_package_on_close=None, port_forward=8888) ) def get_or_call( - self, run_name: str, load_from_den=True, local=True, *args, **kwargs + self, run_name: str, load_from_den: bool = True, *args, **kwargs ) -> Any: """Check if object already exists on cluster or rns, and if so return the result. If not, run the function. Keep in mind this can be called with any of the usual method call modifiers - `remote=True`, `run_async=True`, `stream_logs=False`, etc. Args: - run_name (Optional[str]): Name of a particular run for this function. - If not provided will use the function's name. - load_from_den (bool): Whether to try loading the run name from Den. + run_name (str): Name of a particular run for this function. + load_from_den (bool, optional): Whether to try loading the run name from Den. (Default: ``True``) *args: Arguments to pass to the function for the run (relevant if creating a new run). **kwargs: Keyword arguments to pass to the function for the run (relevant if creating a new run). @@ -253,7 +273,11 @@ def keep_warm( self, autostop_mins=None, ): - """Keep the system warm for autostop_mins. If autostop_mins is ``None`` or -1, keep warm indefinitely. + """Keep the system warm for autostop_mins. + + Args: + autostop_mins (int): Keep the cluster warm for this amount of time. + If ``None`` or -1, keep warm indefinitely. Example: >>> # keep gpu warm for 30 mins diff --git a/runhouse/resources/functions/function_factory.py b/runhouse/resources/functions/function_factory.py index 07f925ea8..7e09ba3e3 100644 --- a/runhouse/resources/functions/function_factory.py +++ b/runhouse/resources/functions/function_factory.py @@ -26,17 +26,17 @@ def function( Args: fn (Optional[str or Callable]): The function to execute on the remote system when the function is called. - name (Optional[str]): Name of the Function to create or retrieve. - This can be either from a local config or from the RNS. - env (Optional[List[str] or Env or str]): List of requirements to install on the remote cluster, or path to the - requirements.txt file, or Env object or string name of an Env object. - load_from_den (bool): Whether to try loading the function from Den. (Default: ``True``) - dryrun (bool): Whether to create the Function if it doesn't exist, or load the Function object as a dryrun. - (Default: ``False``) - load_secrets (bool): Whether or not to send secrets; only applicable if `dryrun` is set to ``False``. - (Default: ``False``) - serialize_notebook_fn (bool): If function is of a notebook setting, whether or not to serialized the function. + name (Optional[str], optional): Name of the Function to create or retrieve. + This can be either from a local config or from the RNS. (Default: ``None``) + env (Optional[List[str] or Env or str], optional): List of requirements to install on the remote cluster, + or path to the requirements.txt file, or Env object or string name of an Env object. (Default: ``None``) + load_from_den (bool, optional): Whether to try loading the function from Den. (Default: ``True``) + dryrun (bool, optional): Whether to create the Function if it doesn't exist, or load the Function object as + a dryrun. (Default: ``False``) + load_secrets (bool, optional): Whether or not to send secrets; only applicable if `dryrun` is set to ``False``. (Default: ``False``) + serialize_notebook_fn (bool, optional): If function is of a notebook setting, whether or not to serialized the + function. (Default: ``False``) Returns: Function: The resulting Function object. diff --git a/runhouse/resources/hardware/__init__.py b/runhouse/resources/hardware/__init__.py index d51be4d52..de63e5912 100644 --- a/runhouse/resources/hardware/__init__.py +++ b/runhouse/resources/hardware/__init__.py @@ -1,12 +1,6 @@ from .cluster import Cluster -from .cluster_factory import ( - cluster, - kubernetes_cluster, - ondemand_cluster, - sagemaker_cluster, -) +from .cluster_factory import cluster, kubernetes_cluster, ondemand_cluster from .on_demand_cluster import OnDemandCluster -from .sagemaker.sagemaker_cluster import SageMakerCluster from .utils import ( _current_cluster, _default_env_if_on_cluster, diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index fed565dd5..b7e5e096c 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -39,6 +39,7 @@ DEFAULT_STATUS_CHECK_INTERVAL, EMPTY_DEFAULT_ENV_NAME, LOCALHOST, + NUM_PORTS_TO_TRY, RESERVED_SYSTEM_NAMES, ) from runhouse.globals import configs, obj_store, rns_client @@ -81,7 +82,7 @@ def __init__( ssl_certfile: str = None, domain: str = None, den_auth: bool = False, - dryrun=False, + dryrun: bool = False, **kwargs, # We have this here to ignore extra arguments when calling from from_config ): """ @@ -197,15 +198,16 @@ def default_env(self, env): def from_name( cls, name, - load_from_den=True, - dryrun=False, - alt_options=None, - _resolve_children=True, + load_from_den: bool = True, + dryrun: bool = False, + _alt_options: Dict = None, + _resolve_children: bool = True, ): cluster = super().from_name( name=name, + load_from_den=load_from_den, dryrun=dryrun, - alt_options=alt_options, + _alt_options=_alt_options, _resolve_children=_resolve_children, ) if hasattr(cluster, "_update_from_sky_status"): @@ -236,6 +238,13 @@ def save_config_to_cluster( def save(self, name: str = None, overwrite: bool = True, folder: str = None): """Overrides the default resource save() method in order to also update the cluster config on the cluster itself. + + Args: + name (str, optional): Name to save the cluster as, if different from its existing name. (Default: ``None``) + overwrite (bool, optional): Whether to overwrite the existing saved resource, if it exists. + (Default: ``True``) + folder (str, optional): Folder to save the config in, if saving locally. If None and saving locally, + will be saved in the ``~/.rh`` directory. (Default: ``None``) """ on_this_cluster = self.on_this_cluster() @@ -260,6 +269,7 @@ def save(self, name: str = None, overwrite: bool = True, folder: str = None): return self def delete_configs(self): + """Delete configs for the cluster""" if self._creds: logger.debug( f"Attempting to delete creds associated with cluster {self.name}" @@ -295,7 +305,9 @@ def _save_sub_resources(self, folder: str = None): self._default_env.save(folder=folder) @classmethod - def from_config(cls, config: dict, dryrun=False, _resolve_children=True): + def from_config( + cls, config: Dict, dryrun: bool = False, _resolve_children: bool = True + ): resource_subtype = config.get("resource_subtype") if _resolve_children: config = cls._check_for_child_configs(config) @@ -306,14 +318,10 @@ def from_config(cls, config: dict, dryrun=False, _resolve_children=True): from .on_demand_cluster import OnDemandCluster return OnDemandCluster(**config, dryrun=dryrun) - elif resource_subtype == "SageMakerCluster": - from .sagemaker.sagemaker_cluster import SageMakerCluster - - return SageMakerCluster(**config, dryrun=dryrun) else: raise ValueError(f"Unknown cluster type {resource_subtype}") - def config(self, condensed=True): + def config(self, condensed: bool = True): config = super().config(condensed) self.save_attrs_to_config( config, @@ -351,12 +359,15 @@ def config(self, condensed=True): return config - def endpoint(self, external=False): - """Endpoint for the cluster's Daemon server. If external is True, will only return the external url, - and will return None otherwise (e.g. if a tunnel is required). If external is False, will either return - the external url if it exists, or will set up the connection (based on connection_type) and return - the internal url (including the local connected port rather than the sever port). If cluster is not up, - returns None. + def endpoint(self, external: bool = False): + """Endpoint for the cluster's Daemon server. + + Args: + external (bool, optional): If ``True``, will only return the external url, and will return ``None`` + otherwise (e.g. if a tunnel is required). If set to ``False``, will either return the external url + if it exists, or will set up the connection (based on connection_type) and return the internal url + (including the local connected port rather than the sever port). If cluster is not up, returns + `None``. (Default: ``False``) """ if not self.address or self.on_this_cluster(): return None @@ -384,10 +395,7 @@ def endpoint(self, external=False): if external: return None - if self.server_connection_type in [ - ServerConnectionType.SSH, - ServerConnectionType.AWS_SSM, - ]: + if self.server_connection_type == ServerConnectionType.SSH: self.client.check_server() return f"http://{LOCALHOST}:{client_port}" @@ -515,7 +523,7 @@ def _sync_default_env_to_cluster(self): def _sync_runhouse_to_cluster( self, _install_url: Optional[str] = None, - env=None, + env: "Env" = None, local_rh_package_path: Optional[Path] = None, ): if self.on_this_cluster(): @@ -581,7 +589,7 @@ def install_packages( """Install the given packages on the cluster. Args: - reqs (List[Package or str): List of packages to install on cluster and env + reqs (List[Package or str]): List of packages to install on cluster and env. env (Env or str): Environment to install package on. If left empty, defaults to base environment. (Default: ``None``) @@ -594,8 +602,15 @@ def install_packages( env.to(self) def get(self, key: str, default: Any = None, remote=False): - """Get the result for a given key from the cluster's object store. To raise an error if the key is not found, - use `cluster.get(key, default=KeyError)`.""" + """Get the result for a given key from the cluster's object store. + + Args: + key (str): Key to get from the cluster's object store. + default (Any, optional): What to return if the key is not found. To raise an error, pass in + ``KeyError``. (Default: None) + remote (bool, optional): Whether to get the remote object, rather than the object in full. + (Default: ``False``) + """ if self.on_this_cluster(): return obj_store.get(key, default=default, remote=remote) try: @@ -612,12 +627,14 @@ def get(self, key: str, default: Any = None, remote=False): return default return res - # TODO deprecate - def get_run(self, run_name: str, folder_path: str = None): - return self.get(run_name, remote=True).provenance + def put(self, key: str, obj: Any, env: str = None): + """Put the given object on the cluster's object store at the given key. - def put(self, key: str, obj: Any, env=None): - """Put the given object on the cluster's object store at the given key.""" + Args: + key (str): Key to assign the object in the object store. + obj (Any): Object to put in the object store + env (str, optional): Env of the object store to put the object in. (Default: ``None``) + """ if self.on_this_cluster(): return obj_store.put(key, obj, env=env) return self.call_client_method( @@ -625,9 +642,20 @@ def put(self, key: str, obj: Any, env=None): ) def put_resource( - self, resource: Resource, state: Dict = None, dryrun: bool = False, env=None + self, + resource: Resource, + state: Dict = None, + dryrun: bool = False, + env: Union[str, "Env"] = None, ): - """Put the given resource on the cluster's object store. Returns the key (important if name is not set).""" + """Put the given resource on the cluster's object store. Returns the key (important if name is not set). + + Args: + resource (Resource): Key to assign the object in the object store. + state (Dict, optional): Dict of resource attributes to override. (Default: ``False``) + dryrun (bool, optional): Whether to put the resource in dryrun mode or not. (Default: ``False``) + env (str, optional): Env of the object store to put the object in. (Default: ``None``) + """ if resource.RESOURCE_TYPE == "env" and not resource.name: resource.name = self.default_env.name @@ -665,20 +693,33 @@ def put_resource( ) def rename(self, old_key: str, new_key: str): - """Rename a key in the cluster's object store.""" + """Rename a key in the cluster's object store. + + Args: + old_key (str): Original key to rename. + new_key (str): Name to reassign the object. + """ if self.on_this_cluster(): return obj_store.rename(old_key, new_key) return self.call_client_method("rename_object", old_key, new_key) - def keys(self, env=None): - """List all keys in the cluster's object store.""" + def keys(self, env: str = None): + """List all keys in the cluster's object store. + + Args: + env (str, optional): Env in which to list out the keys for. + """ if self.on_this_cluster(): return obj_store.keys() res = self.call_client_method("keys", env=env) return res def delete(self, keys: Union[None, str, List[str]]): - """Delete the given items from the cluster's object store. To delete all items, use `cluster.clear()`""" + """Delete the given items from the cluster's object store. To delete all items, use `cluster.clear()` + + Args: + keys (str or List[str]): key or list of keys to delete from the object store. + """ if isinstance(keys, str): keys = [keys] if self.on_this_cluster(): @@ -724,17 +765,14 @@ def connect_tunnel(self, force_reconnect=False): self._rpc_tunnel = self.ssh_tunnel( local_port=self.server_port, remote_port=self.server_port, - num_ports_to_try=10, + num_ports_to_try=NUM_PORTS_TO_TRY, ) def connect_server_client(self, force_reconnect=False): if not self.address: raise ValueError(f"No address set for cluster <{self.name}>. Is it up?") - if self.server_connection_type in [ - ServerConnectionType.SSH, - ServerConnectionType.AWS_SSM, - ]: + if self.server_connection_type == ServerConnectionType.SSH: # For a password cluster, the 'ssh_tunnel' command assumes a Control Master is already set up with # an authenticated password. # TODO: I wonder if this authentication ever goes dry, and our SSH tunnel would need to be @@ -786,7 +824,13 @@ def connect_server_client(self, force_reconnect=False): ) def status(self, resource_address: str = None, send_to_den: bool = False): - """Load the status of the Runhouse daemon running on a cluster.""" + """Load the status of the Runhouse daemon running on a cluster. + + Args: + resource_address (str, optional): + send_to_den (bool, optional): Whether to send and update the status in Den. Only applies to + clusters that are saved to Den. (Default: ``False``) + """ # Note: If running outside a local cluster need to include a resource address to construct the cluster subtoken # Allow for specifying a resource address explicitly in case the resource has no rns address yet @@ -898,7 +942,6 @@ def restart_server( Args: resync_rh (bool): Whether to resync runhouse. Specifying False will not sync Runhouse under any circumstance. If it is None, then it will sync if Runhouse is not installed on the cluster or if locally it is installed as editable. (Default: ``None``) restart_ray (bool): Whether to restart Ray. (Default: ``True``) - env (str or Env, optional): Specified environment to restart the server on. (Default: ``None``) restart_proxy (bool): Whether to restart Caddy on the cluster, if configured. (Default: ``False``) Example: @@ -1062,7 +1105,7 @@ def stop_server(self, stop_ray: bool = True, env: Union[str, "Env"] = None): """Stop the RPC server. Args: - stop_ray (bool): Whether to stop Ray. (Default: `True`) + stop_ray (bool, optional): Whether to stop Ray. (Default: `True`) env (str or Env, optional): Specified environment to stop the server on. (Default: ``None``) """ cmd = CLI_STOP_CMD if stop_ray else f"{CLI_STOP_CMD} --no-stop-ray" @@ -1072,20 +1115,20 @@ def stop_server(self, stop_ray: bool = True, env: Union[str, "Env"] = None): @contextlib.contextmanager def pause_autostop(self): - """Context manager to temporarily pause autostop. Mainly for OnDemand clusters, for BYO cluster - there is no autostop.""" + """Context manager to temporarily pause autostop. Only for OnDemand clusters. There is no autostop + for static clusters.""" pass def call( self, - module_name, - method_name, + module_name: str, + method_name: str, *args, - stream_logs=True, - run_name=None, - remote=False, - run_async=False, - save=False, + stream_logs: bool = True, + run_name: str = None, + remote: bool = False, + run_async: bool = False, + save: bool = False, **kwargs, ): """Call a method on a module that is in the cluster's object store. @@ -1093,10 +1136,12 @@ def call( Args: module_name (str): Name of the module saved on system. method_name (str): Name of the method. - stream_logs (bool): Whether to stream logs from the method call. - run_name (str): Name for the run. - remote (bool): Return a remote object from the function, rather than the result proper. - run_async (bool): Run the method asynchronously and return an awaitable. + stream_logs (bool, optional): Whether to stream logs from the method call. (Default: ``True``) + run_name (str, optional): Name for the run. (Default: ``None``) + remote (bool, optional): Return a remote object from the function, rather than the result proper. + (Default: ``False``) + run_async (bool, optional): Run the method asynchronously and return an awaitable. (Default: ``False``) + save (bool, optional): Whether or not to save the call. (Default: ``False``) *args: Positional arguments to pass to the method. **kwargs: Keyword arguments to pass to the method. @@ -1172,16 +1217,16 @@ def rsync( dest (str): The target path. up (bool): The direction of the sync. If ``True``, will rsync from local to cluster. If ``False`` will rsync from cluster to local. - node (Optional[str]): Specific cluster node to rsync to. If not specified will use the address of the - cluster's head node. - contents (Optional[bool]): Whether the contents of the source directory or the directory itself should - be copied to destination. + node (Optional[str], optional): Specific cluster node to rsync to. If not specified will use the + address of the cluster's head node. + contents (Optional[bool], optional): Whether the contents of the source directory or the directory + itself should be copied to destination. If ``True`` the contents of the source directory are copied to the destination, and the source directory itself is not created at the destination. If ``False`` the source directory along with its contents are copied ot the destination, creating an additional directory layer at the destination. (Default: ``False``). - filter_options (Optional[str]): The filter options for rsync. - stream_logs (Optional[bool]): Whether to stream logs to the stdout/stderr. (Default: ``False``). + filter_options (Optional[str], optional): The filter options for rsync. + stream_logs (Optional[bool], optional): Whether to stream logs to the stdout/stderr. (Default: ``False``). .. note:: Ending ``source`` with a slash will copy the contents of the directory into dest, @@ -1350,15 +1395,25 @@ def _copy_certs_to_cluster(self): def run( self, - commands: List[str], + commands: Union[str, List[str]], env: Union["Env", str] = None, stream_logs: bool = True, require_outputs: bool = True, node: Optional[str] = None, _ssh_mode: str = "interactive", # Note, this only applies for non-password SSH ) -> List: - """Run a list of shell commands on the cluster. If `run_name` is provided, the commands will be - sent over to the cluster before being executed and a Run object will be created. + """Run a list of shell commands on the cluster. + + Args: + commands (str or List[str]): Command or list of commands to run on the cluster. + env (Env or str, optional): Env on the cluster to run the command in. If not provided, + will be run in the default env. (Default: ``None``) + stream_logs (bool, optional): Whether to stream log output as the command runs. + (Default: ``True``) + require_outputs (bool, optional): If ``True``, returns a Tuple (returncode, stdout, stderr). + If ``False``, returns just the returncode. (Default: ``True``) + node (str, optional): Node to run the commands on. If not provided, runs on head node. + (Default: ``None``) Example: >>> cpu.run(["pip install numpy"]) @@ -1532,6 +1587,12 @@ def run_python( ): """Run a list of python commands on the cluster, or a specific cluster node if its IP is provided. + Args: + commands (List[str]): List of commands to run. + env (Env or str, optional): Env to run the commands in. (Default: ``None``) + stream_logs (bool, optional): Whether to stream logs. (Default: ``True``) + node (str, optional): Node to run commands on. If not specified, runs on head node. (Default: ``None``) + Example: >>> cpu.run_python(['import numpy', 'print(numpy.__version__)']) >>> cpu.run_python(["print('hello')"]) @@ -1568,8 +1629,9 @@ def sync_secrets( """Send secrets for the given providers. Args: - providers(List[str] or None): List of providers to send secrets for. - If `None`, all providers configured in the environment will by sent. + providers(List[str] or None, optional): List of providers to send secrets for. + If `None`, all providers configured in the environment will by sent. (Default: ``None``) + env (str, Env, optional): Env to sync secrets into. (Default: ``None``) Example: >>> cpu.sync_secrets(secrets=["aws", "lambda"]) @@ -1618,7 +1680,7 @@ def notebook( tunnel = self.ssh_tunnel( local_port=port_forward, - num_ports_to_try=10, + num_ports_to_try=NUM_PORTS_TO_TRY, ) port_fwd = tunnel.local_bind_port @@ -1647,6 +1709,10 @@ def remove_conda_env( ): """Remove conda env from the cluster. + Args: + env (str or Env): Name of conda env to remove from the cluster, or Env resource + representing the environment. + Example: >>> rh.ondemand_cluster("rh-cpu").remove_conda_env("my_conda_env") """ @@ -1660,8 +1726,12 @@ def download_cert(self): f"Latest TLS certificate for {self.name} saved to local path: {self.cert_config.cert_path}" ) - def enable_den_auth(self, flush=True): - """Enable Den auth on the cluster.""" + def enable_den_auth(self, flush: bool = True): + """Enable Den auth on the cluster. + + Args: + flush (bool, optional): Whether to flush the auth cache. (Default: ``True``) + """ if self.on_this_cluster(): raise ValueError("Cannot toggle Den Auth live on the cluster.") else: diff --git a/runhouse/resources/hardware/cluster_factory.py b/runhouse/resources/hardware/cluster_factory.py index fd0cf1d6a..ead958707 100644 --- a/runhouse/resources/hardware/cluster_factory.py +++ b/runhouse/resources/hardware/cluster_factory.py @@ -4,16 +4,14 @@ from typing import Dict, List, Optional, Union -from runhouse.constants import DEFAULT_SERVER_PORT, LOCAL_HOSTS, RESERVED_SYSTEM_NAMES +from runhouse.constants import RESERVED_SYSTEM_NAMES from runhouse.globals import rns_client from runhouse.logger import get_logger from runhouse.resources.hardware.utils import ServerConnectionType -from runhouse.rns.utils.api import relative_file_path from .cluster import Cluster from .on_demand_cluster import OnDemandCluster -from .sagemaker.sagemaker_cluster import SageMakerCluster logger = get_logger(__name__) @@ -33,28 +31,30 @@ def cluster( load_from_den: bool = True, dryrun: bool = False, **kwargs, -) -> Union[Cluster, OnDemandCluster, SageMakerCluster]: +) -> Union[Cluster, OnDemandCluster]: """ Builds an instance of :class:`Cluster`. Args: name (str): Name for the cluster, to re-use later on. host (str or List[str], optional): Hostname (e.g. domain or name in .ssh/config), IP address, or list of IP - addresses for the cluster (the first of which is the head node). + addresses for the cluster (the first of which is the head node). (Default: ``None``). ssh_creds (dict or str, optional): SSH credentials, passed as dictionary or the name of an `SSHSecret` object. - Example: ``ssh_creds={'ssh_user': '...', 'ssh_private_key':''}`` + Example: ``ssh_creds={'ssh_user': '...', 'ssh_private_key':''}`` (Default: ``None``). server_port (bool, optional): Port to use for the server. If not provided will use 80 for a ``server_connection_type`` of ``none``, 443 for ``tls`` and ``32300`` for all other SSH connection types. server_host (bool, optional): Host from which the server listens for traffic (i.e. the --host argument `runhouse start` run on the cluster). Defaults to "0.0.0.0" unless connecting to the server with an SSH - connection, in which case ``localhost`` is used. + connection, in which case ``localhost`` is used. (Default: ``None``). server_connection_type (ServerConnectionType or str, optional): Type of connection to use for the Runhouse API server. ``ssh`` will use start with server via an SSH tunnel. ``tls`` will start the server with HTTPS on port 443 using TLS certs without an SSH tunnel. ``none`` will start the server with HTTP - without an SSH tunnel. ``aws_ssm`` will start the server with HTTP using AWS SSM port forwarding. + without an SSH tunnel. (Default: ``None``). ssl_keyfile(str, optional): Path to SSL key file to use for launching the API server with HTTPS. + (Default: ``None``). ssl_certfile(str, optional): Path to SSL certificate file to use for launching the API server with HTTPS. - domain(str, optional): Domain name for the cluster. Relevant if enabling HTTPs on the cluster. + (Default: ``None``). + domain(str, optional): Domain name for the cluster. Relevant if enabling HTTPs on the cluster. (Default: ``None``). den_auth (bool, optional): Whether to use Den authorization on the server. If ``True``, will validate incoming requests with a Runhouse token provided in the auth headers of the request with the format: ``{"Authorization": "Bearer "}``. (Default: ``None``). @@ -66,7 +66,7 @@ def cluster( (Default: ``False``) Returns: - Union[Cluster, OnDemandCluster, SageMakerCluster]: The resulting cluster. + Union[Cluster, OnDemandCluster]: The resulting cluster. Example: >>> # using private key @@ -113,7 +113,7 @@ def cluster( name, load_from_den=load_from_den, dryrun=dryrun, - alt_options=alt_options, + _alt_options=alt_options, ) if c: c.set_connection_defaults() @@ -147,40 +147,6 @@ def cluster( **kwargs, ) - if any( - k in kwargs.keys() - for k in [ - "role", - "estimator", - "instance_type", - "connection_wait_time", - "num_instances", - ] - ): - warnings.warn( - "The `cluster` factory is intended to be used for static clusters. " - "If you would like to create a sagemaker cluster, please use `rh.sagemaker_cluster()` instead." - ) - return sagemaker_cluster( - name=name, - ssh_creds=ssh_creds, - server_port=server_port, - server_host=server_host, - server_connection_type=server_connection_type, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - domain=domain, - den_auth=den_auth, - default_env=default_env, - dryrun=dryrun, - **kwargs, - ) - - if server_connection_type == ServerConnectionType.AWS_SSM: - raise ValueError( - f"Cluster does not support server connection type of {server_connection_type}" - ) - if isinstance(host, str): host = [host] @@ -372,7 +338,7 @@ def ondemand_cluster( server_connection_type (ServerConnectionType or str, optional): Type of connection to use for the Runhouse API server. ``ssh`` will use start with server via an SSH tunnel. ``tls`` will start the server with HTTPS on port 443 using TLS certs without an SSH tunnel. ``none`` will start the server with HTTP - without an SSH tunnel. ``aws_ssm`` will start the server with HTTP using AWS SSM port forwarding. + without an SSH tunnel. ssl_keyfile(str, optional): Path to SSL key file to use for launching the API server with HTTPS. ssl_certfile(str, optional): Path to SSL certificate file to use for launching the API server with HTTPS. domain(str, optional): Domain name for the cluster. Relevant if enabling HTTPs on the cluster. @@ -472,7 +438,7 @@ def ondemand_cluster( name, load_from_den=load_from_den, dryrun=dryrun, - alt_options=alt_options, + _alt_options=alt_options, ) if c: c.set_connection_defaults() @@ -516,213 +482,3 @@ def ondemand_cluster( c.save() return c - - -def sagemaker_cluster( - name: str, - role: str = None, - profile: str = None, - ssh_key_path: str = None, - instance_id: str = None, - instance_type: str = None, - num_instances: int = None, - image_uri: str = None, - autostop_mins: int = None, - connection_wait_time: int = None, - estimator: Union["sagemaker.estimator.EstimatorBase", Dict] = None, - job_name: str = None, - server_port: int = None, - server_host: int = None, - server_connection_type: Union[ServerConnectionType, str] = None, - ssl_keyfile: str = None, - ssl_certfile: str = None, - domain: str = None, - den_auth: bool = None, - default_env: Union["Env", str] = None, - load_from_den: bool = True, - dryrun: bool = False, - **kwargs, -) -> SageMakerCluster: - """ - Builds an instance of :class:`SageMakerCluster`. See SageMaker Hardware Setup section for more specific - instructions and requirements for providing the role and setting up the cluster. - - Args: - name (str): Name for the cluster, to re-use later on. - role (str, optional): An AWS IAM role (either name or full ARN). - Can be passed in explicitly as an argument or provided via an estimator. If not specified will try - using the ``profile`` attribute or environment variable ``AWS_PROFILE`` to extract the relevant role ARN. - More info on configuring an IAM role for SageMaker - `here `__. - profile (str, optional): AWS profile to use for the cluster. If provided instead of a ``role``, will lookup - the role ARN associated with the profile in the local AWS credentials. - If not provided, will use the ``default`` profile. - ssh_key_path (str, optional): Path (relative or absolute) to private SSH key to use for connecting to - the cluster. If not provided, will look for the key in path ``~/.ssh/sagemaker-ssh-gw``. - If not found will generate new keys and upload the public key to the default s3 bucket for the Role ARN. - instance_id (str, optional): ID of the AWS instance to use for the cluster. SageMaker does not expose - IP addresses of its instance, so we use an instance ID as a unique identifier for the cluster. - instance_type (str, optional): Type of AWS instance to use for the cluster. More info on supported - instance options `here `__. - (Default: ``ml.m5.large``.) - num_instances (int, optional): Number of instances to use for the cluster. - (Default: ``1``.) - image_uri (str, optional): Image to use for the cluster instead of using the default SageMaker image which - will be based on the framework_version and py_version. Can be an ECR url or dockerhub image and tag. - estimator (Union[str, sagemaker.estimator.EstimatorBase], optional): Estimator to use for a dedicated - training job. Leave as ``None`` if launching the compute without running a dedicated job. - More info on creating an estimator `here - `__. - autostop_mins (int, optional): Number of minutes to keep the cluster up after inactivity, - or ``-1`` to keep cluster up indefinitely. *Note: this will keep the cluster up even if a dedicated - training job has finished running or failed*. - connection_wait_time (int, optional): Amount of time to wait inside the SageMaker cluster before - continuing with normal execution. Useful if you want to connect before a dedicated job starts - (e.g. training). If you don't want to wait, set it to ``0``. - If no estimator is provided, will default to ``0``. - job_name (str, optional): Name to provide for a training job. If not provided will generate a default name - based on the image name and current timestamp (e.g. ``pytorch-training-2023-08-28-20-57-55-113``). - server_port (bool, optional): Port to use for the server (Default: ``32300``). - server_host (bool, optional): Host from which the server listens for traffic (i.e. the --host argument - `runhouse start` run on the cluster). - *Note: For SageMaker, since we connect to the Runhouse API server via an SSH tunnel, the only valid - host is localhost.* - server_connection_type (ServerConnectionType or str, optional): Type of connection to use for the Runhouse - API server. *Note: For SageMaker, only ``aws_ssm`` is currently valid as the server connection type.* - ssl_keyfile(str, optional): Path to SSL key file to use for launching the API server with HTTPS. - ssl_certfile(str, optional): Path to SSL certificate file to use for launching the API server with HTTPS. - domain(str, optional): Domain name for the cluster. Relevant if enabling HTTPs on the cluster. - den_auth (bool, optional): Whether to use Den authorization on the server. If ``True``, will validate incoming - requests with a Runhouse token provided in the auth headers of the request with the format: - ``{"Authorization": "Bearer "}``. (Default: ``None``). - default_env (Env or str, optional): Environment that the Runhouse server is started on in the cluster. Used to - specify an isolated environment (e.g. conda env) or any setup and requirements prior to starting the Runhouse - server. (Default: ``None``) - load_from_den (bool): Whether to try loading the SageMakerCluster resource from Den. (Default: ``True``) - dryrun (bool): Whether to create the SageMakerCluster if it doesn't exist, or load a SageMakerCluster object - as a dryrun. - (Default: ``False``) - - Returns: - SageMakerCluster: The resulting cluster. - - Example: - >>> import runhouse as rh - >>> # Launch a new SageMaker instance and keep it up indefinitely. - >>> # Note: This will use Role ARN associated with the "sagemaker" profile defined in the local aws credentials - >>> c = rh.sagemaker_cluster(name='sm-cluster', profile="sagemaker").save() - - >>> # Running a training job with a provided Estimator - >>> c = rh.sagemaker_cluster(name='sagemaker-cluster', - >>> estimator=PyTorch(entry_point='train.py', - >>> role='arn:aws:iam::123456789012:role/MySageMakerRole', - >>> source_dir='/Users/myuser/dev/sagemaker', - >>> framework_version='1.8.1', - >>> py_version='py36', - >>> instance_type='ml.p3.2xlarge'), - >>> ).save() - - >>> # Load cluster from above - >>> reloaded_cluster = rh.sagemaker_cluster(name="sagemaker-cluster") - """ - if ( - "aws-cli/2." - not in subprocess.run( - ["aws", "--version"], capture_output=True, text=True - ).stdout - ): - raise RuntimeError( - "SageMaker SDK requires AWS CLI v2. You may also need to run `pip uninstall awscli` to ensure the right " - "version is being used. For more info: https://www.run.house/docs/api/python/cluster#id9" - ) - - ssh_key_path = relative_file_path(ssh_key_path) if ssh_key_path else None - - if ( - server_connection_type is not None - and server_connection_type != ServerConnectionType.AWS_SSM - ): - raise ValueError( - "SageMaker Cluster currently requires a server connection type of `aws_ssm`." - ) - server_connection_type = ServerConnectionType.AWS_SSM.value - - if server_host and server_host not in LOCAL_HOSTS: - raise ValueError( - "SageMaker Cluster currently requires a server host of `localhost` or `127.0.0.1`" - ) - - server_port = server_port or DEFAULT_SERVER_PORT - - if name: - alt_options = dict( - role=role, - profile=profile, - ssh_key_path=ssh_key_path, - instance_id=instance_id, - image_uri=image_uri, - estimator=estimator, - instance_type=instance_type, - job_name=job_name, - num_instances=num_instances, - server_host=server_host, - server_port=server_port, - server_connection_type=server_connection_type, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - domain=domain, - den_auth=den_auth, - default_env=default_env, - ) - # Filter out None/default values - alt_options = {k: v for k, v in alt_options.items() if v is not None} - try: - c = SageMakerCluster.from_name( - name, - load_from_den=load_from_den, - dryrun=dryrun, - alt_options=alt_options, - ) - if c: - c.set_connection_defaults() - return c - except ValueError as e: - if not alt_options: - raise e - - if name in RESERVED_SYSTEM_NAMES: - raise ValueError( - f"Cluster name {name} is a reserved name. Please use a different name which is not one of " - f"{RESERVED_SYSTEM_NAMES}." - ) - - sm = SageMakerCluster( - name=name, - role=role, - profile=profile, - ssh_key_path=ssh_key_path, - estimator=estimator, - job_name=job_name, - instance_id=instance_id, - instance_type=instance_type, - num_instances=num_instances, - image_uri=image_uri, - autostop_mins=autostop_mins, - connection_wait_time=connection_wait_time, - server_host=server_host, - server_port=server_port, - server_connection_type=server_connection_type, - den_auth=den_auth, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - domain=domain, - default_env=default_env, - dryrun=dryrun, - **kwargs, - ) - sm.set_connection_defaults() - - if den_auth or rns_client.autosave_resources(): - sm.save() - - return sm diff --git a/runhouse/resources/hardware/on_demand_cluster.py b/runhouse/resources/hardware/on_demand_cluster.py index 71a8e9d9f..6c8253864 100644 --- a/runhouse/resources/hardware/on_demand_cluster.py +++ b/runhouse/resources/hardware/on_demand_cluster.py @@ -4,7 +4,7 @@ import time import warnings from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, List, Union import requests @@ -46,21 +46,21 @@ def __init__( num_instances: int = None, provider: str = None, default_env: "Env" = None, - dryrun=False, - autostop_mins=None, - use_spot=False, - image_id=None, - memory=None, - disk_size=None, - open_ports=None, - server_host: str = None, + dryrun: bool = False, + autostop_mins: int = None, + use_spot: bool = False, + image_id: str = None, + memory: Union[int, str] = None, + disk_size: Union[int, str] = None, + open_ports: Union[int, str, List[int]] = None, + server_host: int = None, server_port: int = None, server_connection_type: str = None, ssl_keyfile: str = None, ssl_certfile: str = None, domain: str = None, den_auth: bool = False, - region=None, + region: str = None, sky_kwargs: Dict = None, **kwargs, # We have this here to ignore extra arguments when calling from from_config ): @@ -190,7 +190,7 @@ def config(self, condensed=True): config["autostop_mins"] = self._autostop_mins return config - def endpoint(self, external=False): + def endpoint(self, external: bool = False): if not self.address or self.on_this_cluster(): return None @@ -219,13 +219,6 @@ def relative_yaml_path(yaml_path): return yaml_path def set_connection_defaults(self): - if self.server_connection_type in [ - ServerConnectionType.AWS_SSM, - ]: - raise ValueError( - f"OnDemandCluster does not support server connection type {self.server_connection_type}" - ) - if not self.server_connection_type: if self.ssl_keyfile or self.ssl_certfile: self.server_connection_type = ServerConnectionType.TLS @@ -451,6 +444,7 @@ def _update_from_sky_status(self, dryrun: bool = False): self._populate_connection_from_status_dict(cluster_dict) def get_instance_type(self): + """Returns instance type of the cluster.""" if self.instance_type and "--" in self.instance_type: # K8s specific syntax return self.instance_type elif ( @@ -463,6 +457,7 @@ def get_instance_type(self): return None def accelerators(self): + """Returns the acclerator type, or None if is a CPU.""" if ( self.instance_type and ":" in self.instance_type @@ -473,6 +468,7 @@ def accelerators(self): return None def num_cpus(self): + """Return the number of CPUs for a CPU cluster.""" if ( self.instance_type and ":" in self.instance_type @@ -626,9 +622,12 @@ def pause_autostop(self): # ----------------- SSH Methods ----------------- # @staticmethod - def cluster_ssh_key(path_to_file): + def cluster_ssh_key(path_to_file: Path): """Retrieve SSH key for the cluster. + Args: + path_to_file (Path): Path of the private key associated with the cluster. + Example: >>> ssh_priv_key = rh.ondemand_cluster("rh-cpu").cluster_ssh_key("~/.ssh/id_rsa") """ @@ -640,7 +639,11 @@ def cluster_ssh_key(path_to_file): raise Exception(f"File with ssh key not found in: {path_to_file}") def ssh(self, node: str = None): - """SSH into the cluster. If no node is specified, will SSH onto the head node. + """SSH into the cluster. + + Args: + node: Node to SSH into. If no node is specified, will SSH onto the head node. + (Default: ``None``) Example: >>> rh.ondemand_cluster("rh-cpu").ssh() diff --git a/runhouse/resources/hardware/sagemaker/__init__.py b/runhouse/resources/hardware/sagemaker/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/runhouse/resources/hardware/sagemaker/launch_instance.py b/runhouse/resources/hardware/sagemaker/launch_instance.py deleted file mode 100644 index 69cd6e831..000000000 --- a/runhouse/resources/hardware/sagemaker/launch_instance.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Script used to keep the SageMaker cluster up, pending the autostop time provided in the cluster's config.""" -import json -import os -import subprocess -import time -import warnings -from pathlib import Path -from typing import Dict - -DEFAULT_AUTOSTOP = -1 -MAIN_DIR = "/opt/ml/code" -OUT_FILE = "sm_cluster.out" - -# ---------Configure SSH Helper---------- -# https://github.com/aws-samples/sagemaker-ssh-helper#step-3-modify-your-training-script -import sagemaker_ssh_helper - -sagemaker_ssh_helper.setup_and_start_ssh() - - -def run_training_job(path_to_job: str, num_attempts: int): - job_succeeded = False - - try: - # Execute the script as a separate process and capture stdout and stderr - completed_process = subprocess.run( - ["python", path_to_job], - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, # Redirect stderr to stdout - text=True, - ) - - # Access combined stdout and stderr - combined_output = completed_process.stdout - - if combined_output: - print(combined_output) - with open(f"{MAIN_DIR}/{OUT_FILE}", "w") as f: - f.write(combined_output) - - job_succeeded = True - - except subprocess.CalledProcessError as e: - warnings.warn( - f"({e.returncode}) Error executing training script " - f"(already made {num_attempts} attempts): {e.stderr}" - ) - num_attempts += 1 - - except Exception as e: - num_attempts += 1 - warnings.warn( - f"Error executing training script (already made {num_attempts} attempts): {e}" - ) - - finally: - return job_succeeded, num_attempts - - -def read_cluster_config() -> Dict: - """Read the autostop from the cluster's config - this will get populated when the cluster is created, - or via the autostop APIs (e.g. `pause_autostop` or `keep_warm`)""" - try: - # Note: Runhouse has not yet been installed at this stage on the cluster, - # so we can't import CLUSTER_CONFIG_PATH, we just need to hardcode it. - with open(os.path.expanduser("~/.rh/cluster_config.json"), "r") as f: - cluster_config = json.load(f) - except FileNotFoundError: - cluster_config = {} - - return cluster_config - - -if __name__ == "__main__": - print("Launching instance from script") - server_log = Path("~/.rh/server.log").expanduser() - server_log.parent.mkdir(parents=True, exist_ok=True) - server_log.touch() - - last_active = time.time() - last_autostop_value = None - training_job_completed = False - path_to_job = None - num_attempts = 0 - - while True: - last_active = server_log.stat().st_mtime - config = read_cluster_config() - - autostop = int(config.get("autostop_mins", DEFAULT_AUTOSTOP)) - - if autostop != -1: - time_to_autostop = last_active + (autostop * 60) - current_time = time.time() - - if current_time >= time_to_autostop: - print("Autostop time reached, stopping instance") - break - - # Reset launch time if autostop was updated - if last_autostop_value is not None and autostop != last_autostop_value: - print(f"Resetting autostop from {last_autostop_value} to {autostop}") - last_active = current_time - - last_autostop_value = autostop - - estimator_entry_point = config.get("estimator_entry_point") - if estimator_entry_point: - # Update the path to the custom estimator job which was rsynced to the cluster - # and now lives in path: /opt/ml/code - path_to_job = f"{MAIN_DIR}/{estimator_entry_point}" - - if not training_job_completed and path_to_job: - print(f"Running training job specified in path: {path_to_job}") - training_job_completed, num_attempts = run_training_job( - path_to_job, num_attempts - ) - - time.sleep(20) diff --git a/runhouse/resources/hardware/sagemaker/refresh-ssm-session.sh b/runhouse/resources/hardware/sagemaker/refresh-ssm-session.sh deleted file mode 100755 index 044c379d7..000000000 --- a/runhouse/resources/hardware/sagemaker/refresh-ssm-session.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash - -# Note: Adapted from: https://github.com/aws-samples/sagemaker-ssh-helper/blob/main/sagemaker_ssh_helper/sm-connect-ssh-proxy -# This skips creating SSH keys and adding the public key to the authorized keys list in S3, which happened when the -# cluster was initially upped and does not need to be repeated when reconnecting with the cluster. - -INSTANCE_ID="$1" -SSH_KEY="$2" -CURRENT_REGION="$3" -shift 3 -PORT_FWD_ARGS=$* - -instance_status=$(aws ssm describe-instance-information --filters Key=InstanceIds,Values="$INSTANCE_ID" --query 'InstanceInformationList[0].PingStatus' --output text) - -echo "Cluster status: $instance_status" - -if [[ "$instance_status" != "Online" ]]; then - echo "Error: Cluster is offline." - exit 1 -fi - -AWS_CLI_VERSION=$(aws --version) - -# Check if the AWS CLI version contains "aws-cli/2." -if [[ $AWS_CLI_VERSION == *"aws-cli/2."* ]]; then - echo "AWS CLI version: $AWS_CLI_VERSION" -else - echo "Error: AWS CLI version must be v2. Please update your AWS CLI version." - exit 1 -fi - -echo "Starting SSH over SSM proxy" - -proxy_command="aws ssm start-session\ - --reason 'Local user started SageMaker SSH Helper'\ - --region '${CURRENT_REGION}'\ - --target '${INSTANCE_ID}'\ - --document-name AWS-StartSSHSession\ - --parameters portNumber=22" - -ssh -4 -T -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \ - -o ProxyCommand="$proxy_command" \ - -o ServerAliveInterval=15 -o ServerAliveCountMax=3 \ - -o PasswordAuthentication=no \ - -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ - $PORT_FWD_ARGS "$INSTANCE_ID" diff --git a/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py b/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py deleted file mode 100644 index 2486373b0..000000000 --- a/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py +++ /dev/null @@ -1,1526 +0,0 @@ -import configparser -import contextlib -import getpass -import importlib -import os -import pty -import re -import select -import shlex -import socket -import subprocess -import sys -import textwrap -import threading -import time -import warnings -from pathlib import Path -from typing import Dict, Optional, Union - -try: - import boto3 - import paramiko - import sagemaker - from sagemaker.estimator import EstimatorBase - from sagemaker.mxnet import MXNet - from sagemaker.pytorch import PyTorch - from sagemaker.tensorflow import TensorFlow - from sagemaker.xgboost import XGBoost -except ImportError: - pass - -from runhouse.constants import LOCAL_HOSTS - -from runhouse.globals import configs, rns_client -from runhouse.logger import get_logger - -from runhouse.resources.hardware.cluster import Cluster -from runhouse.resources.hardware.utils import ServerConnectionType -from runhouse.rns.utils.api import ( - is_jsonable, - relative_file_path, - resolve_absolute_path, -) -from runhouse.utils import generate_default_name - - -logger = get_logger(__name__) -#################################################################################################### -# Caching mechanisms for SSHTunnelForwarder -#################################################################################################### - -ssh_tunnel_cache = {} - - -def get_open_ssh_tunnel(address: str, ssh_port: int) -> Optional["SSHTunnelForwarder"]: - if (address, ssh_port) in ssh_tunnel_cache: - - ssh_tunnel = ssh_tunnel_cache[(address, ssh_port)] - # Initializes tunnel_is_up dictionary - ssh_tunnel.check_tunnels() - - if ( - ssh_tunnel.is_active - and ssh_tunnel.tunnel_is_up[ssh_tunnel.local_bind_address] - ): - return ssh_tunnel - - else: - # If the tunnel is no longer active or up, pop it from the global cache - ssh_tunnel_cache.pop((address, ssh_port)) - else: - return None - - -def cache_open_ssh_tunnel( - address: str, - ssh_port: str, - ssh_tunnel: "SSHTunnelForwarder", -): - ssh_tunnel_cache[(address, ssh_port)] = ssh_tunnel - - -class SageMakerCluster(Cluster): - DEFAULT_SERVER_HOST = "localhost" - DEFAULT_INSTANCE_TYPE = "ml.m5.large" - DEFAULT_REGION = "us-east-1" - DEFAULT_USER = "root" - # https://github.com/aws/deep-learning-containers/blob/master/available_images.md - BASE_ECR_URL = "763104351884.dkr.ecr.us-east-1.amazonaws.com" - - # Default path for any estimator source code copied onto the cluster - ESTIMATOR_SRC_CODE_PATH = "/opt/ml/code" - ESTIMATOR_LOG_FILE = "sm_cluster.out" - SSH_KEY_FILE_NAME = "sagemaker-ssh-gw" - - DEFAULT_SSH_PORT = 11022 - DEFAULT_CONNECTION_WAIT_TIME = 60 # seconds - - def __init__( - self, - name: str, - role: str = None, - profile: str = None, - region: str = None, - ssh_key_path: str = None, - instance_id: str = None, - instance_type: str = None, - num_instances: int = None, - image_uri: str = None, - autostop_mins: int = None, - connection_wait_time: int = None, - estimator: Union["EstimatorBase", Dict] = None, - job_name: str = None, - server_host: str = None, - server_port: int = None, - domain: str = None, - server_connection_type: str = None, - ssl_keyfile: str = None, - ssl_certfile: str = None, - den_auth: bool = False, - dryrun=False, - **kwargs, - ): - """ - The Runhouse SageMaker cluster abstraction. This is where you can use SageMaker as a compute backend, just as - you would an on-demand cluster (i.e. cloud VMs) or a BYO (i.e. on-prem) cluster. Additionally supports running - dedicated training jobs using SageMaker Estimators. - - .. note:: - To build a cluster, please use the factory method :func:`sagemaker_cluster`. - """ - super().__init__( - name=name, - ssh_creds=kwargs.pop("ssh_creds", {}), - ssh_port=kwargs.pop("ssh_port", self.DEFAULT_SSH_PORT), - server_host=server_host, - server_port=server_port, - server_connection_type=server_connection_type, - ssl_certfile=ssl_certfile, - ssl_keyfile=ssl_keyfile, - domain=domain, - den_auth=den_auth, - dryrun=dryrun, - **kwargs, - ) - self._connection_wait_time = connection_wait_time - self._instance_type = instance_type - self._num_instances = num_instances - self._ssh_key_path = ssh_key_path - - # SSHEstimatorWrapper to facilitate the SSH connection to the cluster - self._ssh_wrapper = None - - # Note: Relevant only if an estimator is explicitly provided - self._estimator_entry_point = kwargs.get("estimator_entry_point") - self._estimator_source_dir = kwargs.get("estimator_source_dir") - self._estimator_framework = kwargs.get("estimator_framework") - - self.job_name = job_name - - # Set initial region - may be overwritten depending on the profile used - self.region = region or self.DEFAULT_REGION - - # Set a default sessions initially - may overwrite depending on the profile loaded below - self._set_boto_session() - self._set_sagemaker_session() - - # Either use the user-provided instance_id, or look it up from the job_name - self.instance_id = instance_id or ( - self._cluster_instance_id() if self.job_name else None - ) - - self._autostop_mins = ( - autostop_mins - if autostop_mins is not None - else configs.get("default_autostop") - ) - - self.estimator = self._load_estimator(estimator) - - self.role, self.profile = ( - self._load_role_and_profile(role, profile) - if not dryrun - else (role, profile) - ) - logger.info( - f"Using SageMaker execution role: `{self.role}` and profile: `{self.profile}`" - ) - - self.image_uri = self._load_image_uri(image_uri) - - # Note: Setting instance ID as cluster IP for compatibility with Cluster parent class methods - self.address = self.instance_id - - def config(self, condensed=True): - config = super().config(condensed) - config.update( - { - "instance_id": self.instance_id, - "role": self.role, - "region": self.region, - "profile": self.profile, - "ssh_key_path": self.ssh_key_path, - "job_name": self.job_name, - "instance_type": self.instance_type, - "num_instances": self.num_instances, - "image_uri": self.image_uri, - "autostop_mins": self._autostop_mins, - "connection_wait_time": self.connection_wait_time, - } - ) - - # If running a dedicated job on the cluster, add the estimator config - if self.estimator and ( - self._estimator_source_dir and self._estimator_entry_point - ): - config.update( - { - "estimator_entry_point": self._estimator_entry_point, - "estimator_source_dir": str(self._estimator_source_dir), - } - ) - - if isinstance(self.estimator, EstimatorBase): - # Serialize the estimator before saving it down in the config - selected_attrs = { - key: value - for key, value in self.estimator.__dict__.items() - if is_jsonable(value) - } - # Estimator types: mxnet, tensorflow, keras, pytorch, onnx, xgboost - self._estimator_framework = type(self.estimator).__name__ - config.update( - { - "estimator": selected_attrs, - "estimator_framework": self._estimator_framework, - } - ) - - if isinstance(self.estimator, dict): - config.update({"estimator": self.estimator}) - - return config - - @property - def hosts_path(self): - return Path("~/.ssh/known_hosts").expanduser() - - @property - def ssh_config_file(self): - return Path("~/.ssh/config").expanduser() - - @property - def ssh_key_path(self): - """Relative path to the private SSH key used to connect to the cluster.""" - if self._ssh_key_path: - return relative_file_path(self._ssh_key_path) - - # Default relative path - return f"~/.ssh/{self.SSH_KEY_FILE_NAME}" - - @ssh_key_path.setter - def ssh_key_path(self, ssh_key_path): - self._ssh_key_path = ssh_key_path - - @property - def num_instances(self): - if self._num_instances: - return self._num_instances - elif self.estimator: - return self.estimator.instance_count - else: - return 1 - - @num_instances.setter - def num_instances(self, num_instances): - self._num_instances = num_instances - - @property - def connection_wait_time(self): - """Amount of time the SSH helper will wait inside SageMaker before it continues normal execution""" - if self._connection_wait_time is not None: - return self._connection_wait_time - elif self.estimator and ( - self._estimator_source_dir and self._estimator_entry_point - ): - # Allow for connecting to the instance before the job starts (e.g. training) - return self.DEFAULT_CONNECTION_WAIT_TIME - else: - # For inference and others, always up and running - return 0 - - @connection_wait_time.setter - def connection_wait_time(self, connection_wait_time): - self._connection_wait_time = connection_wait_time - - @property - def instance_type(self): - if self._instance_type: - return self._instance_type - elif self.estimator: - return self.estimator.instance_type - else: - return self.DEFAULT_INSTANCE_TYPE - - @instance_type.setter - def instance_type(self, instance_type): - self._instance_type = instance_type - - @property - def default_bucket(self): - """Default bucket to use for storing the cluster's authorized public keys.""" - return self._sagemaker_session.default_bucket() - - @property - def _use_https(self) -> bool: - # Note: Since always connecting via SSM no need for HTTPS - return False - - @property - def _extra_ssh_args(self): - """Extra SSH arguments to be used when connecting to the cluster.""" - # Note - port 12345 can be used for Python Debug Server: "-R localhost:12345:localhost:12345" - # https://github.com/aws-samples/sagemaker-ssh-helper#remote-debugging-with-pycharm-debug-server-over-ssh - return f"-L localhost:{self.ssh_port}:localhost:22" - - @property - def _s3_keys_path(self): - """Path to public key stored for the cluster on S3. When initializing the cluster, the public key - is copied by default to an authorized keys file in this location.""" - return f"s3://{self.default_bucket}/ssh-authorized-keys/" - - @property - def _ssh_public_key_path(self): - return f"{self._abs_ssh_key_path}.pub" - - @property - def _abs_ssh_key_path(self): - return resolve_absolute_path(self.ssh_key_path) - - @property - def _ssh_key_comment(self): - """Username and hostname to be used as the comment for the public key.""" - return f"{getpass.getuser()}@{socket.gethostname()}" - - @property - def _s3_client(self): - if self._boto_session is None: - self._set_boto_session() - - return self._boto_session.client("s3") - - def _get_env_activate_cmd(self, env=None): - """Prefix for commands run on the cluster. Ensure we are running all commands in the conda environment - and not the system default python.""" - # TODO [JL] Can SageMaker handle this for us? - if env: - from runhouse.resources.envs import _get_env_from - - return _get_env_from(env)._activate_cmd - return "source /opt/conda/bin/activate" - - def _set_boto_session(self, profile_name: str = None): - self._boto_session = boto3.Session( - region_name=self.region, profile_name=profile_name - ) - - def _set_sagemaker_session(self): - """Create a SageMaker session required for using the SageMaker APIs.""" - self._sagemaker_session = sagemaker.Session(boto_session=self._boto_session) - - def set_connection_defaults(self): - if ( - "aws-cli/2." - not in subprocess.run( - ["aws", "--version"], capture_output=True, text=True - ).stdout - ): - raise RuntimeError( - "SageMaker SDK requires AWS CLI v2. You may also need to run `pip uninstall awscli` to ensure " - "the right version is being used. For more info: https://www.run.house/docs/api/python/cluster#id9" - ) - - if self.ssh_key_path: - self.ssh_key_path = relative_file_path(self.ssh_key_path) - else: - self.ssh_key_path = None - - if ( - self.server_connection_type is not None - and self.server_connection_type != ServerConnectionType.AWS_SSM - ): - raise ValueError( - "SageMaker Cluster currently requires a server connection type of `aws_ssm`." - ) - self.server_connection_type = ServerConnectionType.AWS_SSM.value - - if self.server_host and self.server_host not in LOCAL_HOSTS: - raise ValueError( - "SageMaker Cluster currently requires a server host of `localhost` or `127.0.0.1`" - ) - - # ------------------------------------------------------- - # Cluster State & Lifecycle Methods - # ------------------------------------------------------- - def restart_server( - self, - _rh_install_url: str = None, - resync_rh: bool = True, - restart_ray: bool = True, - env: Union[str, "Env"] = None, - restart_proxy: bool = False, - ): - """Restart the RPC server on the SageMaker instance. - - Args: - resync_rh (bool): Whether to resync runhouse. (Default: ``True``) - restart_ray (bool): Whether to restart Ray. (Default: ``True``) - env (str or Env): Env to restart the server from. If not provided - will use default env on the cluster. - restart_proxy (bool): Whether to restart nginx on the cluster, if configured. (Default: ``False``) - Example: - >>> rh.sagemaker_cluster("sagemaker-cluster").restart_server() - """ - return super().restart_server( - _rh_install_url, resync_rh, restart_ray, env, restart_proxy - ) - - def check_server(self, restart_server=True): - if self.on_this_cluster(): - return - - if not self.instance_id or not self.is_up(): - logger.info(f"Cluster {self.name} is not up, bringing it up now.") - self.up_if_not() - - if not self._http_client: - try: - self.connect_server_client() - logger.info( - f"Checking server {self.name} with instance ID: {self.instance_id}" - ) - - self.client.check_server() - logger.info(f"Server {self.instance_id} is up.") - except: - if restart_server: - logger.info( - f"Server {self.instance_id} is up, but the API server may not be up." - ) - self.run( - [ - "sudo apt-get install screen -y " - "&& sudo apt-get install rsync -y" - ] - ) - # Restart the server inside the base conda env - self.restart_server( - resync_rh=True, - restart_ray=True, - env=None, - ) - logger.info(f"Checking server {self.instance_id} again.") - - self.client.check_server() - else: - raise ValueError( - f"Could not connect to SageMaker instance {self.instance_id}" - ) - - def up(self): - """Up the cluster. - - Example: - >>> rh.sagemaker_cluster("sagemaker-cluster").up() - """ - logger.info("Preparing to launch a new SageMaker cluster") - self._launch_new_cluster() - - if rns_client.autosave_resources(): - self.save() - - return self - - def up_if_not(self): - """Bring up the cluster if it is not up. No-op if cluster is already up. - - Example: - >>> rh.sagemaker_cluster("sagemaker-cluster").up_if_not() - """ - if not self.is_up(): - self.address = None - self.job_name = None - self.instance_id = None - self.up() - return self - - def is_up(self) -> bool: - """Check if the cluster is up. - - Example: - >>> rh.sagemaker_cluster("sagemaker-cluster").is_up() - """ - try: - resp: dict = self.status() - status = resp.get("TrainingJobStatus") - # Up if the instance is in progress - return status == "InProgress" - except: - return False - - def teardown(self): - """Teardown the SageMaker instance. - - Example: - >>> rh.sagemaker_cluster(name="sagemaker-cluster").teardown() - """ - self._stop_instance(delete_configs=False) - - def teardown_and_delete(self): - """Teardown the SageMaker instance and delete from RNS configs. - - Example: - >>> rh.sagemaker_cluster(name="sagemaker-cluster").teardown_and_delete() - """ - self._stop_instance() - - def keep_warm(self, autostop_mins: int = -1): - """Keep the cluster warm for given number of minutes after inactivity. - - Args: - autostop_mins (int): Amount of time (in minutes) to keep the cluster warm after inactivity. - If set to ``-1``, keep cluster warm indefinitely. (Default: ``-1``) - """ - self._update_autostop(autostop_mins) - - return self - - def __getstate__(self): - """Delete non-serializable elements (e.g. sagemaker session object) before pickling.""" - state = self.__dict__.copy() - state["_sagemaker_session"] = None - state["_http_client"] = None - state["_rpc_tunnel"] = None - return state - - @contextlib.contextmanager - def pause_autostop(self): - """Context manager to temporarily pause autostop.""" - self._update_autostop(autostop_mins=-1) - yield - self._update_autostop(self._autostop_mins) - - def status(self) -> dict: - """ - Get status of SageMaker cluster. - - Example: - >>> status = rh.sagemaker_cluster("sagemaker-cluster").status() - """ - try: - return self._sagemaker_session.describe_training_job(self.job_name) - except: - return {} - - # ------------------------------------------------------- - # SSH APIs - # ------------------------------------------------------- - def ssh_tunnel( - self, local_port, remote_port=None, num_ports_to_try: int = 0, retry=True - ) -> "SSHTunnelForwarder": - from sshtunnel import BaseSSHTunnelForwarderError, SSHTunnelForwarder - - tunnel = get_open_ssh_tunnel(self.address, self.ssh_port) - if tunnel and tunnel.local_bind_port == local_port: - logger.info( - f"SSH tunnel on ports {local_port, remote_port} already created with the cluster" - ) - return tunnel - - try: - remote_bind_addresses = ("127.0.0.1", local_port) - local_bind_addresses = ("", local_port) - - ssh_tunnel = SSHTunnelForwarder( - self.DEFAULT_SERVER_HOST, - ssh_username=self.DEFAULT_USER, - ssh_pkey=self._abs_ssh_key_path, - ssh_port=self.ssh_port, - remote_bind_address=remote_bind_addresses, - local_bind_address=local_bind_addresses, - set_keepalive=1800, - ) - - # Start the SSH tunnel - ssh_tunnel.start() - - # Update the SSH config for the cluster with the connected SSH port - self._add_or_update_ssh_config_entry() - - logger.info("SSH connection has been successfully created with the cluster") - - except BaseSSHTunnelForwarderError as e: - if not retry: - # Failed to create the SSH tunnel object even after successfully refreshing the SSM session - raise BaseSSHTunnelForwarderError( - f"{e} Make sure ports {self.server_port} and {self.ssh_port} are " - f"not already in use." - ) - - # Refresh the SSM session, which should bind the HTTP and SSH ports to localhost which are forwarded - # to the cluster - self._refresh_ssm_session_with_cluster(num_ports_to_try) - - # Retry creating the SSH tunnel once the session has been refreshed - return self.ssh_tunnel( - local_port, remote_port, num_ports_to_try, retry=False - ) - - cache_open_ssh_tunnel(self.address, self.ssh_port, ssh_tunnel) - return ssh_tunnel - - def ssh(self, interactive: bool = True): - """SSH into the cluster. - - Args: - interactive (bool): Whether to start an interactive shell or not (Default: ``True``). - - Example: - >>> rh.sagemaker_cluster(name="sagemaker-cluster").ssh() - """ - - if (self.address, self.ssh_port) not in ssh_tunnel_cache: - # Make sure SSM session and SSH tunnels are up before running the command - self.connect_server_client() - - if not interactive: - logger.info( - f"Created SSH tunnel with the cluster. To SSH into the cluster, run: `ssh {self.name}`" - ) - return - - head_fd, worker_fd = pty.openpty() - ssh_process = subprocess.Popen( - ["ssh", "-o", "StrictHostKeyChecking=no", self.name], - stdin=worker_fd, - stdout=worker_fd, - stderr=worker_fd, - universal_newlines=True, - ) - - # Close the worker_fd in the parent process as it's not needed there - os.close(worker_fd) - - # Wait for the SSH process to initialize - select.select([head_fd], [], []) - - # Interact with the SSH process through the head_fd - try: - while True: - if head_fd in select.select([head_fd], [], [], 0)[0]: - output = os.read(head_fd, 1024).decode() - print(output, end="") - - if sys.stdin in select.select([sys.stdin], [], [], 0)[0]: - user_input = sys.stdin.readline() - try: - os.write(head_fd, user_input.encode()) - except OSError: - pass - - # terminate the SSH process gracefully - if user_input.strip() == "exit": - break - except Exception as e: - raise e - finally: - # Close the head_fd and terminate the SSH process when done - os.close(head_fd) - ssh_process.terminate() - - def _run_commands_with_runner( - self, - commands: list, - cmd_prefix: str, - stream_logs: bool, - node: str = None, - port_forward: int = None, - require_outputs: bool = True, - ): - return_codes = [] - for command in commands: - if command.startswith("rsync"): - try: - result = subprocess.run( - command, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - return_codes.append( - (result.returncode, result.stdout, result.stderr) - ) - except subprocess.CalledProcessError as e: - return_codes.append((255, "", str(e))) - else: - # Host can be replaced with name (as reflected in the ~/.ssh/config file) - from runhouse.resources.hardware.sky_command_runner import SkySSHRunner - - runner = SkySSHRunner( - (self.name, self.ssh_port), - ssh_user=self.DEFAULT_USER, - ssh_private_key=self._abs_ssh_key_path, - ssh_control_name=f"{self.name}:{self.ssh_port}", - ) - command = f"{cmd_prefix} {command}" if cmd_prefix else command - logger.debug(f"Running command on {self.name}: {command}") - return_code, stdout, stderr = runner.run( - command, - require_outputs=require_outputs, - stream_logs=stream_logs, - port_forward=port_forward, - ) - - if ( - return_code != 0 - and "dpkg: error processing package install-info" in stdout - ): - # **NOTE**: there may be issues with some SageMaker GPUs post installation script which - # leads to an error which looks something like: "installed install-info package post-installation - # script subprocess returned error exit status 2" - # https://askubuntu.com/questions/1034961/cant-upgrade-error-etc-environment-source-not-found-and-error-processin - # /etc/environment file may also be corrupt, replacing with an empty file allows - # subsequent python commands to run - self._run_commands_with_runner( - commands=[ - "cd /var/lib/dpkg/info && sudo rm *.postinst " - "&& sudo mv /etc/environment /etc/environment_broken " - "&& sudo touch /etc/environment " - f"&& {command}" - ], - cmd_prefix=cmd_prefix, - stream_logs=stream_logs, - ) - - return_codes.append((return_code, stdout, stderr)) - - return return_codes - - # ------------------------------------------------------- - # Cluster Provisioning & Launching - # ------------------------------------------------------- - def _create_launch_estimator(self): - """Create the estimator object used for launching the cluster. If a custom estimator is provided, use that. - Otherwise, use a Runhouse default estimator to launch. - **Note If an estimator is provided, Runhouse will override the entry point and source dir to use the - default Runhouse entry point and source dir. This is to ensure that the connection to the cluster - can be maintained even if the job fails, after it has completed, or if autostop is enabled.**""" - # Note: these entry points must point to the existing local files - default_entry_point = "launch_instance.py" - full_module_name = f"resources/hardware/sagemaker/{default_entry_point}" - - entry_point_path = self._get_path_for_module(full_module_name) - source_dir_path = os.path.dirname(entry_point_path) - - # Set default_entry_point and default_source_dir - default_source_dir = source_dir_path - - if self.estimator: - # Save the original entry point and source dir to be used on the cluster for running the estimator - self._estimator_entry_point = self.estimator.entry_point - self._estimator_source_dir = self.estimator.source_dir - - # Update the estimator with the Runhouse custom entry point and source dir - # When the job is initialized, it will run through the Runhouse entry point, which will manage the - # running of the custom estimator - self.estimator.entry_point = default_entry_point - self.estimator.source_dir = default_source_dir - - return self.estimator - - else: - # No estimator provided, use the Runhouse custom estimator (using PyTorch by default) - estimator_dict = { - "instance_count": self.num_instances, - "role": self.role, - "image_uri": self.image_uri, - "framework_version": "2.0.1", - "py_version": "py310", - "entry_point": default_entry_point, - "source_dir": default_source_dir, - "instance_type": self.instance_type, - # https://docs.aws.amazon.com/sagemaker/latest/dg/train-warm-pools.html - "keep_alive_period_in_seconds": 3600, - } - - return PyTorch(**estimator_dict) - - def _launch_new_cluster(self): - self.estimator = self._create_launch_estimator() - - logger.info( - f"Launching a new SageMaker cluster (num instances={self.num_instances}) on instance " - f"type: {self.instance_type}" - ) - - self._create_new_instance() - - # If no name provided, use the autogenerated name - self.job_name = self.estimator.latest_training_job.name - - self.instance_id = self._cluster_instance_id() - - # For compatibility with parent Cluster class methods which use an address - self.address = self.instance_id - - logger.info(f"New SageMaker instance started with ID: {self.instance_id}") - - # Remove stale entries from the known hosts file - self._filter_known_hosts() - - logger.info("Creating session with cluster via SSM") - self._create_ssm_session_with_cluster() - - self.check_server() - - if self._estimator_source_dir and self._estimator_entry_point: - # Copy the provided estimator's code to the cluster - Runhouse will then manage running the job in order - # to preserve control over the cluster's autostop - self._sync_estimator_to_cluster() - logger.info( - f"Logs for the estimator can be viewed on the cluster in " - f"path: {self.ESTIMATOR_SRC_CODE_PATH}/{self.ESTIMATOR_LOG_FILE}" - ) - - logger.info( - f"Connection with {self.name} has been created. You can SSH onto " - f"the cluster with the CLI using: ``ssh {self.name}``" - ) - - def _create_new_instance(self): - from sagemaker_ssh_helper.wrapper import SSHEstimatorWrapper - - # Make sure the SSHEstimatorWrapper is being used by the estimator, this is necessary for - # enabling the SSH tunnel to the cluster - # https://github.com/aws-samples/sagemaker-ssh-helper/tree/main#step-2-modify-your-start-training-job-code - ssh_dependency_dir = SSHEstimatorWrapper.dependency_dir() - if ssh_dependency_dir not in self.estimator.dependencies: - self.estimator.dependencies.append(ssh_dependency_dir) - - # Create the SSH wrapper & run the job - self._ssh_wrapper = SSHEstimatorWrapper.create( - self.estimator, connection_wait_time_seconds=self.connection_wait_time - ) - - self._start_instance() - - def _create_ssm_session_with_cluster(self, num_ports_to_try: int = 5): - """Create a session with the cluster. Runs a bash script containing a series of commands which use existing - SSH keys or generate new ones needed to authorize the connection with the cluster via the AWS SSM. - These commands are run when the cluster is initially provisioned, or for subsequent connections if the session - is longer active. Once finished, the SSH port and HTTP port will be bound to processes on - localhost which are forwarded to the cluster.""" - # https://github.com/aws-samples/sagemaker-ssh-helper/tree/main#forwarding-tcp-ports-over-ssh-tunnel - base_command = self._load_base_command_for_ssm_session() - - connected = False - while not connected: - command = f"{base_command} {self._extra_ssh_args}" - - try: - if num_ports_to_try == 0: - raise ConnectionError( - f"Failed to create SSM session and connect to {self.name} after repeated attempts." - f"Make sure SSH keys exist in local path: {self._abs_ssh_key_path}" - ) - - logger.debug(f"Running command: {command}") - - # Define an event to signal completion of the SSH tunnel setup - tunnel_setup_complete = threading.Event() - - # Manually allocate a pseudo-terminal to prevent a "pseudo-terminal not allocated" error - head_fd, worker_fd = pty.openpty() - - def run_ssm_session_cmd(): - # Execute the command with the pseudo-terminal in a separate thread - process = subprocess.Popen( - command, - shell=True, - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - ) - - # Close the worker file descriptor as we don't need it - os.close(worker_fd) - - # Close the master file descriptor after reading the output - os.close(head_fd) - - # Wait for the process to complete and collect its return code - process.wait() - - # Signal that the tunnel setup is complete - tunnel_setup_complete.set() - - tunnel_thread = threading.Thread(target=run_ssm_session_cmd) - tunnel_thread.daemon = True # Set the thread as a daemon, so it won't block the main thread - - # Start the SSH tunnel thread - tunnel_thread.start() - - # Give time for the SSM session to start, SSH keys to be copied onto the cluster, and the SSH port - # forwarding command to run - tunnel_setup_complete.wait(timeout=30) - - if not self._ports_are_in_use(): - # Command should bind SSH port and HTTP port on localhost, if this is not the case try re-running - # the bash script with a different set of ports - # E.g. ❯ lsof -i:11022,32300 - # COMMAND PID USER FD TYPE DEVICE SIZE/OFF NODE NAME - # ssh 97115 myuser 3u IPv4 0xcf81f230786cc9fd 0t0 TCP localhost:32300 (LISTEN) - # ssh 97115 myuser 6u IPv4 0xcf81f230786eff6d 0t0 TCP localhost:11022 (LISTEN) - raise ConnectionError - - # Update the SSH config for the cluster with the connected SSH port - self._add_or_update_ssh_config_entry() - - connected = True - - logger.info( - f"Created SSM session using ports {self.ssh_port} and {self.server_port}. " - f"All active sessions can be viewed with: ``aws ssm describe-sessions --state Active``" - ) - - except (ConnectionError, subprocess.CalledProcessError): - # Try re-running with updated the ports - possible the ports are already in use - self.server_port += 1 - self.ssh_port += 1 - num_ports_to_try -= 1 - pass - - def _start_instance(self): - """Call the SageMaker CreateTrainingJob API to start the training job on the cluster.""" - # TODO [JL] Note: Keeping private until re-running training jobs on the same cluster is supported - if not self.estimator: - logger.warning("No estimator found, cannot run job.") - return - - # NOTE: underscores not allowed for training job name - must match: ^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62} - self.estimator.fit( - wait=False, - job_name=self.job_name or generate_default_name(self.name, sep="-"), - ) - - def _load_base_command_for_ssm_session(self) -> str: - """Bash script command for creating the SSM session and uploading the SSH keys to the cluster. Will try - reusing existing keys locally if they exist, otherwise will generate new ones locally and copy them to s3.""" - private_key_path = self._abs_ssh_key_path - public_key_path = self._ssh_public_key_path - - resource_name = "resources/hardware/sagemaker/start-ssm-proxy-connection.sh" - script_path = self._get_path_for_module(resource_name) - - os.chmod(script_path, 0o755) - - s3_key_path = self._s3_keys_path - - # bash script which creates an SSM session with the cluster - base_command = ( - f'bash {script_path} "{self.instance_id}" ' - f'"{s3_key_path}" "{private_key_path}" "{self.region}"' - ) - - bucket = s3_key_path.split("/")[2] - key = "/".join(s3_key_path.split("/")[3:]) - - if Path(public_key_path).exists() and Path(private_key_path).exists(): - # If the key pair exists locally, make sure a matching public key also exists in s3 - with open(self._ssh_public_key_path, "r") as f: - public_key = f.read() - - self._add_public_key_to_authorized_keys(bucket, key, public_key) - - return base_command - - # If no private + public keys exists generate a new key pair from scratch - logger.warning( - f"No private + public keypair found in local path: {private_key_path}. Generating a new key pair " - "locally and uploading the new public key to s3" - ) - self._create_new_ssh_key_pair(bucket, key) - - return base_command - - def _refresh_ssm_session_with_cluster(self, num_ports_to_try: int = 5): - """Reconnect to the cluster via the AWS SSM. This bypasses the step of creating a new SSH key which was already - done when upping the cluster. Note: this assumes the session has previously been created, which we do when - the cluster has been upped. - - To view all sessions: ``aws ssm describe-sessions --state Active`` - """ - ssh_key_path = self._abs_ssh_key_path - public_key_path = self._ssh_public_key_path - - if not Path(ssh_key_path).exists() and not Path(public_key_path).exists(): - logger.warning( - f"SSH key pairs not found in paths: {ssh_key_path} and {public_key_path}" - ) - self._create_ssm_session_with_cluster() - - # https://github.com/aws-samples/sagemaker-ssh-helper/blob/main/sagemaker_ssh_helper/sm-connect-ssh-proxy - full_module_name = "resources/hardware/sagemaker/refresh-ssm-session.sh" - script_path = self._get_path_for_module(full_module_name) - - os.chmod(script_path, 0o755) - - # Remove stale entries from the known hosts file - this is import for avoiding collisions when - # subsequent clusters are created as the IP address is added to the file as: [localhost]:11022 - self._filter_known_hosts() - - num_attempts = num_ports_to_try - connected = False - - while not connected: - command = [ - script_path, - self.instance_id, - ssh_key_path, - self.region, - ] + shlex.split(self._extra_ssh_args) - - if num_ports_to_try == 0: - raise ConnectionError( - f"Failed to create connection with {self.name} after {num_attempts} attempts " - f"(cluster status=`{self.status().get('TrainingJobStatus')}`). Make sure that another SageMaker " - f"cluster is not already active, that AWS CLI V2 is installed, and that the path has been properly " - f"added to your bash profile" - f"(https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-troubleshooting.html). " - f"If the error persists, try running the command to create the session " - f"manually: `bash {' '.join(command)}`" - ) - - try: - subprocess.Popen( - command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - stdin=subprocess.PIPE, - ) - - # Give enough time for the aws ssm + ssh port forwarding commands in the script to complete - # Better to wait a few more seconds than to restart the HTTP server on the cluster unnecessarily - time.sleep(8) - - if not self._ports_are_in_use(): - # Command should bind SSH port and HTTP port on localhost, if this is not the case try re-running - # with different ports - raise socket.error - - connected = True - logger.info("Successfully refreshed SSM session") - - except socket.error: - # If the refresh didn't work try connecting with a different port - could be that the port - # is already taken - self.server_port += 1 - self.ssh_port += 1 - num_ports_to_try -= 1 - pass - - # ------------------------------------------------------- - # SSH Keys Management - # ------------------------------------------------------- - def _create_new_ssh_key_pair(self, bucket, key): - """Create a new private / public key pairing needed for SSHing into the cluster.""" - private_key_path = self._abs_ssh_key_path - ssh_key = paramiko.RSAKey.generate(bits=2048) - - ssh_key.write_private_key_file(private_key_path) - os.chmod(private_key_path, 0o600) - - # Set the comment for the public key to include username and hostname - comment = self._ssh_key_comment - public_key = f"ssh-rsa {ssh_key.get_base64()} {comment}" - - # Update the public key locally - self._write_public_key(public_key) - - # Update the public key in s3 - self._add_public_key_to_authorized_keys(bucket, key, public_key) - - def _add_public_key_to_authorized_keys( - self, bucket: str, key: str, public_key: str - ) -> None: - """Add the public key to the authorized keys file stored in S3. This file will get copied onto the cluster's - authorized keys file.""" - path_to_auth_keys = self._path_to_auth_keys(key) - authorized_keys: str = self._load_authorized_keys(bucket, path_to_auth_keys) - - if not authorized_keys: - # Create a new authorized keys file - logger.info( - f"No authorized keys file found in s3 path: {self._s3_keys_path}. Creating and uploading a new file." - ) - self._upload_key_to_s3(bucket, path_to_auth_keys, public_key) - return - - if public_key not in authorized_keys: - # Add the public key to the existing authorized keys saved in s3 - authorized_keys += f"\n{public_key}\n" - logger.info( - f"Adding public key to authorized keys file saved for the cluster " - f"in path: {self._s3_keys_path}" - ) - self._upload_key_to_s3(bucket, path_to_auth_keys, authorized_keys) - - def _path_to_auth_keys(self, key): - """Path to the authorized keys file stored in the s3 bucket for the role ARN associated with the cluster.""" - return key + f"{self.SSH_KEY_FILE_NAME}.pub" - - def _write_public_key(self, public_key: str): - """Update the public key stored locally.""" - with open(self._ssh_public_key_path, "w") as f: - f.write(public_key) - - def _upload_key_to_s3(self, bucket, key, body): - """Save a public key to the authorized file in the default bucket for given SageMaker role.""" - self._s3_client.put_object( - Bucket=bucket, - Key=key, - Body=body, - ) - - def _load_authorized_keys(self, bucket, auth_keys_file) -> Union[str, None]: - """Load the authorized keys file for this AWS role stored in S3. If no file exists, return None.""" - try: - response = self._s3_client.get_object(Bucket=bucket, Key=auth_keys_file) - existing_pub_keys = response["Body"].read().decode("utf-8") - return existing_pub_keys - - except self._s3_client.exceptions.NoSuchKey: - # No authorized keys file exists in s3 for this role - return None - - except Exception as e: - raise e - - # ------------------------------------------------------- - # Cluster Helpers - # ------------------------------------------------------- - def rsync(self, source: str, dest: str, up: bool, contents: bool = False): - source = source + "/" if not source.endswith("/") else source - dest = dest + "/" if not dest.endswith("/") else dest - - command = ( - f"rsync -rvh --exclude='.git' --exclude='venv*/' --exclude='dist/' --exclude='docs/' " - f"--exclude='__pycache__/' --exclude='.*' " - f"--include='.rh/' -e 'ssh -o StrictHostKeyChecking=no " - f"-i {self._abs_ssh_key_path} -p {self.ssh_port}' {source} root@localhost:{dest}" - ) - - logger.info(f"Syncing {source} to: {dest} on cluster") - return_codes = self.run([command]) - if return_codes[0][0] != 0: - logger.error(f"rsync to SageMaker cluster failed: {return_codes[0][1]}") - - def _cluster_instance_id(self): - """Get the instance ID of the cluster. This is the ID of the instance running the training job generated - by SageMaker.""" - if self._ssh_wrapper: - # This is a hack to effectively do list.get(0, None) - return next(iter(self._ssh_wrapper.get_instance_ids()), None) - - from sagemaker_ssh_helper.manager import SSMManager - - ssm_manager = SSMManager(region_name=self.region) - ssm_manager.redo_attempts = 0 - instance_ids = ssm_manager.get_training_instance_ids(self.job_name) - # This is a hack to effectively do list.get(0, None) - return next(iter(instance_ids), None) - - def _get_path_for_module(self, resource_name: str) -> str: - import importlib.resources as pkg_resources - - package_name = "runhouse" - script_path = str(pkg_resources.files(package_name) / resource_name) - return script_path - - def _load_role_and_profile(self, role, profile): - """Load the SageMaker role and profile used for launching and connecting to the cluster. - Role can be provided as an ARN or a name. If provided, will search for the profile containing this role - in local AWS configs. If no profile is provided, try loading from the environment variable ``AWS_PROFILE``, - otherwise default to using the ``default`` profile.""" - if self.estimator: - # If using an estimator must provide a name or full ARN - # https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html - role = self.estimator.role - - if role and role.startswith("arn:aws"): - profile = profile or self._load_profile_and_region_for_role(role) - - else: - # https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html - profile = profile or os.environ.get("AWS_PROFILE") - if profile is None: - profile = "default" - logger.warning( - f"No profile provided or environment variable set to `AWS_PROFILE`, using the {profile} profile" - ) - - try: - # Update the sessions using the profile provided - self._set_boto_session(profile_name=profile) - self._set_sagemaker_session() - - # If no role explicitly provided, use sagemaker to get it via the profile - role = role or sagemaker.get_execution_role( - sagemaker_session=self._sagemaker_session - ) - except Exception as e: - if self.on_this_cluster(): - # If we're on the cluster, we may not have the profile or role saved locally, but should still be able - # to create the cluster object (e.g. for rh.here). - pass - else: - raise e - - return role, profile - - def _load_profile_and_region_for_role(self, role_arn: str) -> Union[str, None]: - """Find the profile (and region) associated with a particular role ARN. If no profile is found, return None.""" - try: - aws_dir = os.path.expanduser("~/.aws") - credentials_path = os.path.join(aws_dir, "credentials") - config_path = os.path.join(aws_dir, "config") - - profiles_with_role = set() - - for path in [credentials_path, config_path]: - config = configparser.ConfigParser() - config.read(path) - - for section in config.sections(): - config_section = config[section] - config_role_arn = config_section.get("role_arn") - if config_role_arn == role_arn: - # Add just the name of the profile (not the full section heading) - profiles_with_role.add(section.split(" ")[-1]) - - # Update the region to use the one associated with this profile - profile_region = config_section.get("region") - if profile_region != self.region: - warnings.warn( - f"Updating region based on AWS config to: {profile_region}" - ) - self.region = profile_region - - if not profiles_with_role: - return None - - profiles = list(profiles_with_role) - profile = profiles[0] - - if len(profiles) > 1: - logger.warning( - f"Found multiple profiles associated with the same role. Using the first " - f"one ({profile})" - ) - - return profile - - except Exception as e: - logger.warning(f"Could not find a profile for role {role_arn}: {e}") - return None - - def _load_image_uri(self, image_uri: str = None) -> str: - """Load the docker image URI used for launching the SageMaker instance. If no image URI is provided, use - a default image based on the instance type.""" - if image_uri: - return image_uri - - if self.estimator: - return self.estimator.image_uri - - return self._base_image_uri() - - def _stop_instance(self, delete_configs=True): - """Stop the SageMaker instance. Optionally remove its config from RNS.""" - self._sagemaker_session.stop_training_job(job_name=self.job_name) - - if self.is_up(): - raise Exception(f"Failed to stop instance {self.name}") - - logger.info(f"Successfully stopped instance {self.name}") - - # Remove stale host key(s) from known hosts - self._filter_known_hosts() - - if delete_configs: - # Delete from RNS - rns_client.delete_configs(resource=self) - logger.info(f"Deleted {self.name} from configs") - - def _sync_runhouse_to_cluster(self, node: str = None, _install_url=None, env=None): - if not self.instance_id: - raise ValueError(f"No instance ID set for cluster {self.name}. Is it up?") - - if not self._http_client: - self.connect_server_client() - - # Sync the local ~/.rh directory to the cluster - self.rsync( - source=str(Path("~/.rh").expanduser()), - dest="~/.rh", - up=True, - contents=True, - ) - logger.info("Synced ~/.rh folder to the cluster") - - local_rh_package_path = Path(importlib.util.find_spec("runhouse").origin).parent - # local_rh_package_path = Path(pkgutil.get_loader("runhouse").path).parent - - # **Note** temp patch to handle PyYAML errors: https://github.com/yaml/pyyaml/issues/724 - base_rh_install_cmd = f'{self._get_env_activate_cmd(env=None)} && python3 -m pip install "cython<3.0.0"' - - # Check if runhouse is installed from source and has setup.py - if ( - not _install_url - and local_rh_package_path.parent.name == "runhouse" - and (local_rh_package_path.parent / "setup.py").exists() - ): - # Package is installed in editable mode - local_rh_package_path = local_rh_package_path.parent - dest_path = f"~/{local_rh_package_path.name}" - - self.rsync( - source=str(local_rh_package_path), - dest=dest_path, - up=True, - contents=True, - ) - - rh_install_cmd = ( - f"{base_rh_install_cmd} && python3 -m pip install ./runhouse[sagemaker]" - ) - else: - if not _install_url: - import runhouse - - _install_url = f"runhouse[sagemaker]=={runhouse.__version__}" - rh_install_cmd = ( - f"{base_rh_install_cmd} && python3 -m pip install {_install_url}" - ) - - status_codes = self.run([rh_install_cmd]) - - if status_codes[0][0] != 0: - raise ValueError( - f"Error installing runhouse on cluster: {status_codes[0][1]}" - ) - - def _load_estimator( - self, estimator: Union[Dict, "EstimatorBase", None] - ) -> Union[None, "EstimatorBase"]: - """Build an Estimator object from config""" - if estimator is None: - return None - - if isinstance(estimator, EstimatorBase): - return estimator - - if not isinstance(estimator, dict): - raise TypeError( - f"Unsupported estimator type. Expected dictionary or EstimatorBase, got {type(estimator)}" - ) - if "sagemaker_session" not in estimator: - # Estimator requires an initialized sagemaker session - # https://stackoverflow.com/questions/55869651/how-to-fix-aws-region-error-valueerror-must-setup-local-aws-configuration-with - estimator["sagemaker_session"] = self._sagemaker_session - - # Re-build the estimator object from its config - estimator_framework = self._estimator_framework - if estimator_framework == "PyTorch": - return PyTorch(**estimator) - elif estimator_framework == "TensorFlow": - return TensorFlow(**estimator) - elif estimator_framework == "MXNet": - return MXNet(**estimator) - elif estimator_framework == "XGBoost": - return XGBoost(**estimator) - else: - raise NotImplementedError( - f"Unsupported estimator framework {estimator_framework}" - ) - - def _sync_estimator_to_cluster(self): - """If providing a custom estimator sync over the estimator's source directory to the cluster""" - from runhouse import folder - - estimator_folder = folder( - path=Path(self._estimator_source_dir).expanduser() - ).to(self, path=self.ESTIMATOR_SRC_CODE_PATH) - logger.info( - f"Synced estimator source directory to the cluster in path: {estimator_folder.path}" - ) - - def _base_image_uri(self): - """Pick a default image for the cluster based on its instance type""" - # TODO [JL] Add flexibility for py version & framework_version - gpu_instance_types = ["p2", "p3", "p4", "g3", "g4", "g5"] - - image_type = ( - "gpu" - if any(prefix in self.instance_type for prefix in gpu_instance_types) - else "cpu" - ) - - cuda_version = "cu118-" if image_type == "gpu" else "" - image_url = ( - f"{self.BASE_ECR_URL}/pytorch-training:2.0.1-{image_type}-py310-{cuda_version}" - f"ubuntu20.04-sagemaker" - ) - - return image_url - - def _update_autostop(self, autostop_mins: int = None): - cluster_config = self.config() - cluster_config["autostop_mins"] = autostop_mins or -1 - if not self._http_client: - self.connect_server_client() - # Update the config on the server with the new autostop time - self.client.check_server() - - # ------------------------------------------------------- - # Port Management - # ------------------------------------------------------- - def _ports_are_in_use(self) -> bool: - """Check if the ports used for port forwarding from localhost to the cluster are in use.""" - try: - self._bind_ports_to_localhost() - # Ports are not in use - return False - except OSError: - # At least one of the ports is in use - return True - - def _bind_ports_to_localhost(self): - """Try binding the SSH and HTTP ports to localhost to check if they are in use.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1: - s1.bind(("localhost", self.ssh_port)) - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2: - s2.bind(("localhost", self.server_port)) - - # ------------------------------------------------------- - # SSH config - # ------------------------------------------------------- - def _filter_known_hosts(self): - """To prevent host key collisions in the ~/.ssh/known_hosts file, remove any stale entries of localhost - using the SSH port.""" - known_hosts = self.hosts_path - if not known_hosts.exists(): - # e.g. in a collab or notebook environment - return - - valid_hosts = [] - with open(known_hosts, "r") as f: - for line in f: - if not line.strip().startswith( - f"[{self.DEFAULT_SERVER_HOST}]:{self.ssh_port}" - ): - valid_hosts.append(line) - - with open(known_hosts, "w") as f: - f.writelines(valid_hosts) - - def _ssh_config_entry( - self, - port: int, - name: str = None, - hostname: str = None, - identity_file: str = None, - user: str = None, - ): - return textwrap.dedent( - f""" - # Added by Runhouse for SageMaker SSH Support - Host {name or self.name} - HostName {hostname or self.DEFAULT_SERVER_HOST} - IdentityFile {identity_file or self._abs_ssh_key_path} - Port {port} - User {user or self.DEFAULT_USER} - """ - ) - - def _add_or_update_ssh_config_entry(self): - """Update the SSH config to allow for accessing the cluster via: ssh """ - connected_ssh_port = self.ssh_port - config_file = self.ssh_config_file - - with open(config_file, "r") as f: - existing_config = f.read() - - pattern = re.compile( - rf"^\s*Host\s+{re.escape(self.name)}\s*$.*?(?=^\s*Host\s+|\Z)", - re.MULTILINE | re.DOTALL, - ) - - entry_match = pattern.search(existing_config) - - if entry_match: - # If entry already exists update the port with the connected SSH port (may have changed from previous - # connection attempt) - existing_entry = entry_match.group() - updated_entry = re.sub( - r"(?<=Port )\d+", str(connected_ssh_port), existing_entry - ) - updated_config = existing_config.replace(existing_entry, updated_entry) - with open(config_file, "w") as f: - f.write(updated_config) - else: - # Otherwise, add the new entry to the config file - new_entry = self._ssh_config_entry(port=connected_ssh_port) - with open(config_file, "a") as f: - f.write(new_entry) diff --git a/runhouse/resources/hardware/sagemaker/start-ssm-proxy-connection.sh b/runhouse/resources/hardware/sagemaker/start-ssm-proxy-connection.sh deleted file mode 100755 index ea1bad8c3..000000000 --- a/runhouse/resources/hardware/sagemaker/start-ssm-proxy-connection.sh +++ /dev/null @@ -1,141 +0,0 @@ -#!/bin/bash - -# Adapted from: https://github.com/aws-samples/sagemaker-ssh-helper/blob/main/sagemaker_ssh_helper/sm-connect-ssh-proxy -# Creates an SSM session and sets up port forwarding to the cluster through localhost for an SSH port and HTTP port -# Optionally creates new SSH keys if they do not already exist - -set -e - -INSTANCE_ID="$1" -SSH_AUTHORIZED_KEYS="$2" -SSH_KEY="$3" -CURRENT_REGION="$4" -shift 4 -PORT_FWD_ARGS=$* - -echo "INSTANCE_ID: $INSTANCE_ID" -echo "SSH_AUTHORIZED_KEYS: $SSH_AUTHORIZED_KEYS" -echo "CURRENT_REGION: $CURRENT_REGION" -echo "PORT_FWD_ARGS: $PORT_FWD_ARGS" - -instance_status=$(aws ssm describe-instance-information --filters Key=InstanceIds,Values="$INSTANCE_ID" --query 'InstanceInformationList[0].PingStatus' --output text) - -echo "Cluster status: $instance_status" - -if [[ "$instance_status" != "Online" ]]; then - echo "Error: Cluster is offline." - exit 1 -fi - -AWS_CLI_VERSION=$(aws --version) - -# Check if the AWS CLI version contains "aws-cli/2." -if [[ $AWS_CLI_VERSION == *"aws-cli/2."* ]]; then - echo "AWS CLI version: $AWS_CLI_VERSION" -else - echo "Error: AWS CLI version must be v2. Please update your AWS CLI version." - exit 1 -fi - -echo "Running SSM commands at region ${CURRENT_REGION} to copy public key to ${INSTANCE_ID}" - -# Copy the public key from the s3 bucket to the authorized_keys.d directory on the cluster -cp_command="aws s3 cp --recursive \"${SSH_AUTHORIZED_KEYS}\" /root/.ssh/authorized_keys.d/" - -# Copy the SSH public key onto the cluster to the root directory, then copy from the root to /etc -send_command=$(aws ssm send-command \ - --region "${CURRENT_REGION}" \ - --instance-ids "${INSTANCE_ID}" \ - --document-name "AWS-RunShellScript" \ - --comment "Copy public key for SSH helper" \ - --timeout-seconds 30 \ - --parameters "commands=[ - 'mkdir -p /root/.ssh/authorized_keys.d/', - '$cp_command', - 'ls -la /root/.ssh/authorized_keys.d/', - 'cat /root/.ssh/authorized_keys.d/* > /root/.ssh/authorized_keys', - 'cat /root/.ssh/authorized_keys' - ]" \ - --no-cli-pager --no-paginate \ - --output json) - -json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/' - -cp_command="cp -r /root/.ssh/authorized_keys.d/* /etc/ssh/authorized_keys.d" -echo "Copying keys from root folder to etc folder: $cp_command" - -send_command=$(aws ssm send-command \ - --region "${CURRENT_REGION}" \ - --instance-ids "${INSTANCE_ID}" \ - --document-name "AWS-RunShellScript" \ - --comment "Copy public key to /etc/ssh folder on cluster" \ - --timeout-seconds 30 \ - --parameters "commands=[ - 'mkdir -p /etc/ssh/authorized_keys.d/', - '$cp_command', - 'ls -la /etc/ssh/authorized_keys.d/', - 'cat /etc/ssh/authorized_keys.d/* > /etc/ssh/authorized_keys', - 'ls -la /etc/ssh/authorized_keys' - ]" \ - --no-cli-pager --no-paginate \ - --output json) - -json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/' - - -send_command=$(echo "$send_command" | python -m json.tool) -command_id=$(echo "$send_command" | grep "CommandId" | sed -e "$json_value_regexp") -echo "Got command ID: $command_id" - -# Wait a little bit to prevent strange InvocationDoesNotExist error -sleep 5 - -for i in $(seq 1 15); do - # Switch to unicode for AWS CLI to properly parse output - export LC_CTYPE=en_US.UTF-8 - command_output=$(aws ssm get-command-invocation \ - --instance-id "${INSTANCE_ID}" \ - --command-id "${command_id}" \ - --no-cli-pager --no-paginate \ - --output json) - command_output=$(echo "$command_output" | python -m json.tool) - command_status=$(echo "$command_output" | grep '"Status":' | sed -e "$json_value_regexp") - output_content=$(echo "$command_output" | grep '"StandardOutputContent":' | sed -e "$json_value_regexp") - error_content=$(echo "$command_output" | grep '"StandardErrorContent":' | sed -e "$json_value_regexp") - - echo "Command status: $command_status" - if [[ "$command_status" != "Pending" && "$command_status" != "InProgress" ]]; then - echo "Command output: $output_content" - if [[ "$error_content" != "" ]]; then - echo "Command error: $error_content" - fi - break - fi - sleep 1 -done - -if [[ "$command_status" != "Success" ]]; then - echo "Error: Command didn't finish successfully in time" - exit 2 -fi - -echo "Connecting to $INSTANCE_ID as proxy and starting port forwarding with the args: $PORT_FWD_ARGS" - -# We don't use AWS-StartPortForwardingSession feature of SSM here, because we need port forwarding in both directions -# with -L and -R parameters of SSH. This is useful for forwarding the PyCharm license server, which needs -R option. -# SSM allows only forwarding of ports from the server (equivalent to the -L option). -# shellcheck disable=SC2086 -proxy_command="aws ssm start-session\ - --reason 'Local user started SageMaker SSH Helper'\ - --region '${CURRENT_REGION}'\ - --target '${INSTANCE_ID}'\ - --document-name AWS-StartSSHSession\ - --parameters portNumber=%p" - -# shellcheck disable=SC2086 -ssh -4 -T -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \ - -o ProxyCommand="$proxy_command" \ - -o ServerAliveInterval=15 -o ServerAliveCountMax=3 \ - -o PasswordAuthentication=no \ - -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ - $PORT_FWD_ARGS "$INSTANCE_ID" diff --git a/runhouse/resources/hardware/ssh_tunnel.py b/runhouse/resources/hardware/ssh_tunnel.py index a06f874f9..c3075bc08 100644 --- a/runhouse/resources/hardware/ssh_tunnel.py +++ b/runhouse/resources/hardware/ssh_tunnel.py @@ -35,17 +35,17 @@ def __init__( """Initialize an ssh tunnel from a remote server to localhost Args: - address (str): The address of the server we are trying to port forward an address to our local machine with. - ssh_creds (Dict): A dictionary of ssh credentials used to connect to the remote server. - local_port (int): The port locally where we are attempting to bind the remote server address to. - ssh_port (int): The port on the machine where the ssh server is running. - This is generally port 22, but occasionally - we may forward a container's ssh port to a different port - on the actual machine itself (for example on a Docker VM). Defaults to 22. - remote_port (Optional[int], optional): The port of the remote server - we're attempting to port forward. Defaults to None. - num_ports_to_try (int, optional): The number of local ports to attempt to bind to, - starting at local_port and incrementing by 1 till we hit the max. Defaults to 0. + ip (str): The address of the server we are trying to port forward an address to our local machine with. + ssh_user (str, optional): The SSH username to use for connecting to the remote server. Defaults to None. + ssh_private_key (str, optional): The path to the SSH private key file. Defaults to None. + ssh_control_name (str, optional): The name for the SSH control connection. + Defaults to `"__default__"`. + ssh_proxy_command (str, optional): The SSH proxy command to use for connecting to the remote + server. Defaults to None. + ssh_port (int, optional): The port on the remote machine where the SSH server is running. Defaults to 22. + disable_control_master (bool, optional): Whether to disable SSH ControlMaster. Defaults to False. + docker_user (str, optional): The Docker username to use if connecting through Docker. Defaults to None. + cloud (str, optional): The cloud provider, if applicable. Defaults to None. """ self.ip = ip self.ssh_port = ssh_port @@ -207,10 +207,11 @@ def ssh_tunnel( This is generally port 22, but occasionally we may forward a container's ssh port to a different port on the actual machine itself (for example on a Docker VM). Defaults to 22. - remote_port (Optional[int], optional): The port of the remote server + remote_port (int, optional): The port of the remote server we're attempting to port forward. Defaults to None. num_ports_to_try (int, optional): The number of local ports to attempt to bind to, starting at local_port and incrementing by 1 till we hit the max. Defaults to 0. + docker_user (str, optional): The Docker username to use if connecting through Docker. Defaults to None. cloud (str, Optional): Cluster cloud, if an on-demand cluster. Returns: diff --git a/runhouse/resources/hardware/utils.py b/runhouse/resources/hardware/utils.py index bc43e08db..5105c9552 100644 --- a/runhouse/resources/hardware/utils.py +++ b/runhouse/resources/hardware/utils.py @@ -29,13 +29,11 @@ class ServerConnectionType(str, Enum): ``tls``: Do not use port forwarding and start the server with HTTPS (using custom or fresh TLS certs), by default on port 443. ``none``: Do not use port forwarding, and start the server with HTTP, by default on port 80. - ``aws_ssm``: Use AWS SSM to connect to the server, by default on port 32300. """ SSH = "ssh" TLS = "tls" NONE = "none" - AWS_SSM = "aws_ssm" class ResourceServerStatus(str, Enum): diff --git a/runhouse/resources/module.py b/runhouse/resources/module.py index abba5f7ff..d77d4c991 100644 --- a/runhouse/resources/module.py +++ b/runhouse/resources/module.py @@ -47,7 +47,6 @@ "_system", "dryrun", "_resolve", - "provenance", "_signature", "_dumb_signature_cache", ] @@ -67,13 +66,15 @@ def __init__( system: Union[Cluster, str] = None, env: Optional[Env] = None, dryrun: bool = False, - provenance: Optional[dict] = None, **kwargs, ): """ - Runhouse Module object + Runhouse Module object. + + .. note:: + To create a Module, please use the factory method :func:`module`. """ - super().__init__(name=name, dryrun=dryrun, provenance=provenance, **kwargs) + super().__init__(name=name, dryrun=dryrun, **kwargs) self._system = _get_cluster_from( system or _current_cluster(key="config"), dryrun=dryrun ) @@ -114,7 +115,7 @@ def __init__( self._resolve = False self._openapi_spec = None - def config(self, condensed=True): + def config(self, condensed: bool = True): if not self.system: raise ValueError( "Cannot save an in-memory local module to RNS. Please send the module to a local " @@ -152,7 +153,9 @@ def config(self, condensed=True): return config @classmethod - def from_config(cls, config: dict, dryrun=False, _resolve_children=True): + def from_config( + cls, config: Dict, dryrun: bool = False, _resolve_children: bool = True + ): if config.get("pointers"): config.pop("resource_subtype", None) logger.debug(f"Constructing module from pointers {config['pointers']}") @@ -222,7 +225,6 @@ def from_config(cls, config: dict, dryrun=False, _resolve_children=True): new_module._pointers = config.pop("pointers", None) new_module._signature = config.pop("signature", None) new_module.dryrun = config.pop("dryrun", False) - new_module.provenance = config.pop("provenance", None) new_module._openapi_spec = config.pop("openapi_spec", None) return new_module @@ -297,7 +299,7 @@ def signature(self, rich=False): return self._signature def method_signature(self, method): - """Extracts the properties of a method that we want to preserve when sending the method over the wire.""" + """Method signature, consisting of method properties to preserve when sending the method over the wire.""" signature = inspect.signature(method) signature_metadata = { "signature": str(signature), @@ -321,9 +323,9 @@ def endpoint(self, external: bool = False): down from a config). If not, request the endpoint from the Module's system. Args: - external: If True and getting the endpoint from the system, only return an endpoint if it's externally - accessible (i.e. not on localhost, not connected through as ssh tunnel). If False, return the endpoint - even if it's not externally accessible. + external (bool, optional): If True and getting the endpoint from the system, only return an endpoint if + it's externally accessible (i.e. not on localhost, not connected through as ssh tunnel). If False, + return the endpoint even if it's not externally accessible. (Default: ``False``) """ if self._endpoint: return self._endpoint @@ -440,6 +442,15 @@ def to( ): """Put a copy of the module on the destination system and env, and return the new module. + Args: + system (str or Cluster): The system to setup the module and env on. + env (str, List[str], or Env, optional): The environment where the module lives on in the cluster, + or the set of requirements necessary to run the module. (Default: ``None``) + name (Optional[str], optional): Name to give to the module resource, if you wish to rename it. + (Default: ``None``) + force_install (bool, optional): Whether to re-install and perform the environment setup steps, even + if it may already exist on the cluster. (Defualt: ``False``) + Example: >>> local_module = rh.module(my_class) >>> cluster_module = local_module.to("my_cluster") @@ -563,6 +574,13 @@ def get_or_to( """Check if the module already exists on the cluster, and if so return the module object. If not, put the module on the cluster and return the remote module. + Args: + system (str or Cluster): The system to setup the module and env on. + env (str, List[str], or Env, optional): The environment where the module lives on in the cluster, + or the set of requirements necessary to run the module. (Default: ``None``) + name (Optional[str], optional): Name to give to the module resource, if you wish to rename it. + (Default: ``None``) + Example: >>> remote_df = Model().get_or_to(my_cluster, name="remote_model") """ @@ -681,8 +699,21 @@ def refresh(self): else: return self - def replicate(self, num_replicas=1, names=None, envs=None, parallel=False): - """Replicate the module on the cluster in a new env and return the new modules.""" + def replicate( + self, + num_replicas: int = 1, + names: List[str] = None, + envs: List["Env"] = None, + parallel: bool = False, + ): + """Replicate the module on the cluster in a new env and return the new modules. + + Args: + num_relicas (int, optional): Number of replicas of the module to create. (Default: 1) + names (List[str], optional): List for the names for the replicas, if specified. (Default: ``None``) + envs (List[Env], optional): List of the envs for the replicas, if specified. (Default: ``None``) + parallel (bool, optional): Whether to create the replicas in parallel. (Default: ``False``) + """ if not self.system or not self.name: raise ValueError( "Cannot replicate a module that is not on a cluster. Please send the module to a cluster first." @@ -973,7 +1004,11 @@ def _save_sub_resources(self, folder: str = None): self.env.save(folder=folder) def rename(self, name: str): - """Rename the module.""" + """Rename the module. + + Args: + name (str): Name to rename the module to. + """ if self.name == name or self.rns_address == name: return old_name = self.name @@ -992,7 +1027,6 @@ def rename(self, name: str): ) def save(self, name: str = None, overwrite: bool = True, folder: str = None): - """Register the resource and save to local working_dir config and RNS config store.""" # Need to override Resource's save to handle key changes in the obj store # Also check that this is a Module and not a File @@ -1048,7 +1082,7 @@ def _is_running_in_notebook(module_path: Union[str, None]) -> bool: @staticmethod def _extract_pointers(raw_cls_or_fn: Union[Type, Callable]): """Get the path to the module, module name, and function name to be able to import it on the server""" - if not (isinstance(raw_cls_or_fn, type) or isinstance(raw_cls_or_fn, Callable)): + if not (isinstance(raw_cls_or_fn, Type) or isinstance(raw_cls_or_fn, Callable)): raise TypeError( f"Expected Type or Callable but received {type(raw_cls_or_fn)}" ) @@ -1109,7 +1143,11 @@ def _get_local_path_containing_module( def openapi_spec(self, spec_name: Optional[str] = None): """Generate an OpenAPI spec for the module. - TODO: This breaks if the module has type annotations that are classes, and not standard library or + Args: + spec_name (str, optional): Spec name for the OpenAPI spec. + """ + + """ TODO: This breaks if the module has type annotations that are classes, and not standard library or typing types. Maybe we can do something using: https://github.com/kuimono/openapi-schema-pydantic to allow @@ -1117,7 +1155,6 @@ def openapi_spec(self, spec_name: Optional[str] = None): TODO: What happens if there is an empty function, will this work with an empty body even though it is marked as required? - """ if self._openapi_spec is not None: return self._openapi_spec @@ -1229,7 +1266,6 @@ def __init__( pointers=cls_pointers, signature=None, name=None, - provenance=None, **kwargs, ): # args and kwargs are passed to the cls's __init__ method if this is being called on a cluster. They @@ -1242,7 +1278,6 @@ def __init__( system=system, env=env, dryrun=dryrun, - provenance=provenance, ) # This allows a class which is already on the cluster to construct an instance of itself with a factory # method, e.g. my_module = MyModuleCls.factory_constructor(*args, **kwargs) @@ -1323,10 +1358,11 @@ class (e.g. ``to``, ``fetch``, etc.). Properties and private methods are not int Args: cls: The class to instantiate. - name (Optional[str]): Name to give the module object, to be reused later on. - env (Optional[str or Env]): Environment in which the module should live on the cluster, if system is cluster. - load_from_den (bool): Whether to try loading the module from Den. (Default: ``True``) - dryrun (bool): Whether to create the Module if it doesn't exist, or load a Module object as a dryrun. + name (Optional[str], optional): Name to give the module object, to be reused later on. (Default: ``None``) + env (Optional[str or Env], optional): Environment in which the module should live on the cluster, if system + is cluster. (Default: ``None``) + load_from_den (bool, optional): Whether to try loading the module from Den. (Default: ``True``) + dryrun (bool, optional): Whether to create the Module if it doesn't exist, or load a Module object as a dryrun. (Default: ``False``) Returns: diff --git a/runhouse/resources/packages/git_package.py b/runhouse/resources/packages/git_package.py index 7b3323f85..9540d90b0 100644 --- a/runhouse/resources/packages/git_package.py +++ b/runhouse/resources/packages/git_package.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Union +from typing import Dict, Union from runhouse.resources.envs.utils import run_setup_command @@ -36,7 +36,7 @@ def __init__( self.git_url = git_url self.revision = revision - def config(self, condensed=True): + def config(self, condensed: bool = True): # If the package is just a simple Package.from_string string, no # need to store it in rns, just give back the string. # if self.install_method in ['pip', 'conda', 'git']: @@ -83,7 +83,7 @@ def _install(self, env: Union[str, "Env"] = None, cluster: "Cluster" = None): super()._install(env, cluster=cluster) @staticmethod - def from_config(config: dict, dryrun=False, _resolve_children=True): + def from_config(config: Dict, dryrun: bool = False, _resolve_children: bool = True): return GitPackage(**config, dryrun=dryrun) @@ -100,14 +100,15 @@ def git_package( Builds an instance of :class:`GitPackage`. Args: - name (str): Name to assign the package resource. - git_url (str): The GitHub URL of the package to install. - revision (str): Version of the Git package to install. - install_method (str): Method for installing the package. If left blank, defaults to local installation. - install_str (str): Additional arguments to add to installation command. - load_from_den (bool): Whether to try loading the package from Den. (Default: ``True``) - dryrun (bool): Whether to load the Package object as a dryrun, or create the Package if it doesn't exist. - (Default: ``False``) + name (str, optional): Name to assign the package resource. + git_url (str, optional): The GitHub URL of the package to install. + revision (str, optional): Version of the Git package to install. + install_method (str, optional): Method for installing the package. If left blank, defaults to + local installation. + install_str (str, optional): Additional arguments to add to installation command. + load_from_den (bool, optional): Whether to try loading the package from Den. (Default: ``True``) + dryrun (bool, optional): Whether to load the Package object as a dryrun, or create the Package if + it doesn't exist. (Default: ``False``) Returns: GitPackage: The resulting GitHub Package. diff --git a/runhouse/resources/packages/package.py b/runhouse/resources/packages/package.py index 992a71a1f..03f44f65e 100644 --- a/runhouse/resources/packages/package.py +++ b/runhouse/resources/packages/package.py @@ -91,7 +91,7 @@ def __init__( self.install_args = install_args self.preferred_version = preferred_version - def config(self, condensed=True): + def config(self, condensed: bool = True): # If the package is just a simple Package.from_string string, no # need to store it in rns, just give back the string. # if self.install_method in ['pip', 'conda', 'git']: @@ -416,7 +416,11 @@ def to( system: Union[str, Dict, "Cluster"], path: Optional[str] = None, ): - """Copy the package onto filesystem or cluster, and return the new Package object.""" + """Copy the package onto filesystem or cluster, and return the new Package object. + + Args: + system (str, Dict, or Cluster): Cluster to send the package to. + """ if not isinstance(self.install_target, InstallTarget): raise TypeError( "`install_target` must be an InstallTarget in order to copy the package to a system." @@ -451,7 +455,7 @@ def split_req_install_method(req_str: str): return (splat[0], splat[1]) if len(splat) > 1 else ("", splat[0]) @staticmethod - def from_config(config: dict, dryrun=False, _resolve_children=True): + def from_config(config: Dict, dryrun: bool = False, _resolve_children: bool = True): if isinstance(config.get("install_target"), tuple): config["install_target"] = InstallTarget( local_path=config["install_target"][0], @@ -466,7 +470,7 @@ def from_config(config: dict, dryrun=False, _resolve_children=True): return Package(**config, dryrun=dryrun) @staticmethod - def from_string(specifier: str, dryrun=False): + def from_string(specifier: str, dryrun: bool = False): if specifier == "requirements.txt": specifier = "reqs:./" @@ -566,15 +570,16 @@ def package( Builds an instance of :class:`Package`. Args: - name (str): Name to assign the package resource. - install_method (str): Method for installing the package. Options: [``pip``, ``conda``, ``reqs``, ``local``] - install_str (str): Additional arguments to install. - path (str): URL of the package to install. - system (str): File system or cluster on which the package lives. Currently this must a cluster or one of: - [``file``, ``s3``, ``gs``]. - load_from_den (bool): Whether to try loading the Package from Den. (Default: ``True``) - dryrun (bool): Whether to create the Package if it doesn't exist, or load the Package object as a dryrun. - (Default: ``False``) + name (str, optional): Name to assign the package resource. + install_method (str, optional): Method for installing the package. + Options: [``pip``, ``conda``, ``reqs``, ``local``] + install_str (str, optional): Additional arguments to install. + path (str, optional): URL of the package to install. + system (str, optional): File system or cluster on which the package lives. + Currently this must a cluster or one of: [``file``, ``s3``, ``gs``]. + load_from_den (bool, optional): Whether to try loading the Package from Den. (Default: ``True``) + dryrun (bool, optional): Whether to create the Package if it doesn't exist, or load the Package + object as a dryrun. (Default: ``False``) Returns: Package: The resulting package. diff --git a/runhouse/resources/provenance.py b/runhouse/resources/provenance.py deleted file mode 100644 index 7c75d27f1..000000000 --- a/runhouse/resources/provenance.py +++ /dev/null @@ -1,605 +0,0 @@ -import copy -import json -import logging -import sys -from enum import Enum -from io import StringIO -from pathlib import Path -from typing import Any, List, Optional, Union - -from runhouse.constants import LOGS_DIR -from runhouse.globals import configs, rns_client -from runhouse.logger import get_logger -from runhouse.resources.blobs import file - -# Need to alias so it doesn't conflict with the folder property -from runhouse.resources.folders import Folder, folder as folder_factory -from runhouse.resources.hardware import _current_cluster, _get_cluster_from, Cluster -from runhouse.resources.resource import Resource -from runhouse.rns.top_level_rns_fns import resolve_rns_path -from runhouse.rns.utils.api import log_timestamp, resolve_absolute_path -from runhouse.utils import StreamTee - -logger = get_logger(__name__) - - -class RunStatus(str, Enum): - NOT_STARTED = "NOT_STARTED" - RUNNING = "RUNNING" - COMPLETED = "COMPLETED" - CANCELLED = "CANCELLED" - ERROR = "ERROR" - - -class RunType(str, Enum): - CMD_RUN = "CMD" - FUNCTION_RUN = "FUNCTION" - CTX_MANAGER = "CTX_MANAGER" - - -class Run(Resource): - RESOURCE_TYPE = "run" - - LOCAL_RUN_PATH = f"{rns_client.rh_directory}/runs" - - RUN_CONFIG_FILE = "config_for_run.json" - RESULT_FILE = "result.pkl" - INPUTS_FILE = "inputs.pkl" - - def __init__( - self, - name: str = None, - fn_name: str = None, - cmds: list = None, - log_dest: str = "file", - path: str = None, - system: Union[str, Cluster] = None, - status: RunStatus = RunStatus.NOT_STARTED, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - creator: Optional[str] = None, - creation_stacktrace: Optional[str] = None, - upstream_artifacts: Optional[List] = None, - downstream_artifacts: Optional[List] = None, - run_type: RunType = RunType.CMD_RUN, - error: Optional[str] = None, - error_traceback: Optional[str] = None, - overwrite: bool = False, - dryrun: bool = False, - **kwargs, - ): - """ - Runhouse Run object - - .. note:: - To load an existing Run, please use the factory method :func:`run`. - """ - run_name = name or str(self._current_timestamp()) - super().__init__(name=run_name, dryrun=dryrun) - - self.log_dest = log_dest - self.folder = None - if self.log_dest == "file": - folder_system = system or Folder.DEFAULT_FS - folder_path = ( - resolve_absolute_path(path) - if path - else ( - self._base_local_folder_path(self.name) - if folder_system == Folder.DEFAULT_FS - else self._base_cluster_folder_path(name=run_name) - ) - ) - - if overwrite: - # Delete the Run from the system if one already exists - self._delete_existing_run(folder_path, folder_system) - - # Create new folder which lives on the system and contains all the Run's data: - # (run config, stdout, stderr, inputs, result) - self.folder = folder_factory( - path=folder_path, - system=folder_system, - dryrun=dryrun, - ) - - self.status = status - self.start_time = start_time - self.end_time = end_time - self.creator = creator - self.creation_stacktrace = creation_stacktrace - self.upstream_artifacts = upstream_artifacts or [] - self.downstream_artifacts = downstream_artifacts or [] - self.fn_name = fn_name - self.cmds = cmds - self.run_type = run_type or self._detect_run_type() - self.error = error - self.traceback = error_traceback - # TODO string representation of inputs - - def __enter__(self): - self.status = RunStatus.RUNNING - self.start_time = self._current_timestamp() - - # Begin tracking the Run in the rns_client - this adds the current Run to the stack of active Runs - rns_client.start_run(self) - - if self.log_dest == "file": - # Capture stdout and stderr to the Run's folder - self.folder.mkdir() - # TODO fix the fact that we keep appending and then stream back the full file - sys.stdout = StreamTee(sys.stdout, [Path(self._stdout_path).open(mode="a")]) - sys.stderr = StreamTee(sys.stderr, [Path(self._stderr_path).open(mode="a")]) - - # Add the stdout and stderr handlers to the root logger - self._stdout_handler = logging.StreamHandler(sys.stdout) - logger.addHandler(self._stdout_handler) - - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - self.end_time = self._current_timestamp() - if exc_type: - self.status = RunStatus.ERROR - self.error = exc_value - self.traceback = exc_traceback - else: - self.status = RunStatus.COMPLETED - - # Pop the current Run from the stack of active Runs - rns_client.stop_run() - - # if self.run_type == RunType.CMD_RUN: - # # Save Run config to its folder on the system - this will already happen on the cluster - # # for function based Runs - # self._write_config() - # - # # For cmd runs we are using the SSH command runner to get the stdout / stderr - # return - - # TODO [DG->JL] Do we still need this? - # stderr = f"{type(exc_value).__name__}: {str(exc_value)}" if exc_value else "" - # self.write(data=stderr.encode(), path=self._stderr_path) - - if self.log_dest == "file": - logger.removeHandler(self._stdout_handler) - - # Flush stdout and stderr - # sys.stdout.flush() - # sys.stderr.flush() - - # Restore stdout and stderr - if hasattr(sys.stdout, "instream"): - sys.stdout = sys.stdout.instream - if hasattr(sys.stderr, "instream"): - sys.stderr = sys.stderr.instream - - # Save Run config to its folder on the system - this will already happen on the cluster - # for function based Runs - # self._write_config() - - # return False to propagate any exception that occurred inside the with block - return False - - @staticmethod - def from_config(config: dict, dryrun=False): - return Run(**config, dryrun=dryrun) - - def __getstate__(self): - """Remove the folder object from the Run before pickling it.""" - state = self.__dict__.copy() - state["folder"] = None - state["_stdout_handler"] = None - return state - - def config(self, condensed=True): - """Metadata to store in RNS for the Run.""" - config = super().config(condensed) - base_config = { - "status": self.status, - "start_time": self.start_time, - "end_time": self.end_time, - "run_type": self.run_type, - "log_dest": self.log_dest, - "creator": self.creator, - "fn_name": self.fn_name, - "cmds": self.cmds, - # NOTE: artifacts are currently only tracked in context manager based runs - "upstream_artifacts": self.upstream_artifacts, - "downstream_artifacts": self.downstream_artifacts, - "path": self.folder.path, - "system": self._resource_string_for_subconfig( - self.folder.system, condensed - ), - "error": str(self.error), - "traceback": str(self.traceback), - } - config.update(base_config) - return config - - def populate_init_provenance(self): - self.creator = configs.username - self.creation_stacktrace = "".join(self.traceback.format_stack(limit=11)[1:]) - - @property - def run_config(self): - """Config to save in the Run's dedicated folder on the system. - Note: this is different from the config saved in RNS, which is the metadata for the Run. - """ - config = { - "name": self.name, - "status": self.status, - "start_time": self.start_time, - "end_time": self.end_time, - "run_type": self.run_type, - "fn_name": self.fn_name, - "cmds": self.cmds, - # NOTE: artifacts are currently only tracked in context manager based runs - "upstream_artifacts": self.upstream_artifacts, - "downstream_artifacts": self.downstream_artifacts, - } - return config - - def save(self, name: str = None, overwrite: bool = True, folder: str = None): - """If the Run name is being overwritten (ex: initially created with auto-generated name), - update the Run config stored on the system before saving to RNS.""" - config_for_rns = self.config() - config_path = self._path_to_config() - if not config_for_rns["name"] or name: - config_for_rns["name"] = resolve_rns_path(name or self.name) - self._write_config(config=config_for_rns) - logger.debug(f"Updated Run config name in path: {config_path}") - - return super().save(name, overwrite, folder) - - def write(self, data: Any, path: str): - """Write data (ex: function inputs or result, stdout, stderr) to the Run's dedicated folder on the system.""" - file(system=self.folder.system, path=path).write(data, serialize=False) - - def to(self, system, path: Optional[str] = None): - """Send a Run to another system. - - Args: - system (Union[str or Cluster]): Name of the system or Cluster object to copy the Run to. - path (Optional[str]): Path to the on the system to save the Run. - Defaults to the local path for Runs (in the rh folder of the working directory). - - Returns: - Run: A copy of the Run on the destination system and path. - """ - # TODO: [JL] - support for `on_completion` (wait to copy the results to destination until async run completes) - - new_run = copy.copy(self) - - if self.run_type == RunType.FUNCTION_RUN: - results_path = self._fn_result_path() - # Pickled function result should be saved down to the Run's folder on the cluster - if results_path not in self.folder.ls(): - raise FileNotFoundError( - f"No results saved down in path: {results_path}" - ) - - for fp in [self._stdout_path, self._stderr_path]: - # Stdout and Stderr files created on a cluster can be symlinks to the files that we create via Ray - # by default - before copying them to a new system make sure they are regular files - self._convert_symlink_to_file(path=fp) - - if system == "here": - # Save to default local path if none provided - path = path or self._base_local_folder_path(self.name) - - new_run.folder = self.folder.to(system=system, path=path) - - return new_run - - def refresh(self) -> "Run": - """Reload the Run object from the system. This is useful for checking the status of a Run. - For example: ``my_run.refresh().status``""" - run_config = self._load_run_config(folder=self.folder) - # Need the metadata from RNS and the Run specific data in order to re-load the Run object - config = {**self.config(), **run_config} - return Run.from_config(config, dryrun=True) - - def inputs(self) -> bytes: - """Load the pickled function inputs saved on the system for the Run.""" - return self._load_blob_from_path(path=self._fn_inputs_path()).fetch() - - def result(self): - """Load the function result saved on the system for the Run. If the Run has failed return the stderr, - otherwise return the stdout.""" - run_status = self.refresh().status - if run_status == RunStatus.COMPLETED: - results_path = self._fn_result_path() - if results_path not in self.folder.ls(): - raise FileNotFoundError( - f"No results file found in path: {results_path}" - ) - return self._load_blob_from_path(path=results_path).fetch() - elif run_status == RunStatus.ERROR: - logger.debug("Run failed, returning stderr") - return self.stderr() - else: - logger.debug(f"Run status: {self.status}, returning stdout") - return self.stdout() - - def stdout(self) -> str: - """Read the stdout saved on the system for the Run.""" - stdout_path = self._stdout_path - logger.debug(f"Reading stdout from path: {stdout_path}") - - return self._load_blob_from_path(path=stdout_path).fetch().decode().strip() - - def stderr(self) -> str: - """Read the stderr saved on the system for the Run.""" - stderr_path = self._stderr_path - logger.debug(f"Reading stderr from path: {stderr_path}") - - return self._load_blob_from_path(stderr_path).fetch().decode().strip() - - def _fn_inputs_path(self) -> str: - """Path to the pickled inputs used for the function which are saved on the system.""" - return f"{self.folder.path}/{self.INPUTS_FILE}" - - def _fn_result_path(self) -> str: - """Path to the pickled result for the function which are saved on the system.""" - return f"{self.folder.path}/{self.RESULT_FILE}" - - def _load_blob_from_path(self, path: str): - """Load a blob from the Run's folder in the specified path. (ex: function inputs, result, stdout, stderr).""" - return file(path=path, system=self.folder.system) - - def _register_new_run(self): - """Log a Run once it's been triggered on the system.""" - self.start_time = self._current_timestamp() - self.status = RunStatus.RUNNING - - # Write config data for the Run to its config file on the system - logger.debug(f"Registering new Run on system in path: {self.folder.path}") - self._write_config() - - def _register_fn_run_completion(self, run_status: RunStatus): - """Update a function based Run's config after its finished running on the system.""" - self.end_time = self._current_timestamp() - self.status = run_status - - logger.debug(f"Registering a completed fn Run with status: {run_status}") - self._write_config() - - def _register_cmd_run_completion(self, return_codes: list): - """Update a cmd based Run's config and register its stderr and stdout after running on the system.""" - run_status = RunStatus.ERROR if return_codes[0][0] != 0 else RunStatus.COMPLETED - self.status = run_status - - logger.debug(f"Registering a completed cmd Run with status: {run_status}") - self._write_config() - - # Write the stdout and stderr of the commands Run to the Run's folder - self.write(data=return_codes[0][1].encode(), path=self._stdout_path) - self.write(data=return_codes[0][2].encode(), path=self._stderr_path) - - def _write_config(self, config: dict = None, overwrite: bool = True): - """Write the Run's config data to the system. - - Args: - config (Optional[Dict]): Config to write. If none is provided, the Run's config for RNS will be used. - overwrite (Optional[bool]): Overwrite the config if one is already saved down. Defaults to ``True``. - """ - config_to_write = config or self.config() - logger.debug(f"Config to save on system: {config_to_write}") - self.folder.put( - {self.RUN_CONFIG_FILE: json.dumps(config)}, - overwrite=overwrite, - mode="w", - ) - - def _detect_run_type(self): - if self.fn_name: - return RunType.FUNCTION_RUN - elif self.cmds is not None: - return RunType.CMD_RUN - else: - return RunType.CTX_MANAGER - - def _path_to_config(self) -> str: - """Path the main folder storing the metadata, inputs, and results for the Run saved on the system.""" - return f"{self.folder.path}/{self.RUN_CONFIG_FILE}" - - def _path_to_file_by_ext(self, ext: str) -> str: - """Path the file for the Run saved on the system for a provided extension (ex: ``.out`` or ``.err``).""" - existing_file = self._find_file_path_by_ext(ext=ext) - if existing_file: - # If file already exists in file (ex: with function on a Ray cluster this will already be - # generated for us) - return existing_file - - path_to_ext = f"{self.folder.path}/{self.name}" + ext - return path_to_ext - - def _convert_symlink_to_file(self, path: str): - """If the system is a Cluster and the file path is a symlink, convert it to a regular file. - This is necessary to allow for copying of the file between systems (ex: cluster --> s3 or cluster --> local).""" - if isinstance(self.folder.system, Cluster): - status_codes: list = self.folder.system.run( - [f"test -h {path} && echo True || echo False"], stream_logs=True - ) - if status_codes[0][1].strip() == "True": - # If it's a symlink convert it to a regular file - self.folder.system.run( - [f"cp --remove-destination `readlink {path}` {path}"] - ) - - @property - def _stdout_path(self) -> str: - """Path to the stdout file for the Run.""" - return self._path_to_file_by_ext(ext=".out") - - @property - def _stderr_path(self) -> str: - """Path to the stderr file for the Run.""" - return self._path_to_file_by_ext(ext=".err") - - def _find_file_path_by_ext(self, ext: str) -> Union[str, None]: - """Get the file path by provided extension. Needed when loading the stdout and stderr files associated - with a particular run.""" - try: - folder_contents: list = self.folder.ls(sort=True) - except FileNotFoundError: - return None - - files_with_ext = self._filter_files_by_ext(folder_contents, ext) - if not files_with_ext: - # No .out / .err file already created in the logs folder for this Run - return None - - # Return the most recent file with this extension - return files_with_ext[0] - - def _register_upstream_artifact(self, artifact_name: str): - """Track a Runhouse object loaded in the Run's context manager. This object's name - will be saved to the upstream artifact registry of the Run's config.""" - if artifact_name not in self.upstream_artifacts: - self.upstream_artifacts.append(artifact_name) - - def _register_downstream_artifact(self, artifact_name: str): - """Track a Runhouse object saved in the Run's context manager. This object's name - will be saved to the downstream artifact registry of the Run's config.""" - if artifact_name not in self.downstream_artifacts: - self.downstream_artifacts.append(artifact_name) - - @staticmethod - def _current_timestamp(): - return str(log_timestamp()) - - @staticmethod - def _filter_files_by_ext(files: list, ext: str): - return list(filter(lambda x: x.endswith(ext), files)) - - @staticmethod - def _delete_existing_run(folder_path, folder_system: str): - """Delete existing Run on the system before a new one is created.""" - existing_folder = folder_factory( - path=folder_path, - system=folder_system, - ) - - existing_folder.rm(recursive=True) - - @staticmethod - def _load_run_config(folder: Folder) -> dict: - """Load the Run config file saved for the Run in its dedicated folder on the system .""" - try: - return json.loads(folder.get(Run.RUN_CONFIG_FILE)) - except FileNotFoundError: - return {} - - @staticmethod - def _base_cluster_folder_path(name: str): - """Path to the base folder for this Run on a cluster.""" - return f"{LOGS_DIR}/{name}" - - @staticmethod - def _base_local_folder_path(name: str): - """Path to the base folder for this Run on a local system.""" - return f"{LOGS_DIR}/{name}" - - -class capture_stdout: - """Context manager for capturing stdout to a file, list, or stream, while still printing to stdout.""" - - def __init__(self, output=None): - self.output = output - self._stream = None - - def __enter__(self): - if self.output is None: - self.output = StringIO() - - if isinstance(self.output, str): - self._stream = open(self.output, "w") - else: - self._stream = self.output - sys.stdout = StreamTee(sys.stdout, [self]) - sys.stderr = StreamTee(sys.stderr, [self]) - return self - - def write(self, message): - self._stream.write(message) - - def flush(self): - self._stream.flush() - - @property - def stream(self): - if isinstance(self.output, str): - return open(self.output, "r") - return self._stream - - def list(self): - if isinstance(self.output, str): - return self.stream.readlines() - return (self.stream.getvalue() or "").splitlines() - - def __str__(self): - return self.stream.getvalue() - - def __exit__(self, exc_type, exc_val, exc_tb): - if hasattr(sys.stdout, "instream"): - sys.stdout = sys.stdout.instream - if hasattr(sys.stderr, "instream"): - sys.stderr = sys.stderr.instream - self._stream.close() - return False - - -def run( - name: str = None, - log_dest: str = "file", - path: str = None, - system: Union[str, Cluster] = None, - load_from_den: bool = True, - dryrun: bool = False, - **kwargs, -) -> Union["Run", None]: - """Constructs a Run object. - - Args: - name (Optional[str]): Name of the Run to load. - log_dest (Optional[str]): Whether to save the Run's logs to a file or stream them back. (Default: ``file``) - path (Optional[str]): Path to the Run's dedicated folder on the system where the Run lives. - system (Optional[str or Cluster]): File system or cluster name where the Run lives. - If providing a file system this must be one of: - [``file``, ``s3``, ``gs``]. - We are working to add additional file system support. - load_from_den (bool): Whether to try loading the run from Den. (Default: ``True``) - dryrun (bool): Whether to create the Run if it doesn't exist, or load a Blob object as a dryrun. - (Default: ``False``) - **kwargs: Optional kwargs for the Run. - - Returns: - Run: The loaded Run object. - """ - if name and not any([path, system, kwargs]): - # Try reloading existing Run from RNS - return Run.from_name(name, load_from_den=load_from_den, dryrun=dryrun) - - if name and path is None and log_dest == "file": - path = ( - Run._base_cluster_folder_path(name=name) - if isinstance(system, Cluster) - else Run._base_local_folder_path(name=name) - ) - - system = _get_cluster_from( - system or _current_cluster(key="config") or Folder.DEFAULT_FS, dryrun=dryrun - ) - - run_obj = Run( - name=name, - log_dest=log_dest, - path=path, - system=system, - dryrun=dryrun, - **kwargs, - ) - - return run_obj diff --git a/runhouse/resources/resource.py b/runhouse/resources/resource.py index 9873c10b9..71093fde5 100644 --- a/runhouse/resources/resource.py +++ b/runhouse/resources/resource.py @@ -29,28 +29,21 @@ def __init__( self, name: Optional[str] = None, dryrun: bool = False, - provenance=None, - access_level: Optional[ResourceAccess] = ResourceAccess.WRITE, - visibility: Optional[ResourceVisibility] = ResourceVisibility.PRIVATE, + access_level: ResourceAccess = ResourceAccess.WRITE, + visibility: ResourceVisibility = ResourceVisibility.PRIVATE, **kwargs, ): """ Runhouse abstraction for objects that can be saved, shared, and reused. - Runhouse currently supports the following builtin Resource types: - - - Compute Abstractions - - Cluster :py:class:`.cluster.Cluster` - - Function :py:class:`.function.Function` - - Module :py:class:`.module.Module` - - Package :py:class:`.package.Package` - - Env: :py:class:`.env.Env` - - - Data Abstractions - - Folder :py:class:`.folder.Folder` - - - Secret Abstractions - - Secret :py:class:`.secret.Secret` + Args: + name (Optional[str], optional): Name to assign the resource. (Default: None) + dryrun (bool, optional): Whether to create the resource object, or load the object as a dryrun. + (Default: ``False``) + access_level (:obj:`ResourceAccess`, optional): Access level to provide for the resource. + (Default: ``ResourceAccess.WRITE``) + visibility (:obj:`ResourceVisibility`, optional): Type of visibility to provide for the resource. + (Default: ``ResourceVisibility.PRIVATE``) """ self._name, self._rns_folder = None, None if name is not None: @@ -63,16 +56,7 @@ def __init__( rns_client.resolve_rns_path(name) ) - from runhouse.resources.provenance import Run - self.dryrun = dryrun - # dryrun is true here so we don't spend time calling check on the server - # if we're just loading down the resource (e.g. with .remote) - self.provenance = ( - Run.from_config(provenance, dryrun=True) - if isinstance(provenance, Dict) - else provenance - ) self.access_level = access_level self._visibility = visibility @@ -88,7 +72,6 @@ def config(self, condensed=True): "name": self.rns_address or self.name, "resource_type": self.RESOURCE_TYPE, "resource_subtype": self.__class__.__name__, - "provenance": self.provenance.config if self.provenance else None, } self.save_attrs_to_config( config, @@ -125,10 +108,7 @@ def _resource_string_for_subconfig( @property def rns_address(self): """Traverse up the filesystem until reaching one of the directories in rns_base_folders, - then compute the relative path to that. - - Maybe later, account for folders along the path with a different RNS name.""" - + then compute the relative path to that.""" if ( self.name is None or self._rns_folder is None ): # Anonymous folders have no rns address @@ -222,7 +202,6 @@ def _compare_config_with_alt_options(cls, config, alt_options): with the options. If the child class returns a config, it's deciding to use the config and ignore the options (or somehow incorporate them, rarely). Note that if alt_options are provided and the config is not found, no error is raised, while if alt_options are not provided and the config is not found, an error is raised. - """ def str_dict_or_resource_to_str(val): @@ -255,13 +234,19 @@ def str_dict_or_resource_to_str(val): @classmethod def from_name( cls, - name, - load_from_den=True, - dryrun=False, - alt_options=None, - _resolve_children=True, + name: str, + load_from_den: bool = True, + dryrun: bool = False, + _alt_options: Dict = None, + _resolve_children: bool = True, ): - """Load existing Resource via its name.""" + """Load existing Resource via its name. + + Args: + name (str): Name of the resource to load from name. + load_from_den (bool, optional): Whether to try loading the module from Den. (Default: ``True``) + dryrun (bool, optional): Whether to construct the object or load as dryrun. (Default: ``False``) + """ # TODO is this the right priority order? from runhouse.resources.hardware.utils import _current_cluster @@ -270,8 +255,8 @@ def from_name( config = rns_client.load_config(name=name, load_from_den=load_from_den) - if alt_options: - config = cls._compare_config_with_alt_options(config, alt_options) + if _alt_options: + config = cls._compare_config_with_alt_options(config, _alt_options) if not config: return None if not config: @@ -290,7 +275,13 @@ def from_name( ) @staticmethod - def from_config(config, dryrun=False, _resolve_children=True): + def from_config(config: Dict, dryrun: bool = False, _resolve_children: bool = True): + """Load or construct resource from config. + + Args: + config (Dict): Resource config. + dryrun (bool, optional): Whether to construct resource or load as dryrun (Default: ``False``) + """ resource_type = config.pop("resource_type", None) dryrun = config.pop("dryrun", False) or dryrun @@ -324,7 +315,11 @@ def unname(self): def history(self, limit: int = None) -> List[Dict]: """Return the history of the resource, including specific config fields (e.g. folder path) and which runs have overwritten it. - If ``limit`` is specified, return the last ``limit`` number of entries in the history.""" + + Args: + limit (int, optional): If specified, return the last ``limit`` number of entries in the history. + Otherwise, return the entire history. (Default: ``None``) + """ if not self.rns_address: raise ValueError("Resource must have a name in order to have a history") @@ -379,10 +374,9 @@ def share( notify_users: bool = True, headers: Optional[Dict] = None, ) -> Tuple[Dict[str, ResourceAccess], Dict[str, ResourceAccess]]: - """Grant access to the resource for a list of users (or a single user). If a user has a Runhouse account they - will receive an email notifying them of their new access. If the user does not have a Runhouse account they will - also receive instructions on creating one, after which they will be able to have access to the Resource. If - ``visibility`` is set to ``public``, users will not be notified. + """Grant access to the resource for a list of users (or a single user). By default, the user will + receive an email notification of access (if they have a Runhouse account) or instructions on creating + an account to access the resource. If ``visibility`` is set to ``public``, users will not be notified. .. note:: You can only grant access to other users if you have write access to the resource. @@ -390,14 +384,14 @@ def share( Args: users (Union[str, list], optional): Single user or list of user emails and / or runhouse account usernames. If none are provided and ``visibility`` is set to ``public``, resource will be made publicly - available to all users. + available to all users. (Default: ``None``) access_level (:obj:`ResourceAccess`, optional): Access level to provide for the resource. - Defaults to ``read``. + (Default: ``read``). visibility (:obj:`ResourceVisibility`, optional): Type of visibility to provide for the shared - resource. Defaults to ``private``. + resource. By default, the visibility is private. (Default: ``None``) notify_users (bool, optional): Whether to send an email notification to users who have been given access. - Note: This is relevant for resources which are not ``shareable``. Defaults to ``True``. - headers (dict, optional): Request headers to provide for the request to RNS. Contains the user's auth token. + Note: This is relevant for resources which are not ``shareable``. (Default: ``True``) + headers (Dict, optional): Request headers to provide for the request to RNS. Contains the user's auth token. Example: ``{"Authorization": f"Bearer {token}"}`` Returns: @@ -478,7 +472,7 @@ def revoke( Args: users (Union[str, str], optional): List of user emails and / or runhouse account usernames - (or a single user). If no users are specified will revoke access for all users. + (or a single user). If no users are specified will revoke access for all users. (Default: ``None``) headers (Optional[Dict]): Request headers to provide for the request to RNS. Contains the user's auth token. Example: ``{"Authorization": f"Bearer {token}"}`` """ diff --git a/runhouse/resources/secrets/provider_secrets/api_key_secret.py b/runhouse/resources/secrets/provider_secrets/api_key_secret.py index af47804dd..b80889aa4 100644 --- a/runhouse/resources/secrets/provider_secrets/api_key_secret.py +++ b/runhouse/resources/secrets/provider_secrets/api_key_secret.py @@ -1,6 +1,5 @@ from typing import Dict, Optional, Union -from runhouse.resources.blobs.file import File from runhouse.resources.envs.env import Env from runhouse.resources.envs.env_factory import env as env_factory from runhouse.resources.hardware.cluster import Cluster @@ -19,7 +18,7 @@ def write( self, file: bool = False, env: bool = False, - path: Union[str, File] = None, + path: str = None, env_vars: Dict = None, overwrite: bool = False, ): @@ -32,7 +31,7 @@ def write( def to( self, system: Union[str, Cluster], - path: Union[str, File] = None, + path: str = None, env: Union[str, Env] = None, values: bool = True, name: Optional[str] = None, diff --git a/runhouse/resources/secrets/provider_secrets/aws_secret.py b/runhouse/resources/secrets/provider_secrets/aws_secret.py index 10b750b54..c0f2bdba9 100644 --- a/runhouse/resources/secrets/provider_secrets/aws_secret.py +++ b/runhouse/resources/secrets/provider_secrets/aws_secret.py @@ -1,14 +1,12 @@ import configparser import copy -import io import os -from pathlib import Path -from typing import Dict, Union +from typing import Dict -from runhouse.resources.blobs.file import File from runhouse.resources.secrets.provider_secrets.provider_secret import ProviderSecret from runhouse.resources.secrets.utils import _check_file_for_mismatches +from runhouse.utils import create_local_dir class AWSSecret(ProviderSecret): @@ -28,9 +26,7 @@ class AWSSecret(ProviderSecret): def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True): return AWSSecret(**config, dryrun=dryrun) - def _write_to_file( - self, path: Union[str, File], values: Dict, overwrite: bool = False - ): + def _write_to_file(self, path: str, values: Dict, overwrite: bool = False): new_secret = copy.deepcopy(self) if not _check_file_for_mismatches( @@ -51,31 +47,18 @@ def _write_to_file( value=values["secret_key"], ) - if isinstance(path, File): - # TODO: may be a better way of getting config parser data? - with io.StringIO() as ss: - parser.write(ss) - ss.seek(0) - data = ss.read() - path.write(data, serialize=False, mode="w") - else: - full_path = os.path.expanduser(path) - Path(full_path).parent.mkdir(parents=True, exist_ok=True) - with open(full_path, "w+") as f: - parser.write(f) - new_secret._add_to_rh_config(path) + full_path = create_local_dir(path) + with open(full_path, "w+") as f: + parser.write(f) + new_secret._add_to_rh_config(path) new_secret._values = None new_secret.path = path return new_secret - def _from_path(self, path: Union[str, File]): + def _from_path(self, path: str): config = configparser.ConfigParser() - if isinstance(path, File): - if not path.exists_in_system(): - return {} - config.read_string(path.fetch(deserialize=False, mode="r")) - elif path and os.path.exists(os.path.expanduser(path)): + if path and os.path.exists(os.path.expanduser(path)): config.read(os.path.expanduser(path)) else: return {} diff --git a/runhouse/resources/secrets/provider_secrets/azure_secret.py b/runhouse/resources/secrets/provider_secrets/azure_secret.py index d619d949a..1bc823c58 100644 --- a/runhouse/resources/secrets/provider_secrets/azure_secret.py +++ b/runhouse/resources/secrets/provider_secrets/azure_secret.py @@ -1,14 +1,12 @@ import configparser import copy -import io import os -from pathlib import Path -from typing import Dict, Union +from typing import Dict -from runhouse.resources.blobs.file import File from runhouse.resources.secrets.provider_secrets.provider_secret import ProviderSecret from runhouse.resources.secrets.utils import _check_file_for_mismatches +from runhouse.utils import create_local_dir class AzureSecret(ProviderSecret): @@ -28,7 +26,7 @@ def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = Tr def _write_to_file( self, - path: Union[str, File] = None, + path: str = None, values: Dict = None, overwrite: bool = False, ): @@ -47,30 +45,18 @@ def _write_to_file( value=subscription_id, ) - if isinstance(path, File): - with io.StringIO() as ss: - parser.write(ss) - ss.seek(0) - data = ss.read() - path.write(data, serialize=False, mode="w") - else: - full_path = os.path.expanduser(path) - Path(full_path).parent.mkdir(parents=True, exist_ok=True) - with open(full_path, "w") as f: - parser.write(f) - new_secret._add_to_rh_config(path) + full_path = create_local_dir(path) + with open(full_path, "w") as f: + parser.write(f) + new_secret._add_to_rh_config(path) new_secret._values = None new_secret.path = path return new_secret - def _from_path(self, path: Union[str, File]): + def _from_path(self, path: str = None): config = configparser.ConfigParser() - if isinstance(path, File): - if not path.exists_in_system(): - return {} - config.read_string(path.fetch(mode="r", deserialize=False)) - elif path and os.path.exists(os.path.expanduser(path)): + if path and os.path.exists(os.path.expanduser(path)): path = os.path.expanduser(path) config.read(path) if config and "AzureCloud" in config.sections(): diff --git a/runhouse/resources/secrets/provider_secrets/gcp_secret.py b/runhouse/resources/secrets/provider_secrets/gcp_secret.py index bd7e80f0c..2fa8066b1 100644 --- a/runhouse/resources/secrets/provider_secrets/gcp_secret.py +++ b/runhouse/resources/secrets/provider_secrets/gcp_secret.py @@ -3,9 +3,7 @@ import os from pathlib import Path -from typing import Dict, Union - -from runhouse.resources.blobs.file import File +from typing import Dict from runhouse.resources.secrets.provider_secrets.provider_secret import ProviderSecret from runhouse.resources.secrets.utils import _check_file_for_mismatches @@ -28,34 +26,23 @@ class GCPSecret(ProviderSecret): def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True): return GCPSecret(**config, dryrun=dryrun) - def _write_to_file( - self, path: Union[str, File], values: Dict = None, overwrite: bool = False - ): + def _write_to_file(self, path: str, values: Dict = None, overwrite: bool = False): new_secret = copy.deepcopy(self) if not _check_file_for_mismatches( path, self._from_path(path), values, overwrite ): - if isinstance(path, File): - data = json.dumps(values, indent=4) - path.write(data, serialize=False, mode="w") - else: - Path(path).parent.mkdir(parents=True, exist_ok=True) - with open(path, "w+") as f: - json.dump(values, f, indent=4) - new_secret._add_to_rh_config(path) + Path(path).parent.mkdir(parents=True, exist_ok=True) + with open(path, "w+") as f: + json.dump(values, f, indent=4) + new_secret._add_to_rh_config(path) new_secret._values = None new_secret.path = path return new_secret - def _from_path(self, path: Union[str, File]): + def _from_path(self, path: str = None): config = {} - if isinstance(path, File): - if not path.exists_in_system(): - return {} - contents = path.fetch(mode="r", deserialize=False) - config = json.loads(contents) - elif path and os.path.exists(os.path.expanduser(path)): + if path and os.path.exists(os.path.expanduser(path)): with open(os.path.expanduser(path), "r") as config_file: config = json.load(config_file) return config diff --git a/runhouse/resources/secrets/provider_secrets/github_secret.py b/runhouse/resources/secrets/provider_secrets/github_secret.py index 2372aed19..5777a3507 100644 --- a/runhouse/resources/secrets/provider_secrets/github_secret.py +++ b/runhouse/resources/secrets/provider_secrets/github_secret.py @@ -2,11 +2,10 @@ import os from pathlib import Path -from typing import Dict, Union +from typing import Dict import yaml -from runhouse.resources.blobs.file import File from runhouse.resources.secrets.provider_secrets.provider_secret import ProviderSecret from runhouse.resources.secrets.utils import _check_file_for_mismatches @@ -25,43 +24,31 @@ class GitHubSecret(ProviderSecret): def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True): return GitHubSecret(**config, dryrun=dryrun) - def _write_to_file( - self, path: Union[str, File], values: Dict = None, overwrite: bool = False - ): + def _write_to_file(self, path: str, values: Dict = None, overwrite: bool = False): new_secret = copy.deepcopy(self) if not _check_file_for_mismatches( path, self._from_path(path), values, overwrite ): config = {} - if isinstance(path, File): - if path.exists_in_system(): - config = path.fetch(deserialize=False, mode="r") - config["github.com"] = values - data = yaml.dump(config, default_flow_style=False) - path.write(data, serialize=False, mode="w") - else: - full_path = os.path.expanduser(path) - if Path(full_path).exists(): - with open(full_path, "r") as stream: - config = yaml.safe_load(stream) - config["github.com"] = values - Path(full_path).parent.mkdir(parents=True, exist_ok=True) - with open(full_path, "w") as yaml_file: - yaml.dump(config, yaml_file, default_flow_style=False) - new_secret._add_to_rh_config(path) + full_path = os.path.expanduser(path) + if Path(full_path).exists(): + with open(full_path, "r") as stream: + config = yaml.safe_load(stream) + config["github.com"] = values + + Path(full_path).parent.mkdir(parents=True, exist_ok=True) + with open(full_path, "w") as yaml_file: + yaml.dump(config, yaml_file, default_flow_style=False) + new_secret._add_to_rh_config(path) new_secret._values = None new_secret.path = path return new_secret - def _from_path(self, path: Union[str, File]): + def _from_path(self, path: str = None): config = {} - if isinstance(path, File): - if not path.exists_in_system(): - return {} - config = yaml.safe_load(path.fetch(mode="r", deserialize=False)) - elif path and os.path.exists(os.path.expanduser(path)): + if path and os.path.exists(os.path.expanduser(path)): with open(os.path.expanduser(path), "r") as stream: config = yaml.safe_load(stream) return config["github.com"] if config else {} diff --git a/runhouse/resources/secrets/provider_secrets/huggingface_secret.py b/runhouse/resources/secrets/provider_secrets/huggingface_secret.py index cbc225609..1b1373914 100644 --- a/runhouse/resources/secrets/provider_secrets/huggingface_secret.py +++ b/runhouse/resources/secrets/provider_secrets/huggingface_secret.py @@ -2,11 +2,11 @@ import os from pathlib import Path -from typing import Dict, Union +from typing import Dict -from runhouse.resources.blobs.file import File from runhouse.resources.secrets.provider_secrets.provider_secret import ProviderSecret from runhouse.resources.secrets.utils import _check_file_for_mismatches +from runhouse.utils import create_local_dir class HuggingFaceSecret(ProviderSecret): @@ -25,33 +25,24 @@ class HuggingFaceSecret(ProviderSecret): def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True): return HuggingFaceSecret(**config, dryrun=dryrun) - def _write_to_file( - self, path: Union[str, File], values: Dict = None, overwrite: bool = False - ): + def _write_to_file(self, path: str, values: Dict = None, overwrite: bool = False): new_secret = copy.deepcopy(self) if not _check_file_for_mismatches( path, self._from_path(path), values, overwrite ): token = values["token"] - if isinstance(path, File): - path.write(token, serialize=False, mode="w") - else: - full_path = os.path.expanduser(path) - Path(full_path).parent.mkdir(parents=True, exist_ok=True) - with open(full_path, "a") as f: - f.write(token) - new_secret._add_to_rh_config(path) + full_path = create_local_dir(path) + with open(full_path, "a") as f: + f.write(token) + new_secret._add_to_rh_config(path) new_secret._values = None new_secret.path = path return new_secret - def _from_path(self, path: Union[str, File]): + def _from_path(self, path: str = None): token = None - if isinstance(path, File): - if path.exists_in_system(): - token = path.fetch(mode="r", deserialize=False).strip("\n") - elif path and os.path.exists(os.path.expanduser(path)): + if path and os.path.exists(os.path.expanduser(path)): token = Path(os.path.expanduser(path)).read_text().strip("\n") if token: return {"token": token} diff --git a/runhouse/resources/secrets/provider_secrets/lambda_secret.py b/runhouse/resources/secrets/provider_secrets/lambda_secret.py index 50fd8aa9a..46242e460 100644 --- a/runhouse/resources/secrets/provider_secrets/lambda_secret.py +++ b/runhouse/resources/secrets/provider_secrets/lambda_secret.py @@ -1,13 +1,11 @@ import copy import os -from pathlib import Path -from typing import Dict, Union - -from runhouse.resources.blobs.file import File +from typing import Dict from runhouse.resources.secrets.provider_secrets.provider_secret import ProviderSecret from runhouse.resources.secrets.utils import _check_file_for_mismatches +from runhouse.utils import create_local_dir class LambdaSecret(ProviderSecret): @@ -24,33 +22,24 @@ class LambdaSecret(ProviderSecret): def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True): return LambdaSecret(**config, dryrun=dryrun) - def _write_to_file( - self, path: Union[str, File], values: Dict = None, overwrite: bool = False - ): + def _write_to_file(self, path: str, values: Dict = None, overwrite: bool = False): new_secret = copy.deepcopy(self) if not _check_file_for_mismatches( path, self._from_path(path), values, overwrite ): data = f'api_key = {values["api_key"]}\n' - if isinstance(path, File): - path.write(data, serialize=False, mode="w") - else: - full_path = os.path.expanduser(path) - Path(full_path).parent.mkdir(parents=True, exist_ok=True) - with open(full_path, "w+") as f: - f.write(data) - new_secret._add_to_rh_config(path) + full_path = create_local_dir(path) + with open(full_path, "w+") as f: + f.write(data) + new_secret._add_to_rh_config(path) new_secret._values = None new_secret.path = path return new_secret - def _from_path(self, path: Union[str, File]): + def _from_path(self, path: str = None): lines = None - if isinstance(path, File): - if path.exists_in_system(): - lines = path.fetch(mode="r", deserialize=False).split("\n") - elif path and os.path.exists(os.path.expanduser(path)): + if path and os.path.exists(os.path.expanduser(path)): with open(os.path.expanduser(path), "r") as f: lines = f.readlines() if lines: diff --git a/runhouse/resources/secrets/provider_secrets/provider_secret.py b/runhouse/resources/secrets/provider_secrets/provider_secret.py index 0f173bc54..700b7059a 100644 --- a/runhouse/resources/secrets/provider_secrets/provider_secret.py +++ b/runhouse/resources/secrets/provider_secrets/provider_secret.py @@ -5,13 +5,12 @@ from typing import Any, Dict, Optional, Union from runhouse.globals import configs, rns_client -from runhouse.resources.blobs import file -from runhouse.resources.blobs.file import File from runhouse.resources.envs.env import Env from runhouse.resources.hardware.cluster import Cluster from runhouse.resources.hardware.utils import _get_cluster_from from runhouse.resources.secrets.secret import Secret from runhouse.resources.secrets.utils import _check_file_for_mismatches +from runhouse.utils import create_local_dir class ProviderSecret(Secret): @@ -95,13 +94,8 @@ def delete(self, headers: Optional[Dict] = None, contents: bool = False): """Delete the secret config from Den and from Vault/local. Optionally also delete contents of secret file or env vars.""" headers = headers or rns_client.request_headers() - if self.path and contents: - if isinstance(self.path, File): - if self.path.exists_in_system(): - self.path.rm() - else: - if os.path.exists(os.path.expanduser(self.path)): - os.remove(os.path.expanduser(self.path)) + if self.path and contents and os.path.exists(os.path.expanduser(self.path)): + os.remove(os.path.expanduser(self.path)) elif self.env_vars and contents: for (_, env_var) in self.env_vars.keys(): if env_var in os.environ: @@ -110,7 +104,7 @@ def delete(self, headers: Optional[Dict] = None, contents: bool = False): def write( self, - path: Union[str, File] = None, + path: str = None, env_vars: Dict = None, file: bool = False, env: bool = False, @@ -133,7 +127,7 @@ def write( def to( self, system: Union[str, Cluster], - path: Union[str, File] = None, + path: str = None, env: Union[str, Env] = None, values: bool = None, name: Optional[str] = None, @@ -206,30 +200,21 @@ def _file_to( self, key: str, system: Union[str, Cluster], - path: Union[str, File] = None, + path: str = None, values: Any = None, ): - if isinstance(path, File): - path = path.path system.call(key, "_write_to_file", path=path, values=values) - remote_file = file(path=path, system=system) - return remote_file + return path - def _write_to_file( - self, path: Union[str, File], values: Any, overwrite: bool = False - ): + def _write_to_file(self, path: str, values: Any, overwrite: bool = False): new_secret = copy.deepcopy(self) if not _check_file_for_mismatches( path, self._from_path(path), values, overwrite ): - if isinstance(path, File): - path.write(data=values, mode="w") - else: - full_path = os.path.expanduser(path) - os.makedirs(os.path.dirname(full_path), exist_ok=True) - with open(full_path, "w") as f: - json.dump(values, f, indent=4) - self._add_to_rh_config(path) + full_path = create_local_dir(path) + with open(full_path, "w") as f: + json.dump(values, f, indent=4) + self._add_to_rh_config(path) new_secret._values = None new_secret.path = path @@ -265,26 +250,19 @@ def _from_env(self, env_vars: Dict = None): return {} return values - def _from_path(self, path: Union[str, File] = None): + def _from_path(self, path: str = None): path = path or self.path if not path: return "" - if isinstance(path, File): - contents = path.fetch(mode="r", deserialize=False) - try: - return json.loads(contents) - except json.decoder.JSONDecodeError: + path = os.path.expanduser(path) + if os.path.exists(path): + with open(path) as f: + try: + contents = json.load(f) + except json.decoder.JSONDecodeError: + contents = f.read() return contents - else: - path = os.path.expanduser(path) - if os.path.exists(path): - with open(path) as f: - try: - contents = json.load(f) - except json.decoder.JSONDecodeError: - contents = f.read() - return contents return {} @staticmethod diff --git a/runhouse/resources/secrets/provider_secrets/ssh_secret.py b/runhouse/resources/secrets/provider_secrets/ssh_secret.py index 1b3ea286c..91bd33147 100644 --- a/runhouse/resources/secrets/provider_secrets/ssh_secret.py +++ b/runhouse/resources/secrets/provider_secrets/ssh_secret.py @@ -6,7 +6,6 @@ from runhouse.globals import rns_client from runhouse.logger import get_logger -from runhouse.resources.blobs.file import File from runhouse.resources.hardware.cluster import Cluster from runhouse.resources.secrets.provider_secrets.provider_secret import ProviderSecret @@ -64,9 +63,7 @@ def save( folder=folder, ) - def _write_to_file( - self, path: Union[str, File], values: Dict = None, overwrite: bool = False - ): + def _write_to_file(self, path: str, values: Dict = None, overwrite: bool = False): priv_key_path = path priv_key_path = Path(os.path.expanduser(priv_key_path)) @@ -105,18 +102,10 @@ def _write_to_file( return new_secret - def _from_path(self, path: Union[str, File]): + def _from_path(self, path: str): if path == self._DEFAULT_CREDENTIALS_PATH: path = f"{self._DEFAULT_CREDENTIALS_PATH}/{self.key}" - if isinstance(path, File): - from runhouse.resources.blobs.file import file - - priv_key = path.fetch(mode="r", deserialize=False) - pub_key_file = file(path=f"{path.path}.pub", system=path.system) - pub_key = pub_key_file.fetch(mode="r", deserialize=False) - return {"public_key": pub_key, "private_key": priv_key} - return self.extract_secrets_from_path(path) @staticmethod @@ -139,16 +128,12 @@ def _file_to( path: Union[str, Path] = None, values: Any = None, ): - from runhouse.resources.blobs.file import file - if self.path: - pub_key_path = ( - f"{path.path}.pub" if isinstance(path, File) else f"{path}.pub" - ) - remote_priv_file = file(path=self.path).to(system, path=path) - file(path=pub_key_path).to(system, path=pub_key_path) + remote_priv_file = self.path + # pub_key_path = f"{path}.pub" + system.call(key, "_write_to_file", path=remote_priv_file, values=values) system.run([f"chmod 600 {path}"]) else: system.call(key, "_write_to_file", path=path, values=values) - remote_priv_file = file(path=path, system=system) + remote_priv_file = path return remote_priv_file diff --git a/runhouse/resources/secrets/secret.py b/runhouse/resources/secrets/secret.py index 189af8a28..c97086d84 100644 --- a/runhouse/resources/secrets/secret.py +++ b/runhouse/resources/secrets/secret.py @@ -93,8 +93,7 @@ def _write_shared_secret_to_local(config): ) @staticmethod - def from_config(config: dict, dryrun: bool = False, _resolve_children=True): - """Create a Secret object from a config dictionary.""" + def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True): if "provider" in config: from runhouse.resources.secrets.provider_secrets.providers import ( _get_provider_class, @@ -116,12 +115,11 @@ def from_config(config: dict, dryrun: bool = False, _resolve_children=True): def from_name( cls, name, - load_from_den=True, - dryrun=False, - alt_options=None, - _resolve_children=True, + load_from_den: bool = True, + dryrun: bool = False, + _alt_options: Dict = None, + _resolve_children: bool = True, ): - """Load existing Secret via its name.""" try: config = load_config(name, cls.USER_ENDPOINT) if config: @@ -139,7 +137,12 @@ def from_name( @classmethod def builtin_providers(cls, as_str: bool = False) -> List: - """Return list of all Runhouse providers (as class objects) supported out of the box.""" + """Return list of all Runhouse providers (as class objects) supported out of the box. + + Args: + as_str (bool, optional): Whether to return the providers as a string or as a class. + (Default: ``False``) + """ from runhouse.resources.secrets.provider_secrets.providers import ( _str_to_provider_class, ) @@ -167,6 +170,12 @@ def vault_secrets(cls, headers: Optional[Dict] = None) -> List[str]: @classmethod def local_secrets(cls, names: List[str] = None) -> Dict[str, "Secret"]: + """Get list of local secrets. + + Args: + names (List[str], optional): Specific names of local secrets to retrieve. If ``None``, returns all + locally detected secrets. (Default: ``None``) + """ if not os.path.exists(os.path.expanduser("~/.rh/secrets")): return {} @@ -194,6 +203,12 @@ def local_secrets(cls, names: List[str] = None) -> Dict[str, "Secret"]: @classmethod def extract_provider_secrets(cls, names: List[str] = None) -> Dict[str, "Secret"]: + """Extract secret values from providers. Returns a Dict mapping the provider name to Secret. + + Args: + names (List[str]): List of provider names to extract secrets for. If ``None``, returns + secrets for all detected providers. (Default: ``None``) + """ from runhouse.resources.secrets.provider_secrets.providers import ( _str_to_provider_class, ) @@ -237,8 +252,16 @@ def save( ): """ Save the secret config to Den. Save the secret values into Vault if the user is logged in, - or to local if not or if the resource is a local resource. If a folder is specified, save the secret - to that folder in Den (e.g. saving secrets for a cluster associated with an organization). + or to local if not or if the resource is a local resource. + + Args: + name (str, optional): Name to save the secret resource as. + save_values (str, optional): Whether to save the values of the secret to Vault in addition + to saving the metadata to Den. (Default: ``True``) + headers (Dict, optional): Request headers to provide for the request to RNS. Contains the + user's auth token. Example: ``{"Authorization": f"Bearer {token}"}`` (Default: ``None``) + folder (str, optional): If specified, save the secret to that folder in Den (e.g. saving secrets + for a cluster associated with an organization). (Default: ``None``) """ if name: self.name = name @@ -348,7 +371,8 @@ def to( Args: system (str or Cluster): Cluster to send the secret to - name (str, ooptional): Name to assign the resource on the cluster. + name (str, optional): Name to assign the resource on the cluster. + env (Env, optional): Env to send the secret to. Example: >>> secret.to(my_cluster, path=secret.path) diff --git a/runhouse/resources/secrets/secret_factory.py b/runhouse/resources/secrets/secret_factory.py index 3c7a48b19..a5a4adbb8 100644 --- a/runhouse/resources/secrets/secret_factory.py +++ b/runhouse/resources/secrets/secret_factory.py @@ -1,6 +1,5 @@ -from typing import Dict, Optional, Union +from typing import Dict, Optional -from runhouse.resources.blobs.file import File from runhouse.resources.secrets.provider_secrets.provider_secret import ProviderSecret from runhouse.resources.secrets.secret import Secret @@ -45,7 +44,7 @@ def provider_secret( provider: Optional[str] = None, name: Optional[str] = None, values: Optional[Dict] = None, - path: Union[str, File] = None, + path: Optional[str] = None, env_vars: Optional[Dict] = None, load_from_den: bool = True, dryrun: bool = False, @@ -61,11 +60,11 @@ def provider_secret( name (str, optional): Name to assign the resource. If none is provided, resource name defaults to the provider name. values (Dict, optional): Dictionary mapping of secret keys and values. - path (str or Path, optional): Path where the secret values are held. + path (str, optional): Path where the secret values are held. env_vars (Dict, optional): Dictionary mapping secret keys to the corresponding environment variable key. load_from_den (bool): Whether to try loading the secret from Den. (Default: ``True``) - dryrun (bool): Whether to creat in dryrun mode. (Default: False) + dryrun (bool): Whether to create in dryrun mode. (Default: False) Returns: ProviderSecret: The resulting provider secret object. diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 649fb0b1d..4720702f4 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -6,13 +6,19 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union import httpx + +import psutil +import pynvml import requests import runhouse from runhouse.constants import ( DEFAULT_STATUS_CHECK_INTERVAL, + GPU_COLLECTION_INTERVAL, INCREASED_STATUS_CHECK_INTERVAL, + MAX_GPU_INFO_LEN, + REDUCED_GPU_INFO_LEN, SERVER_LOGFILE, SERVER_LOGS_FILE_NAME, ) @@ -25,7 +31,13 @@ from runhouse.rns.utils.api import ResourceAccess from runhouse.servers.autostop_helper import AutostopHelper from runhouse.servers.http.auth import AuthCache -from runhouse.utils import ColoredFormatter, sync_function +from runhouse.utils import ( + ColoredFormatter, + get_gpu_usage, + get_pid, + ServletType, + sync_function, +) logger = get_logger(__name__) @@ -47,6 +59,8 @@ async def __init__( self.cluster_config: Optional[Dict[str, Any]] = ( cluster_config if cluster_config else {} ) + self.cluster_config["has_cuda"] = detect_cuda_version_or_cpu() != "cpu" + self._initialized_env_servlet_names: Set[str] = set() self._key_to_env_servlet_name: Dict[Any, str] = {} self._auth_cache: AuthCache = AuthCache(cluster_config) @@ -65,6 +79,19 @@ async def __init__( self._api_server_url = self.cluster_config.get( "api_server_url", rns_client.api_server_url ) + self.pid = get_pid() + self.process = psutil.Process(pid=self.pid) + self.gpu_metrics = None # will be updated only if this is a gpu cluster. + self.lock = ( + threading.Lock() + ) # will be used when self.gpu_metrics will be updated by different threads. + + if self.cluster_config.get("has_cuda"): + logger.debug("Creating _periodic_gpu_check thread.") + collect_gpu_thread = threading.Thread( + target=self._periodic_gpu_check, daemon=True + ) + collect_gpu_thread.start() logger.info("Creating periodic_cluster_checks thread.") cluster_checks_thread = threading.Thread( @@ -289,7 +316,9 @@ async def aperiodic_cluster_checks(self): "Cluster has not yet been saved to Den, cannot update status or logs." ) elif status_code != 200: - logger.error("Failed to send cluster status to Den") + logger.error( + f"Failed to send cluster status to Den, status_code: {status_code}" + ) else: logger.debug("Successfully sent cluster status to Den") @@ -382,42 +411,79 @@ async def _status_for_env_servlet(self, env_servlet_name): except Exception as e: return {"env_servlet_name": env_servlet_name, "Exception": e} + async def _aperiodic_gpu_check(self): + """periodically collects cluster gpu usage""" + + pynvml.nvmlInit() # init nvidia ml info collection + + while True: + try: + + gpu_count = pynvml.nvmlDeviceGetCount() + with self.lock: + if not self.gpu_metrics: + self.gpu_metrics = {device: [] for device in range(gpu_count)} + + for gpu_index in range(gpu_count): + handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_index) + util_info = pynvml.nvmlDeviceGetUtilizationRates(handle) + memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + + total_memory = memory_info.total # in bytes + used_memory = memory_info.used # in bytes + free_memory = memory_info.free # in bytes + utilization_percent = util_info.gpu / 1.0 # make it float + + # to reduce cluster memory usage (we are saving the gpu_usage info on the cluster), + # we save only the most updated gpu usage. If for some reason the size of updated_gpu_info is + # too big, we remove the older gpu usage info. + # This is relevant when using cluster.status() directly and not relying on status being sent to den. + updated_gpu_info = self.gpu_metrics[gpu_index] + if len(updated_gpu_info) + 1 > MAX_GPU_INFO_LEN: + updated_gpu_info = updated_gpu_info[REDUCED_GPU_INFO_LEN:] + updated_gpu_info.append( + { + "total_memory": total_memory, + "used_memory": used_memory, + "free_memory": free_memory, + "utilization_percent": utilization_percent, + } + ) + self.gpu_metrics[gpu_index] = updated_gpu_info + except Exception as e: + logger.error(str(e)) + pynvml.nvmlShutdown() + break + finally: + # collects gpu usage every 5 seconds. + await asyncio.sleep(GPU_COLLECTION_INTERVAL) + + def _periodic_gpu_check(self): + # This is only ever called once in its own thread, so we can do asyncio.run here instead of + # sync_function. + asyncio.run(self._aperiodic_gpu_check()) + def _get_node_gpu_usage(self, server_pid: int): - import subprocess - - gpu_general_info = ( - subprocess.run( - [ - "nvidia-smi", - "--query-gpu=memory.total,memory.used,memory.free,count,utilization.gpu", - "--format=csv,noheader,nounits", - ], - stdout=subprocess.PIPE, - ) - .stdout.decode("utf-8") - .strip() - .split(", ") + + # currently works correctly for a single node GPU. Multinode-clusters will be supported shortly. + + collected_gpus_info = copy.deepcopy(self.gpu_metrics) + + if collected_gpus_info is None or not collected_gpus_info[0]: + return None + + cluster_gpu_usage = get_gpu_usage( + collected_gpus_info=collected_gpus_info, servlet_type=ServletType.cluster ) - total_gpu_memory = int(gpu_general_info[0]) * (1024**2) # in bytes - total_used_memory = int(gpu_general_info[1]) * (1024**2) # in bytes - free_memory = int(gpu_general_info[2]) * (1024**2) # in bytes - gpu_count = int(gpu_general_info[3]) - gpu_utilization_percent = int(gpu_general_info[4]) / 100 - - return { - "total_memory": total_gpu_memory, - "used_memory": total_used_memory, - "free_memory": free_memory, - "gpu_count": gpu_count, - "utilization_percent": gpu_utilization_percent, - "server_pid": server_pid, # will be useful for multi-node clusters. - } + cluster_gpu_usage[ + "server_pid" + ] = server_pid # will be useful for multi-node clusters. + + return cluster_gpu_usage async def astatus(self, send_to_den: bool = False) -> Tuple[Dict, Optional[int]]: import psutil - from runhouse.utils import get_pid - config_cluster = copy.deepcopy(self.cluster_config) # Popping out creds because we don't want to show them in the status @@ -453,7 +519,7 @@ async def astatus(self, send_to_den: bool = False) -> Tuple[Dict, Optional[int]] env_servlet_utilization_data[env_servlet_name] = env_memory_info # TODO: decide if we need this info at all: cpu_usage, memory_usage, disk_usage - cpu_utilization = psutil.cpu_percent(interval=1) + cpu_utilization = psutil.cpu_percent(interval=0) # A dictionary that match the keys of psutil.virtual_memory()._asdict() to match the keys we expect in Den. relevant_memory_info = { @@ -472,11 +538,9 @@ async def astatus(self, send_to_den: bool = False) -> Tuple[Dict, Optional[int]] for k in relevant_memory_info.keys() } - server_pid: int = get_pid() - # get general gpu usage server_gpu_usage = ( - self._get_node_gpu_usage(server_pid) + self._get_node_gpu_usage(self.pid) if self.cluster_config.get("has_cuda", False) else None ) @@ -486,10 +550,16 @@ async def astatus(self, send_to_den: bool = False) -> Tuple[Dict, Optional[int]] else None ) + # rest the gpu_info only after the status was sent to den. If we should not send status to den, + # self.gpu_metrics will not be updated at all, therefore should not be reset. + if send_to_den: + with self.lock: + self.gpu_metrics = None + status_data = { "cluster_config": config_cluster, "runhouse_version": runhouse.__version__, - "server_pid": server_pid, + "server_pid": self.pid, "env_servlet_processes": env_servlet_utilization_data, "server_cpu_utilization": cpu_utilization, "server_gpu_utilization": gpu_utilization, diff --git a/runhouse/servers/env_servlet.py b/runhouse/servers/env_servlet.py index 9b5f57bd9..d15c8ce50 100644 --- a/runhouse/servers/env_servlet.py +++ b/runhouse/servers/env_servlet.py @@ -1,11 +1,28 @@ +import copy import os +import threading +import time import traceback from functools import wraps from typing import Any, Dict, Optional +import psutil +import pynvml + +from runhouse import configs + +from runhouse.constants import ( + DEFAULT_STATUS_CHECK_INTERVAL, + GPU_COLLECTION_INTERVAL, + MAX_GPU_INFO_LEN, + REDUCED_GPU_INFO_LEN, +) + from runhouse.globals import obj_store from runhouse.logger import get_logger +from runhouse.resources.hardware.utils import detect_cuda_version_or_cpu + from runhouse.servers.http.http_utils import ( deserialize_data, handle_exception_response, @@ -14,7 +31,14 @@ serialize_data, ) from runhouse.servers.obj_store import ClusterServletSetupOption -from runhouse.utils import arun_in_thread, get_node_ip + +from runhouse.utils import ( + arun_in_thread, + get_gpu_usage, + get_node_ip, + get_pid, + ServletType, +) logger = get_logger(__name__) @@ -92,6 +116,21 @@ async def __init__(self, env_name: str, *args, **kwargs): self.output_types = {} self.thread_ids = {} + self.pid = get_pid() + self.process = psutil.Process(pid=self.pid) + + self.gpu_metrics = None # will be updated only if this is a gpu cluster. + self.lock = ( + threading.Lock() + ) # will be used when self.gpu_metrics will be updated by different threads. + + if detect_cuda_version_or_cpu() != "cpu": + logger.debug("Creating _periodic_gpu_check thread.") + collect_gpu_thread = threading.Thread( + target=self._collect_env_gpu_usage, daemon=True + ) + collect_gpu_thread.start() + ############################################################## # Methods to disable or enable den auth ############################################################## @@ -185,15 +224,8 @@ async def aclear_local(self): def _get_env_cpu_usage(self, cluster_config: dict = None): - import psutil - - from runhouse.utils import get_pid - - cluster_config = cluster_config or obj_store.cluster_config - total_memory = psutil.virtual_memory().total node_ip = get_node_ip() - env_servlet_pid = get_pid() if not cluster_config.get("resource_subtype") == "Cluster": stable_internal_external_ips = cluster_config.get( @@ -216,9 +248,9 @@ def _get_env_cpu_usage(self, cluster_config: dict = None): node_name = f"worker_{ips.index(node_ip)} ({node_ip})" try: - env_servlet_process = psutil.Process(pid=env_servlet_pid) - memory_size_bytes = env_servlet_process.memory_full_info().uss - cpu_usage_percent = env_servlet_process.cpu_percent() + + memory_size_bytes = self.process.memory_full_info().uss + cpu_usage_percent = self.process.cpu_percent(interval=0) env_memory_usage = { "used_memory": memory_size_bytes, "utilization_percent": cpu_usage_percent, @@ -227,61 +259,70 @@ def _get_env_cpu_usage(self, cluster_config: dict = None): except psutil.NoSuchProcess: env_memory_usage = {} - return (env_memory_usage, node_name, total_memory, env_servlet_pid, node_ip) + return (env_memory_usage, node_name, total_memory, self.pid, node_ip) - def _get_env_gpu_usage(self, env_servlet_pid: int): - import subprocess + def _get_env_gpu_usage(self): + # currently works correctly for a single node GPU. Multinode-clusters will be supported shortly. - try: + collected_gpus_info = copy.deepcopy(self.gpu_metrics) - gpu_general_info = ( - subprocess.run( - [ - "nvidia-smi", - "--query-gpu=memory.total", - "--format=csv,noheader,nounits", - ], - stdout=subprocess.PIPE, - ) - .stdout.decode("utf-8") - .strip() - .split(", ") - ) - total_gpu_memory = int(gpu_general_info[0]) * (1024**2) # in bytes - - env_used_memory = 0 # in bytes - - env_gpu_usage = ( - subprocess.run( - [ - "nvidia-smi", - "--query-compute-apps=pid,gpu_uuid,used_memory", - "--format=csv,nounits", - ], - stdout=subprocess.PIPE, - ) - .stdout.decode("utf-8") - .strip() - .split("\n") - ) - for i in range(1, len(env_gpu_usage)): - single_env_gpu_info = env_gpu_usage[i].strip().split(", ") - if int(single_env_gpu_info[0]) == env_servlet_pid: - env_used_memory = env_used_memory + int(single_env_gpu_info[-1]) * ( - 1024**2 - ) - if env_used_memory > 0: - env_gpu_usage = { - "used_memory": env_used_memory, # in bytes - "total_memory": total_gpu_memory, # in bytes - } - else: - env_gpu_usage = {} - except subprocess.CalledProcessError as e: - logger.error(f"Failed to get GPU usage for {self.env_name}: {e}") - env_gpu_usage = {} + if collected_gpus_info is None or not collected_gpus_info[0]: + return None + + return get_gpu_usage( + collected_gpus_info=collected_gpus_info, servlet_type=ServletType.env + ) - return env_gpu_usage + def _collect_env_gpu_usage(self): + """periodically collects env gpu usage""" + + pynvml.nvmlInit() # init nvidia ml info collection + + while True: + try: + gpu_count = pynvml.nvmlDeviceGetCount() + with self.lock: + if not self.gpu_metrics: + self.gpu_metrics: Dict[int, list[Dict[str, int]]] = { + device: [] for device in range(gpu_count) + } + + for gpu_index in range(gpu_count): + handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_index) + processes = pynvml.nvmlDeviceGetComputeRunningProcesses_v3( + handle + ) + memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + if processes: + for p in processes: + if p.pid == self.pid: + used_memory = p.usedGpuMemory # in bytes + total_memory = memory_info.total # in bytes + current_gpu_metrics: list[ + Dict[str, int] + ] = self.gpu_metrics[gpu_index] + # to reduce cluster memory usage (we are saving the gpu_usage info on the cluster), + # we save only the most updated gpu usage. If for some reason the size of updated_gpu_info is + # too big, we remove the older gpu usage info. + # This is relevant when using cluster.status() directly and not relying on status being sent to den. + if len(current_gpu_metrics) + 1 > MAX_GPU_INFO_LEN: + current_gpu_metrics = current_gpu_metrics[ + REDUCED_GPU_INFO_LEN: + ] + current_gpu_metrics.append( + { + "used_memory": used_memory, + "total_memory": total_memory, + } + ) + self.gpu_metrics[gpu_index] = current_gpu_metrics + except Exception as e: + logger.error(str(e)) + pynvml.nvmlShutdown() + break + finally: + # collects gpu usage every 5 seconds. + time.sleep(GPU_COLLECTION_INTERVAL) def _status_local_helper(self): objects_in_env_servlet = obj_store.keys_with_info() @@ -297,17 +338,32 @@ def _status_local_helper(self): # Try loading GPU data (if relevant) env_gpu_usage = ( - self._get_env_gpu_usage(int(env_servlet_pid)) - if cluster_config.get("has_cuda", False) - else {} + self._get_env_gpu_usage() if cluster_config.get("has_cuda", False) else {} ) + cluster_config = obj_store.cluster_config + interval_size = cluster_config.get( + "status_check_interval", DEFAULT_STATUS_CHECK_INTERVAL + ) + + # TODO: [sb]: once introduced, we could use ClusterServlet _cluster_periodic_thread_alive() to replace the + # 'should_send_status_and_logs_to_den' logic below. + # Only if one of these is true, do we actually need to get the status from each EnvServlet + should_send_status_and_logs_to_den: bool = ( + configs.token is not None and interval_size != -1 + ) + + # reset the gpu_info only if the current env_gpu collection will be sent to den. Otherwise, keep collecting it. + if should_send_status_and_logs_to_den: + with self.lock: + self.gpu_metrics = None + env_servlet_utilization_data = { "env_gpu_usage": env_gpu_usage, "node_ip": node_ip, "node_name": node_name, - "pid": env_servlet_pid, "env_cpu_usage": env_memory_usage, + "pid": env_servlet_pid, } return objects_in_env_servlet, env_servlet_utilization_data diff --git a/runhouse/servers/http/http_utils.py b/runhouse/servers/http/http_utils.py index a6f6a7a64..47a5cffbc 100644 --- a/runhouse/servers/http/http_utils.py +++ b/runhouse/servers/http/http_utils.py @@ -94,6 +94,7 @@ class OutputType: class FolderParams(BaseModel): path: str + is_file: bool = False @field_validator("path", mode="before") def convert_path_to_string(cls, v): @@ -533,8 +534,11 @@ def folder_mv(src_path: Path, dest_path: str, overwrite: bool): def folder_exists(path: Path): + folder_exists_resp = path.exists() + if not path.is_file(): + folder_exists_resp = folder_exists_resp and path.is_dir() return Response( - data=path.exists() and path.is_dir(), + data=folder_exists_resp, output_type=OutputType.RESULT_SERIALIZED, serialization=None, ) diff --git a/runhouse/utils.py b/runhouse/utils.py index a37ab678b..833a28f4f 100644 --- a/runhouse/utils.py +++ b/runhouse/utils.py @@ -2,6 +2,7 @@ import contextvars import functools import logging +from io import StringIO try: import importlib.metadata as metadata @@ -25,6 +26,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import datetime +from enum import Enum from pathlib import Path from typing import Callable, Optional, Type, Union @@ -34,6 +36,8 @@ from runhouse.logger import get_logger, init_logger logger = get_logger(__name__) + + #################################################################################################### # Python package utilities #################################################################################################### @@ -375,6 +379,54 @@ def __getattr__(self, item): return getattr(self.instream, item) +class capture_stdout: + """Context manager for capturing stdout to a file, list, or stream, while still printing to stdout.""" + + def __init__(self, output=None): + self.output = output + self._stream = None + + def __enter__(self): + if self.output is None: + self.output = StringIO() + + if isinstance(self.output, str): + self._stream = open(self.output, "w") + else: + self._stream = self.output + sys.stdout = StreamTee(sys.stdout, [self]) + sys.stderr = StreamTee(sys.stderr, [self]) + return self + + def write(self, message): + self._stream.write(message) + + def flush(self): + self._stream.flush() + + @property + def stream(self): + if isinstance(self.output, str): + return open(self.output, "r") + return self._stream + + def list(self): + if isinstance(self.output, str): + return self.stream.readlines() + return (self.stream.getvalue() or "").splitlines() + + def __str__(self): + return self.stream.getvalue() + + def __exit__(self, exc_type, exc_val, exc_tb): + if hasattr(sys.stdout, "instream"): + sys.stdout = sys.stdout.instream + if hasattr(sys.stderr, "instream"): + sys.stderr = sys.stderr.instream + self._stream.close() + return False + + class LogToFolder: def __init__(self, name: str): self.name = name @@ -544,3 +596,59 @@ def format(self, output_type): self._display_title = True return system_color, reset_color + + +def create_local_dir(path: Union[str, Path]): + full_path = os.path.expanduser(path) if isinstance(path, str) else path.expanduser() + Path(full_path).parent.mkdir(parents=True, exist_ok=True) + return full_path + + +#################################################################################################### +# Status collection utils +#################################################################################################### +class ServletType(str, Enum): + env = "env" + cluster = "cluster" + + +def get_gpu_usage(collected_gpus_info: dict, servlet_type: ServletType): + + gpus_indices = list(collected_gpus_info.keys()) + + # how we retrieve total_gpu_memory: + # 1. getting the first gpu usage of the first gpu un the gpus list + # 2. getting the first gpu_info dictionary of the specific gpu (we collected the gpu info over time) + # 3. get total_memory value (it is the same across all envs) + total_gpu_memory = collected_gpus_info[gpus_indices[0]][0].get("total_memory") + total_used_memory, gpu_utilization_percent, free_memory = 0, 0, 0 + + if servlet_type == ServletType.cluster: + free_memory = collected_gpus_info[gpus_indices[0]][-1].get( + "free_memory" + ) # getting the latest free_memory value collected. + + for gpu_index in gpus_indices: + collected_gpu_info = collected_gpus_info.get(gpu_index) + sum_used_memery = sum( + [gpu_info.get("used_memory") for gpu_info in collected_gpu_info] + ) + total_used_memory = sum_used_memery / len(collected_gpu_info) # average + + if servlet_type == ServletType.cluster: + sum_cpu_util = sum( + [gpu_info.get("utilization_percent") for gpu_info in collected_gpu_info] + ) + gpu_utilization_percent = sum_cpu_util / len(collected_gpu_info) # average + + total_used_memory = total_used_memory / len(gpus_indices) + + gpu_usage = {"total_memory": total_gpu_memory, "used_memory": total_used_memory} + + if servlet_type == ServletType.cluster: + gpu_utilization_percent = round(gpu_utilization_percent / len(gpus_indices), 2) + gpu_usage["free_memory"] = free_memory + gpu_usage["gpu_count"] = len(gpus_indices) + gpu_usage["utilization_percent"] = gpu_utilization_percent + + return gpu_usage diff --git a/setup.py b/setup.py index c3fa8b020..edd99dfa0 100644 --- a/setup.py +++ b/setup.py @@ -80,6 +80,7 @@ def parse_readme(readme: str) -> str: "apispec", "httpx", "pydantic >= 2.5.0", # required for ray >= 2.9.0 (https://github.com/ray-project/ray/releases?page=2) + "pynvml", ] # NOTE: Change the templates/spot-controller.yaml.j2 file if any of the following @@ -93,7 +94,6 @@ def parse_readme(readme: str) -> str: # If you don't want to use these exact versions, you can install runhouse without the aws extras, then # install your desired versions of awscli and boto3 "pycryptodome==3.12.0", - "sshtunnel>=0.3.0", # required for sagemaker ], "azure": ["skypilot[azure]==0.6.0"], "gcp": [ @@ -101,13 +101,6 @@ def parse_readme(readme: str) -> str: "gcsfs", ], "docker": ["docker"], - "sagemaker": [ - "skypilot==0.6.0", - # https://github.com/aws-samples/sagemaker-ssh-helper - "sagemaker_ssh_helper", - "sagemaker", - "paramiko>=3.2.0", - ], "kubernetes": ["skypilot==0.6.0", "kubernetes"], } diff --git a/tests/conftest.py b/tests/conftest.py index 1cf28fd88..f2a06cdb2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -330,20 +330,6 @@ def event_loop(): unnamed_env, # noqa: F401 ) -# ----------------- Blobs ----------------- - -from tests.test_resources.test_modules.test_blobs.conftest import ( - blob, # noqa: F401 - blob_data, # noqa: F401 - cluster_blob, # noqa: F401 - cluster_file, # noqa: F401 - file, # noqa: F401 - gcs_blob, # noqa: F401 - local_blob, # noqa: F401 - local_file, # noqa: F401 - s3_blob, # noqa: F401 -) - # ----------------- Modules ----------------- # ----------------- Functions ----------------- diff --git a/tests/test_resources/test_clusters/test_cluster.py b/tests/test_resources/test_clusters/test_cluster.py index 14bfede84..fd63d2805 100644 --- a/tests/test_resources/test_clusters/test_cluster.py +++ b/tests/test_resources/test_clusters/test_cluster.py @@ -180,7 +180,7 @@ def test_cluster_endpoint(self, cluster): return endpoint = cluster.endpoint() - if cluster.server_connection_type in ["ssh", "aws_ssm"]: + if cluster.server_connection_type == "ssh": assert cluster.endpoint(external=True) is None assert endpoint == f"http://{LOCALHOST}:{cluster.client_port}" else: @@ -206,18 +206,19 @@ def test_cluster_endpoint(self, cluster): ] # getting the first element because the endpoint returns the status + response to den. assert status_data["cluster_config"]["resource_type"] == "cluster" assert status_data["env_servlet_processes"] - assert status_data["server_cpu_utilization"] + assert isinstance(status_data["server_cpu_utilization"], float) assert status_data["server_memory_usage"] assert not status_data.get("server_gpu_usage", None) @pytest.mark.level("local") @pytest.mark.clustertest - def test_cluster_request_timeout(self, cluster): + def test_cluster_request_timeout(self, docker_cluster_pk_ssh_no_auth): + cluster = docker_cluster_pk_ssh_no_auth with pytest.raises(requests.exceptions.ReadTimeout): cluster._http_client.request_json( endpoint="/status", req_type="get", - timeout=0.01, + timeout=0.005, headers=rh.globals.rns_client.request_headers(), ) diff --git a/tests/test_resources/test_envs/test_env.py b/tests/test_resources/test_envs/test_env.py index c1d463b94..605b3ede0 100644 --- a/tests/test_resources/test_envs/test_env.py +++ b/tests/test_resources/test_envs/test_env.py @@ -280,7 +280,7 @@ def test_secrets_env(self, env, cluster): secret = rh.Secret.from_name(secret) if secret.path: - assert rh.file(path=secret.path, system=cluster).exists_in_system() + assert rh.folder(path=secret.path, system=cluster).exists_in_system() else: env_vars = secret.env_vars or secret._DEFAULT_ENV_VARS for _, var in env_vars.items(): diff --git a/tests/test_resources/test_modules/test_blobs/__init__.py b/tests/test_resources/test_modules/test_blobs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_resources/test_modules/test_blobs/conftest.py b/tests/test_resources/test_modules/test_blobs/conftest.py deleted file mode 100644 index 28ce5e5a3..000000000 --- a/tests/test_resources/test_modules/test_blobs/conftest.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np -import pytest - -import runhouse as rh - -from tests.conftest import init_args - - -@pytest.fixture -def blob(request): - """Parametrize over multiple blobs - useful for running the same test on multiple storage types.""" - return request.getfixturevalue(request.param) - - -@pytest.fixture -def file(request): - """Parametrize over multiple files - useful for running the same test on multiple storage types.""" - return request.getfixturevalue(request.param) - - -@pytest.fixture(scope="session") -def blob_data(): - return [np.arange(50), "test", {"a": 1, "b": 2}] - - -@pytest.fixture -def local_file(blob_data, tmp_path): - args = { - "data": blob_data, - "system": "file", - "path": str(tmp_path / "test_blob.pickle"), - } - b = rh.blob(**args) - init_args[id(b)] = args - return b - - -@pytest.fixture -def local_blob(blob_data): - return rh.blob( - data=blob_data, - ) - - -@pytest.fixture -def s3_blob(blob_data, blob_s3_bucket): - return rh.blob( - data=blob_data, - system="s3", - path=f"/{blob_s3_bucket}/test_blob.pickle", - ) - - -@pytest.fixture -def gcs_blob(blob_data, blob_gcs_bucket): - return rh.blob( - data=blob_data, - system="gs", - path=f"/{blob_gcs_bucket}/test_blob.pickle", - ) - - -@pytest.fixture -def cluster_blob(blob_data, ondemand_aws_cluster): - return rh.blob( - data=blob_data, - system=ondemand_aws_cluster, - ) - - -@pytest.fixture -def cluster_file(blob_data, ondemand_aws_cluster): - return rh.blob( - data=blob_data, - system=ondemand_aws_cluster, - path="test_blob.pickle", - ) diff --git a/tests/test_resources/test_modules/test_blobs/test_blob.py b/tests/test_resources/test_modules/test_blobs/test_blob.py deleted file mode 100644 index 6cd13e161..000000000 --- a/tests/test_resources/test_modules/test_blobs/test_blob.py +++ /dev/null @@ -1,103 +0,0 @@ -from pathlib import Path - -import pytest - -import runhouse as rh - -from runhouse import Cluster - - -TEMP_LOCAL_FOLDER = Path(__file__).parents[1] / "rh-blobs" - - -def test_save_local_blob_fails(local_blob, blob_data): - with pytest.raises(ValueError): - local_blob.save(name="my_local_blob") - - -@pytest.mark.parametrize( - "blob", - ["local_file", "s3_blob", "gcs_blob"], - indirect=True, -) -def test_reload_blob_with_name(blob): - name = "my_blob" - blob.save(name) - original_system = str(blob.system) - original_data_str = str(blob.fetch()) - - del blob - - reloaded_blob = rh.blob(name=name) - assert str(reloaded_blob.system) == str(original_system) - reloaded_data = reloaded_blob.fetch() - assert reloaded_data[1] == "test" - assert str(reloaded_data) == original_data_str - - # Delete metadata saved locally and / or the database for the blob - reloaded_blob.delete_configs() - - # Delete the blob - reloaded_blob.rm() - assert not reloaded_blob.exists_in_system() - - -@pytest.mark.parametrize( - "blob", ["local_file", "s3_blob", "gcs_blob", "cluster_file"], indirect=True -) -def test_reload_file_with_path(blob): - reloaded_blob = rh.blob(path=blob.path, system=blob.system) - reloaded_data = reloaded_blob.fetch() - assert reloaded_data[1] == "test" - - # Delete the blob - reloaded_blob.rm() - assert not reloaded_blob.exists_in_system() - - -@pytest.mark.parametrize("file", ["local_file", "cluster_file"], indirect=True) -def test_file_to_blob(file, cluster): - local_blob = file.to("here") - assert local_blob.system is None - fetched = local_blob.fetch() - assert fetched[1] == "test" - assert str(fetched) == str(file.fetch()) - - cluster_blob = file.to(cluster) - assert isinstance(cluster_blob.system, Cluster) - fetched = cluster_blob.fetch() - assert fetched[1] == "test" - assert str(fetched) == str(file.fetch()) - - -@pytest.mark.parametrize( - "blob", ["local_blob", "cluster_blob", "local_file"], indirect=True -) -@pytest.mark.parametrize( - "folder", - ["local_folder", "cluster_folder", "s3_folder", "gcs_folder"], - indirect=True, -) -def test_blob_to_file(blob, folder): - new_file = blob.to(system=folder.system, path=folder.path + "/test_blob.pickle") - assert new_file.system == folder.system - assert new_file.path == folder.path + "/test_blob.pickle" - fetched = new_file.fetch() - assert fetched[1] == "test" - assert str(fetched) == str(blob.fetch()) - assert "test_blob.pickle" in folder.ls(full_paths=False) - - -@pytest.mark.skip -def test_sharing_blob(cluster_blob): - pass - # TODO - - -def test_load_shared_blob(local_blob): - my_blob = rh.blob(name="@/shared_blob") - assert my_blob.exists_in_system() - - reloaded_data = my_blob.fetch() - # NOTE: we need to do the deserialization ourselves - assert str(reloaded_data) == str(local_blob.fetch()) diff --git a/tests/test_resources/test_modules/test_functions/test_function.py b/tests/test_resources/test_modules/test_functions/test_function.py index d2b22a606..5199e02f5 100644 --- a/tests/test_resources/test_modules/test_functions/test_function.py +++ b/tests/test_resources/test_modules/test_functions/test_function.py @@ -500,34 +500,6 @@ def test_keep_warm_on_demand_unittest(self, mocker): # Reset the system attribute self.function.system = None - @pytest.mark.level("unit") - def test_keep_warm_unittest_sagemaker(self, mocker): - mock_function = mocker.patch("runhouse.Function.keep_warm") - mock_function.return_value = self.function - - # Create a Mock instance for cluster - mock_cluster = mocker.patch("runhouse.SageMakerCluster") - sagemaker_cluster = mock_cluster(name="Sagemaker_cluster") - sagemaker_cluster.autostop_mins.return_value = 3 - - # Set the system attribute - self.function.system = sagemaker_cluster - - # Call the method under test - response_sagemaker = self.function.keep_warm(autostop_mins=3) - - # Assertions - mock_function.assert_called_once_with(autostop_mins=3) - assert ( - response_sagemaker.system.autostop_mins.return_value - == self.function.system.autostop_mins.return_value - ) - assert self.function.system.autostop_mins.return_value == 3 - assert response_sagemaker.system.autostop_mins.return_value == 3 - - # Reset the system attribute - self.function.system = None - @pytest.mark.level("unit") def test_notebook_unittest(self, mocker): mock_function = mocker.patch("runhouse.Function.notebook") diff --git a/tests/test_resources/test_modules/test_module.py b/tests/test_resources/test_modules/test_module.py index d1ac34f62..d8c8306c7 100644 --- a/tests/test_resources/test_modules/test_module.py +++ b/tests/test_resources/test_modules/test_module.py @@ -14,6 +14,7 @@ from runhouse import Package from runhouse.constants import TEST_ORG from runhouse.logger import get_logger +from runhouse.utils import capture_stdout logger = get_logger(__name__) @@ -170,7 +171,7 @@ def load_and_use_readonly_module(mod_name, cpu_count, size=3): results = [] # Capture stdout to check that it's working out = "" - with rh.capture_stdout() as stdout: + with capture_stdout() as stdout: for i, val in enumerate(remote_df.slow_iter()): assert val print(val) @@ -275,7 +276,7 @@ def test_module_from_factory(self, cluster, env): results = [] out = "" - with rh.capture_stdout() as stdout: + with capture_stdout() as stdout: for val in remote_instance.slow_iter(): assert val print(val) @@ -364,7 +365,7 @@ def test_module_from_subclass(self, cluster, env): results = [] # Capture stdout to check that it's working out = "" - with rh.capture_stdout() as stdout: + with capture_stdout() as stdout: for i, val in enumerate(remote_instance.slow_iter()): assert val print(val) @@ -429,7 +430,7 @@ async def test_module_from_subclass_async(self, cluster, env): results = [] # Capture stdout to check that it's working out = "" - with rh.capture_stdout() as stdout: + with capture_stdout() as stdout: async for val in remote_df.slow_iter_async(): assert val print(val) diff --git a/tests/test_resources/test_run.py b/tests/test_resources/test_run.py deleted file mode 100644 index 6951d0ab6..000000000 --- a/tests/test_resources/test_run.py +++ /dev/null @@ -1,448 +0,0 @@ -from pathlib import Path -from pprint import pprint - -import pytest -import runhouse as rh - -CTX_MGR_RUN = "my_run_activity" -CLI_RUN_NAME = "my_cli_run" - -PATH_TO_CTX_MGR_RUN = f"{rh.Run.LOCAL_RUN_PATH}/{CTX_MGR_RUN}" - -RUN_FILES = ( - rh.Run.INPUTS_FILE, - rh.Run.RESULT_FILE, - rh.Run.RUN_CONFIG_FILE, - ".out", - ".err", -) - - -@pytest.fixture(scope="session") -def submitted_run(summer_func): # noqa: F811 - """Initializes a Run, which will run synchronously on the cluster. Returns the function's result.""" - run_name = "synchronous_run" - res = summer_func(1, 2, run_name=run_name) - assert res == 3 - return run_name - - -@pytest.fixture(scope="session") -def submitted_async_run(summer_func): # noqa: F811 - """Execute function async on the cluster. If a run already exists, do not re-run. Returns a Run object.""" - run_name = "async_run" - async_run = summer_func.run(run_name=run_name, a=1, b=2) - - assert isinstance(async_run, rh.Run) - return run_name - - -# ------------------------- FUNCTION RUN ---------------------------------- - - -def test_read_prov_info(summer_func): - """Reads the stdout for the Run.""" - remote_res = summer_func.call.remote(a=1, b=2) - assert isinstance(remote_res, rh.Blob) - assert remote_res.name in summer_func.system.keys() - assert remote_res.fetch() == 3 - stdout = remote_res.provenance.stdout() - pprint(stdout) - assert "Calling method call on module summer_func" in stdout - - assert remote_res.provenance.status == rh.RunStatus.COMPLETED - - -def test_get_or_call_from_cache(summer_func): - """Cached version of synchronous run - if already completed return the result, otherwise run and wait for - completion before returning the result.""" - run_name = "my_sync_run" - summer_func.system.delete(run_name) - - run_output = summer_func.get_or_call(run_name, a=1, b=2, load_from_den=False) - assert run_output.fetch() == 3 - assert run_name in summer_func.system.keys() - - run_output = summer_func.get_or_call(run_name, a=10, b=10, load_from_den=False) - assert run_output.fetch() == 3 - - summer_func.system.delete(run_name) - # Asser than an exception is thrown if the wrong args are passed in - with pytest.raises(TypeError): - summer_func.get_or_call(run_name, a=10, b=10, c=10, load_from_den=False) - - -def test_invalid_fn_sync_run(summer_func, ondemand_aws_cluster): - """Test error handling for invalid function Run. The function expects to receive integers but - does not receive any. An error should be thrown via Ray.""" - import ray - - try: - summer_func.get_or_call(run_name="invalid_run") - except (ray.exceptions.RayTaskError, TypeError) as e: - assert ( - str(e.args[0]) - == "summer() missing 2 required positional arguments: 'a' and 'b'" - ) - - -@pytest.mark.skip("Not implemented yet.") -def test_invalid_fn_async_run(summer_func): - """Test error handling for invalid function Run. The function expects to receive integers but - does not receive any. The Run object returned should have a status of `ERROR`, and the - result should be its stderr.""" - run_obj = summer_func.get_or_run(run_name="invalid_async_run") - - assert run_obj.refresh().status == rh.RunStatus.ERROR - assert "summer() missing 2 required positional arguments" in run_obj.result() - - -@pytest.mark.skip("Not implemented yet.") -def test_get_fn_status_updates(ondemand_aws_cluster, slow_func): - """Run a function that takes a long time to run, confirming that its status changes as we refresh the Run""" - async_run = slow_func.run(run_name="my_slow_async_run", a=1, b=2) - - assert isinstance(async_run, rh.Run) - - assert async_run.status == rh.RunStatus.RUNNING - - while async_run.refresh().status != rh.RunStatus.COMPLETED: - # ... do something else while we wait for the run to finish - pass - - assert async_run.refresh().status == rh.RunStatus.COMPLETED - - -@pytest.mark.skip("Not implemented yet.") -def test_get_or_call_latest(summer_func): - """Cached version of synchronous run - if already completed return the result, otherwise run and wait for - completion before returning the result.""" - # Note: In this test since we are providing a name of "latest", it should return the latest cached version - run_output = summer_func.get_or_call("latest") - - assert run_output == 3 - - -@pytest.mark.skip("Not implemented yet.") -def test_send_run_to_system_on_completion(summer_func, submitted_async_run): - # Only once the run actually finishes do we send to S3 - async_run = summer_func.run(run_name=submitted_async_run, a=1, b=2).to( - "s3", on_completion=True - ) - - assert isinstance(async_run, rh.Run) - - -@pytest.mark.skip("Not implemented yet.") -def test_run_refresh(slow_func): - async_run = slow_func.get_or_run(run_name="async_get_or_run", a=1, b=2) - - while async_run.refresh().status in [ - rh.RunStatus.RUNNING, - rh.RunStatus.NOT_STARTED, - ]: - # do stuff ..... - pass - - assert async_run.refresh().status == rh.RunStatus.COMPLETED - - -@pytest.mark.skip("Not implemented yet.") -def test_get_async_run_result(summer_func, submitted_async_run): - """Read the results from an async run.""" - async_run = summer_func.get_or_run(run_name=submitted_async_run) - assert isinstance(async_run, rh.Run) - assert async_run.result() == 3 - - -@pytest.mark.skip("Not implemented yet.") -def test_get_or_run_no_cache(summer_func): - """Execute function async on the cluster. If a run already exists, do not re-run. Returns a Run object.""" - # Note: In this test since no Run exists with this name, will trigger the function async on the cluster and in the - # meantime return a Run object. - async_run = summer_func.get_or_run(run_name="new_async_run", a=1, b=2) - assert isinstance(async_run, rh.Run) - - run_result = async_run.result() - assert run_result == 3 - - -@pytest.mark.skip("Not implemented yet.") -def test_get_or_run_latest(summer_func): - """Execute function async on the cluster. If a run already exists, do not re-run. Returns a Run object.""" - # Note: In this test since we are providing "latest", will return the latest cached version. - async_run = summer_func.get_or_run(run_name="latest") - assert isinstance(async_run, rh.Run) - - -@pytest.mark.skip("Not implemented yet.") -def test_delete_async_run_from_system(ondemand_aws_cluster, submitted_async_run): - # Load the run from the cluster and delete its dedicated folder - async_run = ondemand_aws_cluster.get_run(submitted_async_run) - async_run.folder.rm() - assert not async_run.folder.exists_in_system() - - -@pytest.mark.skip("Not implemented yet.") -def test_save_fn_run_to_rns(ondemand_aws_cluster, submitted_run): - """Saves run config to RNS""" - # Load run that lives on the cluster - func_run = ondemand_aws_cluster.get_run(submitted_run) - assert func_run - - # Save to RNS - func_run.save(name=submitted_run) - - # Load from RNS - loaded_run = rh.run(submitted_run) - assert rh.exists(loaded_run.name, resource_type=rh.Run.RESOURCE_TYPE) - - -def test_create_anon_run_on_cluster(summer_func): - """Create a new Run without giving it an explicit name.""" - # Note: this will run synchronously and return the result - res = summer_func(1, 2) - assert res == 3 - - -@pytest.mark.skip("Not yet implemented.") -def test_latest_fn_run(summer_func): - run_output = summer_func.get_or_call(run_str="latest") - assert run_output == 3 - - -@pytest.mark.skip("Not implemented yet.") -def test_copy_fn_run_from_cluster_to_local(ondemand_aws_cluster, submitted_run): - my_run = ondemand_aws_cluster.get_run(submitted_run) - my_local_run = my_run.to("here") - assert my_local_run.folder.exists_in_system() - - # Check that all files were copied - folder_contents = my_local_run.folder.ls() - for f in folder_contents: - file_extension = Path(f).suffix - file_name = f.split("/")[-1] - assert file_extension in file_name or file_name in RUN_FILES - - -@pytest.mark.skip("Not implemented yet.") -def test_copy_fn_run_from_system_to_s3( - ondemand_aws_cluster, runs_s3_bucket, submitted_run -): - my_run = ondemand_aws_cluster.get_run(submitted_run) - my_run_on_s3 = my_run.to("s3", path=f"/{runs_s3_bucket}/my_test_run") - - assert my_run_on_s3.folder.exists_in_system() - - # Check that all files were copied - folder_contents = my_run_on_s3.folder.ls() - for f in folder_contents: - file_extension = Path(f).suffix - file_name = f.split("/")[-1] - assert file_extension in file_name or file_name in RUN_FILES - - # Delete the run from s3 - my_run_on_s3.folder.rm() - assert not my_run_on_s3.folder.exists_in_system() - - -@pytest.mark.skip("Not implemented yet.") -def test_delete_fn_run_from_rns(submitted_run): - # Load directly from RNS - loaded_run = rh.run(name=submitted_run) - - loaded_run.delete_configs() - assert not rh.exists(name=loaded_run.name, resource_type=rh.Run.RESOURCE_TYPE) - - -# ------------------------- CLI RUN ------------ ---------------------- - - -@pytest.mark.skip("Run stuff is deprecated.") -def test_create_cli_python_command_run(ondemand_aws_cluster): - # Run python commands on the specified system. Save the run results to the .rh/logs/ folder of the system. - return_codes = ondemand_aws_cluster.run_python( - [ - "import runhouse as rh", - "import logging", - "local_blob = rh.file(name='local_blob', data=list(range(50)))", - "logging.info(f'File path: {local_blob.path}')", - "local_blob.rm()", - ], - stream_logs=True, - ) - pprint(return_codes) - - assert return_codes[0][0] == 0, "Failed to run python commands" - assert "File path" in return_codes[0][1].strip() - - -@pytest.mark.skip("Run stuff is deprecated.") -def test_create_cli_command_run(ondemand_aws_cluster): - """Run CLI command on the specified system. - Saves the Run locally to the rh/ folder of the local file system.""" - return_codes = ondemand_aws_cluster.run(["python --version"]) - - assert return_codes[0][0] == 0, "Failed to run CLI command" - assert return_codes[0][1].strip() == "Python 3.10.6" - - -@pytest.mark.skip("Not implemented yet.") -def test_send_cli_run_to_cluster(ondemand_aws_cluster): - """Send the CLI based Run which was initially saved on the local file system to the cpu cluster.""" - # Load the run from the local file system - loaded_run = rh.run( - name=CLI_RUN_NAME, path=f"{rh.Run.LOCAL_RUN_PATH}/{CLI_RUN_NAME}" - ) - assert loaded_run.refresh().status == rh.RunStatus.COMPLETED - assert loaded_run.stdout() == "Python 3.10.6" - - # Save to default path on the cluster (~/.rh/logs/) - cluster_run = loaded_run.to( - ondemand_aws_cluster, path=rh.Run._base_cluster_folder_path(name=CLI_RUN_NAME) - ) - - assert cluster_run.folder.exists_in_system() - assert isinstance(cluster_run.folder.system, rh.Cluster) - - -@pytest.mark.skip("Not implemented yet.") -def test_load_cli_command_run_from_cluster(ondemand_aws_cluster): - # At this point the Run exists locally and on the cluster (hasn't yet been saved to RNS). - # Load from the cluster - cli_run = ondemand_aws_cluster.get_run(CLI_RUN_NAME) - assert isinstance(cli_run, rh.Run) - - -@pytest.mark.skip("Not implemented yet.") -def test_save_cli_run_to_rns(ondemand_aws_cluster): - # Load the run from the cluster - cli_run = ondemand_aws_cluster.get_run(CLI_RUN_NAME) - - # Save to RNS - cli_run.save(name=CLI_RUN_NAME) - - # Confirm Run now lives in RNS - loaded_run = rh.run(CLI_RUN_NAME) - assert loaded_run - - -@pytest.mark.skip("Not implemented yet.") -def test_read_cli_command_stdout_from_cluster(ondemand_aws_cluster): - # Read the stdout from the cluster - cli_run = ondemand_aws_cluster.get_run(CLI_RUN_NAME) - cli_stdout = cli_run.stdout() - assert cli_stdout == "Python 3.10.6" - - -def test_delete_cli_run_from_local_filesystem(): - """Delete the config where it was initially saved (in the local ``rh`` folder of the working directory)""" - # Load the run from the local file system - cli_run = rh.run(CLI_RUN_NAME, system=rh.Folder.DEFAULT_FS) - cli_run.folder.rm() - - assert not cli_run.folder.exists_in_system() - - -@pytest.mark.skip("Not implemented yet.") -def test_delete_cli_run_from_cluster(ondemand_aws_cluster): - """Delete the config where it was copied to (in the ``~/.rh/logs/`` folder of the cluster)""" - cli_run = ondemand_aws_cluster.get_run(CLI_RUN_NAME) - assert cli_run, f"Failed to load run {CLI_RUN_NAME} from cluster" - - # Update the Run's folder to point to the cluster instead of the local file system - cli_run.folder.system = ondemand_aws_cluster - cli_run.folder.path = rh.Run._base_cluster_folder_path(name=CLI_RUN_NAME) - assert cli_run.folder.exists_in_system() - - cli_run.folder.rm() - assert not cli_run.folder.exists_in_system() - - cli_run = ondemand_aws_cluster.get_run(CLI_RUN_NAME) - assert cli_run is None, f"Failed to delete {cli_run} on cluster" - - -@pytest.mark.skip("Not implemented yet.") -def test_delete_cli_run_from_rns(): - # Load from RNS - loaded_run = rh.run(CLI_RUN_NAME) - loaded_run.delete_configs() - assert not rh.exists(name=loaded_run.name, resource_type=rh.Run.RESOURCE_TYPE) - - -# ------------------------- CTX MANAGER RUN ---------------------------------- - - -@pytest.mark.skip("Not implemented yet.") -def test_create_local_ctx_manager_run(summer_func, ondemand_aws_cluster): - from runhouse.globals import rns_client - - ctx_mgr_func = "my_ctx_mgr_func" - - with rh.run(path=PATH_TO_CTX_MGR_RUN) as r: - # Add all Runhouse objects loaded or saved in the context manager to the Run's artifact registry - # (upstream + downstream artifacts) - summer_func.save(ctx_mgr_func) - - summer_func(1, 2, run_name="my_new_run") - - current_run = summer_func.system.get_run("my_new_run") - run_res = current_run.result() - print(f"Run result: {run_res}") - - cluster_config = rh.load(name=ondemand_aws_cluster.name, instantiate=False) - cluster = rh.Cluster.from_config(config=cluster_config, dryrun=True) - print(f"Cluster loaded: {cluster.name}") - - summer_func.delete_configs() - - r.save(name=CTX_MGR_RUN) - - print(f"Saved Run with name: {r.name} to path: {r.folder.path}") - - # Artifacts include the rns resolved name (ex: "/jlewitt1/rh-cpu") - assert r.downstream_artifacts == [ - rns_client.resolve_rns_path(ctx_mgr_func), - rns_client.resolve_rns_path(ondemand_aws_cluster.name), - ] - assert r.upstream_artifacts == [ - rns_client.resolve_rns_path(ondemand_aws_cluster.name), - ] - - -def test_load_named_ctx_manager_run(): - # Load from local file system - ctx_run = rh.run(path=PATH_TO_CTX_MGR_RUN) - assert ctx_run.folder.exists_in_system() - - -@pytest.mark.skip("Not implemented yet.") -def test_read_stdout_from_ctx_manager_run(): - # Load from local file system - ctx_run = rh.run(path=PATH_TO_CTX_MGR_RUN) - stdout = ctx_run.stdout() - pprint(stdout) - assert stdout - - -def test_save_ctx_run_to_rns(): - # Load from local file system - ctx_run = rh.run(path=PATH_TO_CTX_MGR_RUN) - ctx_run.save() - assert rh.exists(name=ctx_run.name, resource_type=rh.Run.RESOURCE_TYPE) - - -@pytest.mark.skip("Not implemented yet.") -def test_delete_ctx_run_from_rns(): - # Load from RNS - loaded_run = rh.run(name=CTX_MGR_RUN) - loaded_run.delete_configs() - - assert not rh.exists(name=loaded_run.name, resource_type=rh.Run.RESOURCE_TYPE) - - -def test_delete_ctx_run_from_local_filesystem(): - # Load from local file system - ctx_run = rh.run(path=PATH_TO_CTX_MGR_RUN) - ctx_run.folder.rm() - assert not ctx_run.folder.exists_in_system() diff --git a/tests/test_resources/test_secrets/test_secret.py b/tests/test_resources/test_secrets/test_secret.py index 7333459ba..c533ec75d 100644 --- a/tests/test_resources/test_secrets/test_secret.py +++ b/tests/test_resources/test_secrets/test_secret.py @@ -200,12 +200,12 @@ def test_sync_secrets(self, secret, cluster): secret = secret.write(path=test_path) cluster.sync_secrets([secret]) - remote_file = rh.file(path=secret.path, system=cluster) - assert remote_file.exists_in_system() - assert secret._from_path(remote_file) == secret.values + remote_folder = rh.folder(path=secret.path, system=cluster) + assert remote_folder.exists_in_system() + assert secret._from_path(remote_folder.path) == secret.values assert_delete_local(secret, contents=True) - remote_file.rm() + remote_folder.rm() else: cluster.sync_secrets([secret]) assert cluster.get(secret.name) diff --git a/tests/test_servers/test_http_server.py b/tests/test_servers/test_http_server.py index c53824fef..94298442c 100644 --- a/tests/test_servers/test_http_server.py +++ b/tests/test_servers/test_http_server.py @@ -1,14 +1,11 @@ import json import os -import tempfile import uuid -from pathlib import Path import pytest -import runhouse as rh - from runhouse.globals import rns_client +from runhouse.resources.resource import Resource from runhouse.servers.http.http_utils import ( DeleteObjectParams, deserialize_data, @@ -57,9 +54,9 @@ def test_check_server(self, http_client): assert response.status_code == 200 @pytest.mark.level("local") - def test_put_resource(self, http_client, blob_data, cluster): + def test_put_resource(self, http_client, cluster): state = None - resource = rh.blob(data=blob_data, system=cluster) + resource = Resource(name="test_resource3", system=cluster) data = serialize_data( (resource.config(condensed=False), state, resource.dryrun), "pickle" ) @@ -643,9 +640,9 @@ def test_check_server_with_invalid_token(self, http_client): assert response.status_code == 200 @pytest.mark.level("local") - def test_put_resource_with_invalid_token(self, http_client, blob_data, cluster): + def test_put_resource_with_invalid_token(self, http_client, cluster): state = None - resource = rh.blob(blob_data, system=cluster) + resource = Resource(name="test_resource1", system=cluster) data = serialize_data( (resource.config(condensed=False), state, resource.dryrun), "pickle" ) @@ -744,24 +741,20 @@ def test_check_server(self, client): assert response.status_code == 200 @pytest.mark.level("unit") - def test_put_resource(self, client, blob_data, local_cluster): - with tempfile.TemporaryDirectory() as temp_dir: - resource_path = Path(temp_dir, "local-blob") - local_blob = rh.blob(blob_data, path=resource_path) - resource = local_blob.to(system="file", path=resource_path) - - state = None - data = serialize_data( - (resource.config(condensed=False), state, resource.dryrun), "pickle" - ) - response = client.post( - "/resource", - json=PutResourceParams( - serialized_data=data, serialization="pickle" - ).model_dump(), - headers=rns_client.request_headers(local_cluster.rns_address), - ) - assert response.status_code == 200 + def test_put_resource(self, client, local_cluster): + resource = Resource(name="local-resource") + state = None + data = serialize_data( + (resource.config(condensed=False), state, resource.dryrun), "pickle" + ) + response = client.post( + "/resource", + json=PutResourceParams( + serialized_data=data, serialization="pickle" + ).model_dump(), + headers=rns_client.request_headers(local_cluster.rns_address), + ) + assert response.status_code == 200 @pytest.mark.level("unit") def test_put_object(self, client, local_cluster): @@ -849,26 +842,21 @@ def test_check_server_with_invalid_token(self, local_client_with_den_auth): assert response.status_code == 200 @pytest.mark.level("unit") - def test_put_resource_with_invalid_token( - self, local_client_with_den_auth, blob_data - ): - with tempfile.TemporaryDirectory() as temp_dir: - resource_path = Path(temp_dir, "local-blob") - local_blob = rh.blob(blob_data, path=resource_path) - resource = local_blob.to(system="file", path=resource_path) - state = None - data = serialize_data( - (resource.config(condensed=False), state, resource.dryrun), "pickle" - ) - resp = local_client_with_den_auth.post( - "/resource", - json=PutResourceParams( - serialized_data=data, serialization="pickle" - ).model_dump(), - headers=INVALID_HEADERS, - ) + def test_put_resource_with_invalid_token(self, local_client_with_den_auth): + resource = Resource(name="test_resource2") + state = None + data = serialize_data( + (resource.config(condensed=False), state, resource.dryrun), "pickle" + ) + resp = local_client_with_den_auth.post( + "/resource", + json=PutResourceParams( + serialized_data=data, serialization="pickle" + ).model_dump(), + headers=INVALID_HEADERS, + ) - assert resp.status_code == 403 + assert resp.status_code == 403 @pytest.mark.level("unit") def test_put_object_with_invalid_token(self, local_client_with_den_auth): diff --git a/tests/test_servers/test_servlet.py b/tests/test_servers/test_servlet.py index bd2dd651b..99d1eed6a 100644 --- a/tests/test_servers/test_servlet.py +++ b/tests/test_servers/test_servlet.py @@ -1,9 +1,6 @@ -import tempfile -from pathlib import Path - import pytest -import runhouse as rh +from runhouse.resources.resource import Resource from runhouse.servers.http.http_utils import deserialize_data, serialize_data from runhouse.servers.obj_store import ObjStore @@ -11,38 +8,31 @@ @pytest.mark.servertest class TestServlet: @pytest.mark.level("unit") - def test_put_resource(self, test_servlet, blob_data): - with tempfile.TemporaryDirectory() as temp_dir: - resource_path = Path(temp_dir, "local-blob") - local_blob = rh.blob(blob_data, path=resource_path) - resource = local_blob.to(system="file", path=resource_path) - - state = {} - resp = ObjStore.call_actor_method( - test_servlet, - "aput_resource_local", - data=serialize_data( - (resource.config(condensed=False), state, resource.dryrun), "pickle" - ), - serialization="pickle", - ) - - assert resp.output_type == "result_serialized" - assert deserialize_data(resp.data, resp.serialization).startswith("file_") + def test_put_resource(self, test_servlet): + resource = Resource(name="local-resource") + state = {} + resp = ObjStore.call_actor_method( + test_servlet, + "aput_resource_local", + data=serialize_data( + (resource.config(condensed=False), state, resource.dryrun), "pickle" + ), + serialization="pickle", + ) + assert resp.output_type == "result_serialized" + assert deserialize_data(resp.data, resp.serialization) == resource.name @pytest.mark.level("unit") - def test_put_obj_local(self, test_servlet, blob_data): - with tempfile.TemporaryDirectory() as temp_dir: - resource_path = Path(temp_dir, "local-blob") - resource = rh.blob(blob_data, path=resource_path) - resp = ObjStore.call_actor_method( - test_servlet, - "aput_local", - key="key1", - data=serialize_data(resource, "pickle"), - serialization="pickle", - ) - assert resp.output_type == "success" + def test_put_obj_local(self, test_servlet): + resource = Resource(name="local-resource") + resp = ObjStore.call_actor_method( + test_servlet, + "aput_local", + key="key1", + data=serialize_data(resource, "pickle"), + serialization="pickle", + ) + assert resp.output_type == "success" @pytest.mark.level("unit") def test_get_obj(self, test_servlet): @@ -55,8 +45,8 @@ def test_get_obj(self, test_servlet): remote=False, ) assert resp.output_type == "result_serialized" - blob = deserialize_data(resp.data, resp.serialization) - assert isinstance(blob, rh.Blob) + resource = deserialize_data(resp.data, resp.serialization) + assert isinstance(resource, Resource) @pytest.mark.level("unit") def test_get_obj_remote(self, test_servlet): @@ -69,8 +59,8 @@ def test_get_obj_remote(self, test_servlet): remote=True, ) assert resp.output_type == "config" - blob_config = deserialize_data(resp.data, resp.serialization) - assert isinstance(blob_config, dict) + resource_config = deserialize_data(resp.data, resp.serialization) + assert isinstance(resource_config, dict) @pytest.mark.level("unit") def test_get_obj_does_not_exist(self, test_servlet):