diff --git a/.github/workflows/code_test_and_deploy.yml b/.github/workflows/code_test_and_deploy.yml index 49054c6b..fe1ffdd4 100644 --- a/.github/workflows/code_test_and_deploy.yml +++ b/.github/workflows/code_test_and_deploy.yml @@ -55,8 +55,25 @@ jobs: conda uninstall datashuttle --force python -m pip install --upgrade pip pip install .[dev] - - name: Test - run: pytest +# - name: Test + # run: pytest + # - name: Set up Google Drive secrets + # run: | + # printf '%s' '${{ secrets.GDRIVE_SERVICE_ACCOUNT_JSON }}' > "$HOME/gdrive.json" + # echo "GDRIVE_SERVICE_ACCOUNT_FILE=$HOME/gdrive.json" >> $GITHUB_ENV + # echo "GDRIVE_ROOT_FOLDER_ID=${{ secrets.GDRIVE_ROOT_FOLDER_ID }}" >> $GITHUB_ENV + + # - name: Run Google Drive tests + # run: pytest -q -k test_gdrive_connection + + - name: Set up AWS secrets + run: | + echo "AWS_ACCESS_KEY_ID=${{ secrets.AWS_ACCESS_KEY_ID }}" >> $GITHUB_ENV + echo "AWS_ACCESS_KEY_ID_SECRET=${{ secrets.AWS_ACCESS_KEY_ID_SECRET }}" >> $GITHUB_ENV + echo "AWS_REGION=${{ secrets.AWS_REGION }}" >> $GITHUB_ENV + + - name: Run AWS tests + run: pytest -q -k test_aws_connection build_sdist_wheels: name: Build source distribution diff --git a/datashuttle/configs/aws_regions.py b/datashuttle/configs/aws_regions.py new file mode 100644 index 00000000..84576030 --- /dev/null +++ b/datashuttle/configs/aws_regions.py @@ -0,0 +1,43 @@ +from typing import Dict, List + +# ----------------------------------------------------------------------------- +# AWS regions +# ----------------------------------------------------------------------------- + +# These function are used for type checking and providing intellisense to the developer + + +def get_aws_regions() -> Dict[str, str]: + aws_regions = { + "US_EAST_1": "us-east-1", + "US_EAST_2": "us-east-2", + "US_WEST_1": "us-west-1", + "US_WEST_2": "us-west-2", + "CA_CENTRAL_1": "ca-central-1", + "EU_WEST_1": "eu-west-1", + "EU_WEST_2": "eu-west-2", + "EU_WEST_3": "eu-west-3", + "EU_NORTH_1": "eu-north-1", + "EU_SOUTH_1": "eu-south-1", + "EU_CENTRAL_1": "eu-central-1", + "AP_SOUTHEAST_1": "ap-southeast-1", + "AP_SOUTHEAST_2": "ap-southeast-2", + "AP_NORTHEAST_1": "ap-northeast-1", + "AP_NORTHEAST_2": "ap-northeast-2", + "AP_NORTHEAST_3": "ap-northeast-3", + "AP_SOUTH_1": "ap-south-1", + "AP_EAST_1": "ap-east-1", + "SA_EAST_1": "sa-east-1", + "IL_CENTRAL_1": "il-central-1", + "ME_SOUTH_1": "me-south-1", + "AF_SOUTH_1": "af-south-1", + "CN_NORTH_1": "cn-north-1", + "CN_NORTHWEST_1": "cn-northwest-1", + "US_GOV_EAST_1": "us-gov-east-1", + "US_GOV_WEST_1": "us-gov-west-1", + } + return aws_regions + + +def get_aws_regions_list() -> List[str]: + return list(get_aws_regions().values()) diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index b65ad6c6..b0d6d8ef 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -39,9 +39,15 @@ def get_canonical_configs() -> dict: canonical_configs = { "local_path": Union[str, Path], "central_path": Optional[Union[str, Path]], - "connection_method": Optional[Literal["ssh", "local_filesystem"]], + "connection_method": Optional[ + Literal["ssh", "local_filesystem", "gdrive", "aws"] + ], "central_host_id": Optional[str], "central_host_username": Optional[str], + "gdrive_client_id": Optional[str], + "gdrive_root_folder_id": Optional[str], + "aws_access_key_id": Optional[str], + "aws_region": Optional[str], } return canonical_configs @@ -128,6 +134,29 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: ConfigError, ) + # Check gdrive settings + elif config_dict["connection_method"] == "gdrive": + if not config_dict["gdrive_root_folder_id"]: + utils.log_and_raise_error( + "'gdrive_root_folder_id' is required if 'connection_method' " + "is 'gdrive'.", + ConfigError, + ) + + if not config_dict["gdrive_client_id"]: + utils.log_and_message( + "`gdrive_client_id` not found in config. default rlcone client will be used (slower)." + ) + + # Check AWS settings + elif config_dict["connection_method"] == "aws" and ( + not config_dict["aws_access_key_id"] or not config_dict["aws_region"] + ): + utils.log_and_raise_error( + "Both aws_access_key_id and aws_region must be present for AWS connection.", + ConfigError, + ) + # Initialise the local project folder utils.print_message_to_user( f"Making project folder at: {config_dict['local_path']}" diff --git a/datashuttle/configs/config_class.py b/datashuttle/configs/config_class.py index 562c3310..43457ac4 100644 --- a/datashuttle/configs/config_class.py +++ b/datashuttle/configs/config_class.py @@ -116,7 +116,10 @@ def dump_to_file(self) -> None: def load_from_file(self) -> None: """ Load a config dict saved at .yaml file. Note this will - not automatically check the configs are valid, this + do a minimal backwards compatibility check and add config + keys to ensure backwards compatibility with new connection + methods added to Datashuttle. + But this will not automatically check the configs are valid, this requires calling self.check_dict_values_raise_on_fail() """ with open(self.file_path, "r") as config_file: @@ -124,8 +127,35 @@ def load_from_file(self) -> None: load_configs.convert_str_and_pathlib_paths(config_dict, "str_to_path") + self.update_config_for_backward_compatability_if_required(config_dict) + self.data = config_dict + def update_config_for_backward_compatability_if_required( + self, config_dict: Dict + ): + canonical_config_keys_to_add = [ + "gdrive_client_id", + "gdrive_root_folder_id", + "aws_access_key_id", + "aws_region", + ] + + # All keys shall be missing for a backwards compatibility update + if not ( + all( + key in config_dict.keys() + for key in canonical_config_keys_to_add + ) + ): + assert not any( + key in config_dict.keys() + for key in canonical_config_keys_to_add + ) + + for key in canonical_config_keys_to_add: + config_dict[key] = None + # ------------------------------------------------------------------------- # Utils # ------------------------------------------------------------------------- diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 9322323f..6e5fb742 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -36,9 +36,11 @@ from datashuttle.configs.config_class import Configs from datashuttle.datashuttle_functions import _format_top_level_folder from datashuttle.utils import ( + aws, ds_logger, folders, formatting, + gdrive, getters, rclone, ssh, @@ -53,6 +55,7 @@ from datashuttle.utils.decorators import ( # noqa check_configs_set, check_is_not_local_project, + requires_aws_configs, requires_ssh_configs, ) @@ -892,6 +895,90 @@ def write_public_key(self, filepath: str) -> None: public.write(key.get_base64()) public.close() + # ------------------------------------------------------------------------- + # Google Drive + # ------------------------------------------------------------------------- + + @check_configs_set + def setup_google_drive_connection(self) -> None: + """ + Setup a connection to Google Drive using the provided credentials. + Assumes `gdrive_root_folder_id` is set in configs. + + First, the user will be prompted to enter their Google Drive client + secret if `gdrive_client_id` is set in the configs. + + Next, the user will be asked if their machine has access to a browser. + If not, they will be prompted to input a config_token after running an + rclone command displayed to the user on a machine with access to a browser. + + Next, with the provided credentials, the final setup will be done. This + opens up a browser if the user confirmed access to a browser. + """ + self._start_log( + "setup-google-drive-connection-to-central-server", + local_vars=locals(), + ) + + if self.cfg["gdrive_client_id"]: + gdrive_client_secret = gdrive.get_client_secret() + else: + gdrive_client_secret = None + + browser_available = gdrive.ask_user_for_browser(log=True) + + if not browser_available: + service_account_filepath = utils.get_user_input( + "Please input the path to your credentials json" + ) # TODO: add more explanation + # config_token = gdrive.prompt_and_get_config_token( + # self.cfg, + # gdrive_client_secret, + # self.cfg.get_rclone_config_name("gdrive"), + # log=True, + # ) + else: + service_account_filepath = None + + self._setup_rclone_gdrive_config( + gdrive_client_secret, service_account_filepath, log=True + ) + + rclone.check_successful_connection_and_raise_error_on_fail(self.cfg) + + ds_logger.close_log_filehandler() + + # ------------------------------------------------------------------------- + # AWS S3 + # ------------------------------------------------------------------------- + + @requires_aws_configs + @check_configs_set + def setup_aws_connection(self) -> None: + """ + Setup a connection to AWS S3 buckets using the provided credentials. + Assumes `aws_access_key_id` and `aws_region` are set in configs. + + First, the user will be prompted to input their AWS secret access key. + + Next, with the provided credentials, the final connection setup will be done. + """ + self._start_log( + "setup-aws-connection-to-central-server", + local_vars=locals(), + ) + + aws_secret_access_key = aws.get_aws_secret_access_key() + + self._setup_rclone_aws_config(aws_secret_access_key, log=True) + + rclone.check_successful_connection_and_raise_error_on_fail(self.cfg) + aws.raise_if_bucket_absent(self.cfg) + + utils.log_and_message("AWS Connection Successful.") + + ds_logger.close_log_filehandler() + # ------------------------------------------------------------------------- # Configs # ------------------------------------------------------------------------- @@ -903,6 +990,10 @@ def make_config_file( connection_method: str | None = None, central_host_id: Optional[str] = None, central_host_username: Optional[str] = None, + gdrive_client_id: Optional[str] = None, + gdrive_root_folder_id: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_region: Optional[str] = None, ) -> None: """ Initialise the configurations for datashuttle to use on the @@ -967,6 +1058,10 @@ def make_config_file( "connection_method": connection_method, "central_host_id": central_host_id, "central_host_username": central_host_username, + "gdrive_client_id": gdrive_client_id, + "gdrive_root_folder_id": gdrive_root_folder_id, + "aws_access_key_id": aws_access_key_id, + "aws_region": aws_region, }, ) @@ -1470,6 +1565,30 @@ def _setup_rclone_central_local_filesystem_config(self) -> None: self.cfg.get_rclone_config_name("local_filesystem"), ) + def _setup_rclone_gdrive_config( + self, + gdrive_client_secret: str | None, + service_account_filepath: str | None, + log: bool, + ) -> None: + rclone.setup_rclone_config_for_gdrive( + self.cfg, + self.cfg.get_rclone_config_name("gdrive"), + gdrive_client_secret, + service_account_filepath, + log=log, + ) + + def _setup_rclone_aws_config( + self, aws_secret_access_key: str, log: bool + ) -> None: + rclone.setup_rclone_config_for_aws( + self.cfg, + self.cfg.get_rclone_config_name("aws"), + aws_secret_access_key, + log=log, + ) + # Persistent settings # ------------------------------------------------------------------------- diff --git a/datashuttle/tui/css/tui_menu.tcss b/datashuttle/tui/css/tui_menu.tcss index 00cb9458..e32a7b55 100644 --- a/datashuttle/tui/css/tui_menu.tcss +++ b/datashuttle/tui/css/tui_menu.tcss @@ -64,6 +64,12 @@ SettingsScreen { GetHelpScreen { align: center middle; } +SetupGdriveScreen { + align: center middle; +} +SetupAwsScreen { + align: center middle; +} #get_help_label { align: center middle; text-align: center; @@ -114,6 +120,69 @@ MessageBox:light > #messagebox_top_container { align: center middle; } +#setup_gdrive_screen_container { + height: 75%; + width: 80%; + align: center middle; + border: thick $panel-lighten-1; +} + +#gdrive_setup_messagebox_message_container { + height: 70%; + align: center middle; + overflow: hidden auto; + margin: 0 1; +} + +#gdrive_setup_messagebox_message { + text-align: center; + padding: 0 2; +} + +#setup_gdrive_ok_button { + align: center bottom; + height: 3; +} + +#setup_gdrive_cancel_button { + align: center bottom; +} + +#setup_gdrive_buttons_horizontal { + align: center middle; +} + +#setup_aws_screen_container { + height: 75%; + width: 80%; + align: center middle; + border: thick $panel-lighten-1; +} + +#setup_aws_messagebox_message_container { + align: center middle; + overflow: hidden auto; + margin: 0 1; +} + +#setup_aws_messagebox_message { + text-align: center; + padding: 0 2; +} + +#setup_aws_ok_button { + align: center bottom; + height: 3; +} + +#setup_aws_cancel_button { + align: center bottom; +} + +#setup_aws_buttons_horizontal { + align: center middle; +} + /* Configs Content ----------------------------------------------------------------- */ #configs_container { @@ -161,11 +230,15 @@ MessageBox:light > #messagebox_top_container { padding: 0 4 0 4; width: 26; color: $success; /* unsure about this */ + dock: left; } -#configs_setup_ssh_connection_button { +#setup_buttons_container > Button { margin: 2 1 0 0; + dock: left; } + + #configs_go_to_project_screen_button { margin: 2 1 0 0; } @@ -204,6 +277,10 @@ MessageBox:light > #messagebox_top_container { padding: 0 0 2 0; } +#configs_aws_region_select { + width: 70%; +} + /* This Horizontal wrapper container is necessary to make the config label and button align center */ diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index e9520bb0..c62e1d8f 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -11,7 +11,7 @@ from datashuttle import DataShuttle from datashuttle.configs import load_configs -from datashuttle.utils import ssh +from datashuttle.utils import aws, gdrive, rclone, ssh class Interface: @@ -493,3 +493,54 @@ def setup_key_pair_and_rclone_config( except BaseException as e: return False, str(e) + + # Setup Google Drive + # ---------------------------------------------------------------------------------- + + def setup_google_drive_connection( + self, + gdrive_client_secret: Optional[str] = None, + config_token: Optional[str] = None, + ) -> InterfaceOutput: + try: + self.project._setup_rclone_gdrive_config( + gdrive_client_secret, config_token, log=False + ) + rclone.check_successful_connection_and_raise_error_on_fail( + self.project.cfg + ) + return True, None + except BaseException as e: + return False, str(e) + + def get_rclone_message_for_gdrive_without_browser( + self, gdrive_client_secret: Optional[str] = None + ) -> InterfaceOutput: + try: + output = gdrive.preliminary_for_setup_without_browser( + self.project.cfg, + gdrive_client_secret, + self.project.cfg.get_rclone_config_name("gdrive"), + log=False, + ) + return True, output + except BaseException as e: + return False, str(e) + + # Setup AWS + # ---------------------------------------------------------------------------------- + + def setup_aws_connection( + self, aws_secret_access_key: str + ) -> InterfaceOutput: + try: + self.project._setup_rclone_aws_config( + aws_secret_access_key, log=False + ) + rclone.check_successful_connection_and_raise_error_on_fail( + self.project.cfg + ) + aws.raise_if_bucket_absent(self.project.cfg) + return True, None + except BaseException as e: + return False, str(e) diff --git a/datashuttle/tui/screens/setup_aws.py b/datashuttle/tui/screens/setup_aws.py new file mode 100644 index 00000000..8f429635 --- /dev/null +++ b/datashuttle/tui/screens/setup_aws.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from textual.app import ComposeResult + + from datashuttle.tui.interface import Interface + +from textual.containers import Container, Horizontal +from textual.screen import ModalScreen +from textual.widgets import Button, Input, Static + + +class SetupAwsScreen(ModalScreen): + """ + This dialog window handles the TUI equivalent of API's + `setup_aws_connection()`. This asks the user for confirmation + to proceed with the setup, and then prompts the user for the AWS Secret Access Key. + The secret access key is then used to setup rclone config for AWS S3. + Works similar to `SetupSshScreen`. + """ + + def __init__(self, interface: Interface) -> None: + super(SetupAwsScreen, self).__init__() + + self.interface = interface + self.stage = 0 + + def compose(self) -> ComposeResult: + yield Container( + Horizontal( + Static( + "Ready to setup AWS connection. Press OK to proceed", + id="setup_aws_messagebox_message", + ), + id="setup_aws_messagebox_message_container", + ), + Input(password=True, id="setup_aws_secret_access_key_input"), + Horizontal( + Button("OK", id="setup_aws_ok_button"), + Button("Cancel", id="setup_aws_cancel_button"), + id="setup_aws_buttons_horizontal", + ), + id="setup_aws_screen_container", + ) + + def on_mount(self) -> None: + self.query_one("#setup_aws_secret_access_key_input").visible = False + + def on_button_pressed(self, event: Button.Pressed) -> None: + """ """ + if event.button.id == "setup_aws_cancel_button": + self.dismiss() + + if event.button.id == "setup_aws_ok_button": + if self.stage == 0: + self.prompt_user_for_aws_secret_access_key() + + elif self.stage == 1: + self.use_secret_access_key_to_setup_aws_connection() + + elif self.stage == 2: + self.dismiss() + + def prompt_user_for_aws_secret_access_key(self) -> None: + message = "Please Enter your AWS Secret Access Key" + + self.query_one("#setup_aws_messagebox_message").update(message) + self.query_one("#setup_aws_secret_access_key_input").visible = True + + self.stage += 1 + + def use_secret_access_key_to_setup_aws_connection(self) -> None: + secret_access_key = self.query_one( + "#setup_aws_secret_access_key_input" + ).value + + success, output = self.interface.setup_aws_connection( + secret_access_key + ) + + if success: + message = "AWS Connection Successful!" + self.query_one("#setup_aws_secret_access_key_input").visible = ( + False + ) + + else: + message = ( + f"AWS setup failed. Please check your configs and secret access key" + f"\n\n Traceback: {output}" + ) + self.query_one("#setup_aws_secret_access_key_input").disabled = ( + True + ) + + self.query_one("#setup_aws_ok_button").label = "Finish" + self.query_one("#setup_aws_messagebox_message").update(message) + self.query_one("#setup_aws_cancel_button").disabled = True + self.stage += 1 diff --git a/datashuttle/tui/screens/setup_gdrive.py b/datashuttle/tui/screens/setup_gdrive.py new file mode 100644 index 00000000..cc11035f --- /dev/null +++ b/datashuttle/tui/screens/setup_gdrive.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from textual.app import ComposeResult + from textual.worker import Worker + + from datashuttle.tui.interface import Interface + from datashuttle.utils.custom_types import InterfaceOutput + +from textual import work +from textual.containers import Container, Horizontal +from textual.screen import ModalScreen +from textual.widgets import ( + Button, + Input, + Static, +) + + +class SetupGdriveScreen(ModalScreen): + """ + This dialog window handles the TUI equivalent of API's `setup_google_drive_connection`. + If the config contains a "gdrive_client_id", the user is prompted to enter a client secret. + Then, the user is asked if the their machine has access to a browser. If yes, a + google drive authentication page on the user's default browser opens up. Else, + the user is asked to run an rclone command on a machine with access to a browser and + input the config token generated by rclone. + """ + + def __init__(self, interface: Interface) -> None: + super(SetupGdriveScreen, self).__init__() + + self.interface = interface + self.stage: float = 0 + self.setup_worker: Worker | None = None + self.gdrive_client_secret: Optional[str] = None + + self.input_box: Input = Input( + id="setup_gdrive_generic_input_box", + placeholder="Enter value here", + ) + + def compose(self) -> ComposeResult: + yield Container( + Horizontal( + Static( + "Ready to setup Google Drive. " "Press OK to proceed", + id="gdrive_setup_messagebox_message", + ), + id="gdrive_setup_messagebox_message_container", + ), + Horizontal( + Button("OK", id="setup_gdrive_ok_button"), + Button("Cancel", id="setup_gdrive_cancel_button"), + id="setup_gdrive_buttons_horizontal", + ), + id="setup_gdrive_screen_container", + ) + + def on_button_pressed(self, event: Button.Pressed) -> None: + """ + This dialog window operates using 6 buttons. + + 1) "ok" button : Starts the connection setup process. It has an intermediate + step that asks the user for client secret if "gdrive_client_id" is present + in configs. And then proceeds to ask the user for browser availability. + This is done via a `stage` variable. If asking the user for client secret, + the stage is incremented by 0.5 on two steps. Else, `ask_user_for_browser` + increments the stage directly by 1. + + 2) "yes" button : A "yes" answer to the availability of browser question. On click, + proceeds to a browser authentication. + + 3) "no" button : A "no" answer to the availability of browser question. On click, + prompts the user to enter a config token by running an rclone command. + + 4) "enter" button : To enter the config token generated by rclone. + + 5) "finish" button : To finish the setup. + + 6) "cancel" button : To cancel the setup at any step before completion. + """ + if ( + event.button.id == "setup_gdrive_cancel_button" + or event.button.id == "setup_gdrive_finish_button" + ): + if self.setup_worker and self.setup_worker.is_running: + self.setup_worker.cancel() # fix + self.dismiss() + + elif event.button.id == "setup_gdrive_ok_button": + if self.stage == 0: + if self.interface.project.cfg["gdrive_client_id"]: + self.ask_user_for_gdrive_client_secret() + else: + self.ask_user_for_browser() + + elif self.stage == 0.5: + self.gdrive_client_secret = ( + self.input_box.value.strip() + if self.input_box.value.strip() + else None + ) + self.ask_user_for_browser() + + elif event.button.id == "setup_gdrive_yes_button": + self.open_browser_and_setup_gdrive_connection() + + elif event.button.id == "setup_gdrive_no_button": + self.prompt_user_for_config_token() + + elif event.button.id == "setup_gdrive_enter_button": + self.setup_gdrive_connection_using_config_token() + + def ask_user_for_gdrive_client_secret(self) -> None: + """ + Asks the user for google drive client secret. Only called if + the datashuttle config has a gdrive_client_id. + """ + message = ( + "Please provide the client secret for Google Drive. " + "You can find it in your Google Cloud Console." + ) + self.update_message_box_message(message) + + ok_button = self.query_one("#setup_gdrive_ok_button") + ok_button.label = "Enter" + + self.mount_input_box_before_buttons() + + self.stage += 0.5 + + def ask_user_for_browser(self) -> None: + """ + Asks the user if their machine has access to a browser. + """ + message = ( + "Are you running Datashuttle on a machine " + "that can open a web browser?" + ) + self.update_message_box_message(message) + + self.query_one("#setup_gdrive_ok_button").remove() + + # Remove the input box if it was mounted previously + if self.input_box.is_mounted: + self.input_box.remove() + + # Mount the Yes and No buttons + yes_button = Button("Yes", id="setup_gdrive_yes_button") + no_button = Button("No", id="setup_gdrive_no_button") + + # Mount a cancel button + self.query_one("#setup_gdrive_buttons_horizontal").mount( + yes_button, no_button, before="#setup_gdrive_cancel_button" + ) + + self.stage += 0.5 if self.stage == 0.5 else 1 + + def open_browser_and_setup_gdrive_connection(self) -> None: + """ + This removes the "yes" and "no" buttons that were asked during the + browser question and starts an asyncio task that sets up google drive + connection and updates the UI with success/failure. The connection setup + is asynchronous so that the user is able to cancel the setup if anything + goes wrong without quitting datashuttle altogether. + """ + message = "Please authenticate through browser." + self.update_message_box_message(message) + + # Remove the Yes and No buttons + self.query_one("#setup_gdrive_yes_button").remove() + self.query_one("#setup_gdrive_no_button").remove() + + asyncio.create_task(self.setup_gdrive_connection_and_update_ui()) + + def prompt_user_for_config_token(self) -> None: + + self.query_one("#setup_gdrive_yes_button").remove() + self.query_one("#setup_gdrive_no_button").remove() + + success, message = ( + self.interface.get_rclone_message_for_gdrive_without_browser( + self.gdrive_client_secret + ) + ) + + if not success: + self.display_failed(message) + return + + self.update_message_box_message( + message + "\nPress shift+click to copy." + ) + + enter_button = Button("Enter", id="setup_gdrive_enter_button") + self.input_box.value = "" + + self.query_one("#setup_gdrive_buttons_horizontal").mount( + enter_button, before="#setup_gdrive_cancel_button" + ) + self.mount_input_box_before_buttons() + + def setup_gdrive_connection_using_config_token(self) -> None: + """ + Disables the enter button and starts an asyncio task to setup + google drive connection and update the UI with success/failure. + """ + + self.input_box.disabled = True + + enter_button = self.query_one("#setup_gdrive_enter_button") + enter_button.disabled = True + + config_token = self.input_box.value.strip() + + asyncio.create_task( + self.setup_gdrive_connection_and_update_ui(config_token) + ) + + async def setup_gdrive_connection_and_update_ui( + self, config_token: Optional[str] = None + ) -> None: + """ + This starts the google drive connection setup in a separate thread (required + to have asynchronous processing) and awaits for its completion. After completion, + it displays a success / failure screen. + """ + worker = self.setup_gdrive_connection(config_token) + self.setup_worker = worker + if worker.is_running: + await worker.wait() + + if config_token: + enter_button = self.query_one("#setup_gdrive_enter_button") + enter_button.disabled = True + + success, output = worker.result + if success: + self.show_finish_screen() + else: + self.display_failed(output) + + @work(exclusive=True, thread=True) + def setup_gdrive_connection( + self, config_token: Optional[str] = None + ) -> Worker[InterfaceOutput]: + """ + This function runs in a worker thread to setup google drive connection. + If the user had access to a browser, the underlying rclone commands called + by this function are responsible for opening google's auth page to authenticate + with google drive. + """ + success, output = self.interface.setup_google_drive_connection( + self.gdrive_client_secret, config_token + ) + return success, output + + # ---------------------------------------------------------------------------------- + # UI Update Methods + # ---------------------------------------------------------------------------------- + + def show_finish_screen(self) -> None: + message = "Setup Complete!" + self.query_one("#setup_gdrive_cancel_button").remove() + + self.update_message_box_message(message) + self.query_one("#setup_gdrive_buttons_horizontal").mount( + Button("Finish", id="setup_gdrive_finish_button") + ) + + def display_failed(self, output) -> None: + message = ( + f"Google Drive setup failed. Please check your configs and client secret" + f"\n\n Traceback: {output}" + ) + self.update_message_box_message(message) + + def update_message_box_message(self, message: str) -> None: + self.query_one("#gdrive_setup_messagebox_message").update(message) + + def mount_input_box_before_buttons(self) -> None: + self.query_one("#setup_gdrive_screen_container").mount( + self.input_box, before="#setup_gdrive_buttons_horizontal" + ) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 974ee08a..8ea5ea8a 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -20,12 +20,19 @@ Label, RadioButton, RadioSet, + Select, Static, ) +from datashuttle.configs.aws_regions import get_aws_regions_list from datashuttle.tui.custom_widgets import ClickableInput from datashuttle.tui.interface import Interface -from datashuttle.tui.screens import modal_dialogs, setup_ssh +from datashuttle.tui.screens import ( + modal_dialogs, + setup_aws, + setup_gdrive, + setup_ssh, +) from datashuttle.tui.tooltips import get_tooltip @@ -90,6 +97,35 @@ def compose(self) -> ComposeResult: ), ] + self.config_gdrive_widgets = [ + Label("Client ID", id="configs_gdrive_client_id_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="Google Drive Client ID", + id="configs_gdrive_client_id_input", + ), + Label("Root Folder ID", id="configs_gdrive_root_folder_id_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="Google Drive Root Folder ID", + id="configs_gdrive_root_folder_id", + ), + ] + + self.config_aws_widgets = [ + Label("AWS Access Key ID", id="configs_aws_access_key_id_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="AWS Access Key ID eg. EJIBCLSIP2K2PQK3CDON", + id="configs_aws_access_key_id_input", + ), + Label("AWS S3 Region", id="configs_aws_region_label"), + Select( + ((region, region) for region in get_aws_regions_list()), + id="configs_aws_region_select", + ), + ] + config_screen_widgets = [ Label("Local Path", id="configs_local_path_label"), Horizontal( @@ -103,18 +139,32 @@ def compose(self) -> ComposeResult: ), Label("Connection Method", id="configs_connect_method_label"), RadioSet( + RadioButton( + "No connection (local only)", + id="configs_local_only_radiobutton", + ), RadioButton( "Local Filesystem", - id="configs_local_filesystem_radiobutton", + id=self.radiobutton_id_from_connection_method( + "local_filesystem" + ), ), - RadioButton("SSH", id="configs_ssh_radiobutton"), RadioButton( - "No connection (local only)", - id="configs_local_only_radiobutton", + "SSH", id=self.radiobutton_id_from_connection_method("ssh") + ), + RadioButton( + "Google Drive", + id=self.radiobutton_id_from_connection_method("gdrive"), + ), + RadioButton( + "AWS S3", + id=self.radiobutton_id_from_connection_method("aws"), ), id="configs_connect_method_radioset", ), *self.config_ssh_widgets, + *self.config_gdrive_widgets, + *self.config_aws_widgets, Label("Central Path", id="configs_central_path_label"), Horizontal( ClickableInput( @@ -127,9 +177,12 @@ def compose(self) -> ComposeResult: ), Horizontal( Button("Save", id="configs_save_configs_button"), - Button( - "Setup SSH Connection", - id="configs_setup_ssh_connection_button", + Horizontal( + Button( + "Setup Button", + id="configs_setup_connection_button", + ), + id="setup_buttons_container", ), # Below button is always hidden when accessing # configs from project manager screen @@ -184,14 +237,16 @@ def on_mount(self) -> None: self.query_one("#configs_go_to_project_screen_button").visible = False if self.interface: self.fill_widgets_with_project_configs() + self.setup_widgets_to_display( + connection_method=self.interface.get_configs()[ + "connection_method" + ] + ) else: self.query_one("#configs_local_filesystem_radiobutton").value = ( True ) - self.switch_ssh_widgets_display(display_ssh=False) - self.query_one("#configs_setup_ssh_connection_button").visible = ( - False - ) + self.setup_widgets_to_display(connection_method="local_filesystem") # Setup tooltips if not self.interface: @@ -218,6 +273,8 @@ def on_mount(self) -> None: "#configs_local_only_radiobutton", "#configs_central_host_username_input", "#configs_central_host_id_input", + "#configs_gdrive_client_id_input", + "#configs_gdrive_root_folder_id", ]: self.query_one(id).tooltip = get_tooltip(id) @@ -233,29 +290,63 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: disabled. """ label = str(event.pressed.label) + radiobutton_id = event.pressed.id + + connection_method = self.connection_method_from_radiobutton_id( + radiobutton_id + ) + assert label in [ "SSH", "Local Filesystem", "No connection (local only)", + "Google Drive", + "AWS S3", ], "Unexpected label." - if label == "No connection (local only)": - self.query_one("#configs_central_path_input").value = "" - self.query_one("#configs_central_path_input").disabled = True - self.query_one("#configs_central_path_select_button").disabled = ( - True - ) - display_ssh = False - else: - self.query_one("#configs_central_path_input").disabled = False - self.query_one("#configs_central_path_select_button").disabled = ( - False - ) - display_ssh = True if label == "SSH" else False + connection_method = self.connection_method_from_radiobutton_id( + radiobutton_id + ) + display_ssh = ( + True if connection_method == "ssh" else False + ) # temporarily, for tooltips + + if self.interface: + self.fill_inputs_with_project_configs() + + self.setup_widgets_to_display(connection_method) - self.switch_ssh_widgets_display(display_ssh) self.set_central_path_input_tooltip(display_ssh) + def radiobutton_id_from_connection_method( + self, connection_method: str + ) -> str: + return f"configs_{connection_method}_radiobutton" + + def connection_method_from_radiobutton_id( + self, radiobutton_id: str + ) -> str | None: + """ + Get the connection method from the radiobutton id. + """ + assert radiobutton_id.startswith("configs_") + assert radiobutton_id.endswith("_radiobutton") + + connection_string = radiobutton_id[ + len("configs_") : -len("_radiobutton") + ] + return ( + connection_string + if connection_string + in [ + "ssh", + "gdrive", + "aws", + "local_filesystem", + ] + else None + ) + def set_central_path_input_tooltip(self, display_ssh: bool) -> None: """ Use a different tooltip depending on whether connection method @@ -305,19 +396,6 @@ def switch_ssh_widgets_display(self, display_ssh: bool) -> None: for widget in self.config_ssh_widgets: widget.display = display_ssh - self.query_one("#configs_central_path_select_button").display = ( - not display_ssh - ) - - if self.interface is None: - self.query_one("#configs_setup_ssh_connection_button").visible = ( - False - ) - else: - self.query_one("#configs_setup_ssh_connection_button").visible = ( - display_ssh - ) - if not self.query_one("#configs_central_path_input").value: if display_ssh: placeholder = f"e.g. {self.get_platform_dependent_example_paths('central', ssh=True)}" @@ -327,6 +405,14 @@ def switch_ssh_widgets_display(self, display_ssh: bool) -> None: placeholder ) + def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: + for widget in self.config_gdrive_widgets: + widget.display = display_gdrive + + def switch_aws_widgets_display(self, display_aws: bool) -> None: + for widget in self.config_aws_widgets: + widget.display = display_aws + def on_button_pressed(self, event: Button.Pressed) -> None: """ Enables the Create Folders button to read out current input values @@ -338,8 +424,21 @@ def on_button_pressed(self, event: Button.Pressed) -> None: else: self.setup_configs_for_an_existing_project() - elif event.button.id == "configs_setup_ssh_connection_button": - self.setup_ssh_connection() + elif event.button.id == "configs_setup_connection_button": + assert ( + self.interface is not None + ), "type narrow flexible `interface`" + + connection_method = self.interface.get_configs()[ + "connection_method" + ] + + if connection_method == "ssh": + self.setup_ssh_connection() + elif connection_method == "gdrive": + self.setup_gdrive_connection() + elif connection_method == "aws": + self.setup_aws_connection() elif event.button.id == "configs_go_to_project_screen_button": self.parent_class.dismiss(self.interface) @@ -409,6 +508,37 @@ def setup_ssh_connection(self) -> None: setup_ssh.SetupSshScreen(self.interface) ) + def setup_gdrive_connection(self) -> None: + """ + Set up the `SetupGdriveScreen` screen, + """ + assert self.interface is not None, "type narrow flexible `interface`" + + if not self.widget_configs_match_saved_configs(): + self.parent_class.mainwindow.show_modal_error_dialog( + "The values set above must equal the datashuttle settings. " + "Either press 'Save' or reload this page." + ) + return + + self.parent_class.mainwindow.push_screen( + setup_gdrive.SetupGdriveScreen(self.interface) + ) + + def setup_aws_connection(self) -> None: + assert self.interface is not None, "type narrow flexible `interface`" + + if not self.widget_configs_match_saved_configs(): + self.parent_class.mainwindow.show_modal_error_dialog( + "The values set above must equal the datashuttle settings. " + "Either press 'Save' or reload this page." + ) + return + + self.parent_class.mainwindow.push_screen( + setup_aws.SetupAwsScreen(self.interface) + ) + def widget_configs_match_saved_configs(self): """ Check that the configs currently stored in the widgets @@ -460,23 +590,29 @@ def setup_configs_for_a_new_project(self) -> None: True ) + # A message template to display custom message to user according to the chosen connection method + message_template = ( + "A datashuttle project has now been created.\n\n " + "Next, setup the {method_name} connection. Once complete, navigate to the " + "'Main Menu' and proceed to the project page, where you will be " + "able to create and transfer project folders." + ) + # Could not find a neater way to combine the push screen # while initiating the callback in one case but not the other. - if cfg_kwargs["connection_method"] == "ssh": + connection_method = cfg_kwargs["connection_method"] - self.query_one( - "#configs_setup_ssh_connection_button" - ).visible = True - self.query_one( - "#configs_setup_ssh_connection_button" - ).disabled = False + # To trigger the appearance of "Setup connection" button + self.setup_widgets_to_display(connection_method) - message = ( - "A datashuttle project has now been created.\n\n " - "Next, setup the SSH connection. Once complete, navigate to the " - "'Main Menu' and proceed to the project page, where you will be " - "able to create and transfer project folders." - ) + if connection_method == "ssh": + message = message_template.format(method_name="SSH") + + elif connection_method == "gdrive": + message = message_template.format(method_name="Google Drive") + + elif connection_method == "aws": + message = message_template.format(method_name="AWS") else: message = ( @@ -505,7 +641,7 @@ def setup_configs_for_an_existing_project(self) -> None: # Handle the edge case where connection method is changed after # saving on the 'Make New Project' screen. - self.query_one("#configs_setup_ssh_connection_button").visible = True + # self.query_one("#configs_setup_ssh_connection_button").visible = True cfg_kwargs = self.get_datashuttle_inputs_from_widgets() @@ -520,6 +656,8 @@ def setup_configs_for_an_existing_project(self) -> None: ), lambda unused: self.post_message(self.ConfigsSaved()), ) + # To trigger the appearance of "Setup connection" button + self.setup_widgets_to_display(cfg_kwargs["connection_method"]) else: self.parent_class.mainwindow.show_modal_error_dialog(output) @@ -529,24 +667,15 @@ def fill_widgets_with_project_configs(self) -> None: widgets with the current project configs. This in some instances requires recasting to a new type of changing the value. - In the case of the `connection_method` widget, the associated - "ssh" widgets are hidden / displayed based on the current setting, - in `self.switch_ssh_widgets_display()`. + In the case of the `connection_method` widget, the associated connection + method radio button is hidden / displayed based on the current settings. + This change of radio button triggers `on_radio_set_changed` which displays + the appropriate connection method widgets. """ assert self.interface is not None, "type narrow flexible `interface`" cfg_to_load = self.interface.get_textual_compatible_project_configs() - # Local Path - input = self.query_one("#configs_local_path_input") - input.value = cfg_to_load["local_path"] - - # Central Path - input = self.query_one("#configs_central_path_input") - input.value = ( - cfg_to_load["central_path"] if cfg_to_load["central_path"] else "" - ) - # Connection Method # Make a dict of radiobutton: is on bool to easily find # how to set radiobuttons and associated configs @@ -556,6 +685,10 @@ def fill_widgets_with_project_configs(self) -> None: cfg_to_load["connection_method"] == "ssh", "configs_local_filesystem_radiobutton": cfg_to_load["connection_method"] == "local_filesystem", + "configs_gdrive_radiobutton": + cfg_to_load["connection_method"] == "gdrive", + "configs_aws_radiobutton": + cfg_to_load["connection_method"] == "aws", "configs_local_only_radiobutton": cfg_to_load["connection_method"] is None, } @@ -564,8 +697,26 @@ def fill_widgets_with_project_configs(self) -> None: for id, value in what_radiobuton_is_on.items(): self.query_one(f"#{id}").value = value - self.switch_ssh_widgets_display( - display_ssh=what_radiobuton_is_on["configs_ssh_radiobutton"] + self.fill_inputs_with_project_configs() + + def fill_inputs_with_project_configs(self) -> None: + """ + This fills the input widgets with the current project configs. It is + used while setting up widgets for the project while mounting the current + tab and also to repopulate input widgets when the radio buttons change. + """ + assert self.interface is not None, "type narrow flexible `interface`" + + cfg_to_load = self.interface.get_textual_compatible_project_configs() + + # Local Path + input = self.query_one("#configs_local_path_input") + input.value = cfg_to_load["local_path"] + + # Central Path + input = self.query_one("#configs_central_path_input") + input.value = ( + cfg_to_load["central_path"] if cfg_to_load["central_path"] else "" ) # Central Host ID @@ -586,6 +737,113 @@ def fill_widgets_with_project_configs(self) -> None: ) input.value = value + # Google Drive Client ID + input = self.query_one("#configs_gdrive_client_id_input") + value = ( + "" + if cfg_to_load["gdrive_client_id"] is None + else cfg_to_load["gdrive_client_id"] + ) + input.value = value + + # Google Drive Root Folder ID + input = self.query_one("#configs_gdrive_root_folder_id") + value = ( + "" + if cfg_to_load["gdrive_root_folder_id"] is None + else cfg_to_load["gdrive_root_folder_id"] + ) + input.value = value + + # AWS Access Key ID + input = self.query_one("#configs_aws_access_key_id_input") + value = ( + "" + if cfg_to_load["aws_access_key_id"] is None + else cfg_to_load["aws_access_key_id"] + ) + input.value = value + + # AWS S3 Region + select = self.query_one("#configs_aws_region_select") + value = ( + Select.BLANK + if cfg_to_load["aws_region"] is None + else cfg_to_load["aws_region"] + ) + select.value = value + + def setup_widgets_to_display(self, connection_method: str | None) -> None: + """ + Sets up widgets to display based on the chosen `connection_method` on the + radiobutton. The widgets pertaining to the chosen connection method will be + be displayed. This is done by dedicated functions for each connection method + which display widgets on receiving a `True` flag. + + Also, this function handles other TUI changes like displaying "setup connection" + button, disabling central path input in a local only project, etc. + + Called on mount, on radiobuttons' switch and upon saving project configs. + """ + if connection_method: + assert connection_method in [ + "local_filesystem", + "ssh", + "gdrive", + "aws", + ], "Unexpected Connection Method" + + # Connection specific widgets + connection_widget_display_functions = { + "ssh": self.switch_ssh_widgets_display, + "gdrive": self.switch_gdrive_widgets_display, + "aws": self.switch_aws_widgets_display, + } + + for name, widget_func in connection_widget_display_functions.items(): + if connection_method == name: + widget_func(True) + else: + widget_func(False) + + has_connection_method = connection_method is not None + + # Central path input + self.query_one("#configs_central_path_input").disabled = ( + not has_connection_method + ) + self.query_one("#configs_central_path_select_button").disabled = ( + not has_connection_method + ) + + # Local only project + if not has_connection_method: + self.query_one("#configs_central_path_input").value = "" + + setup_connection_button = self.query_one( + "#configs_setup_connection_button" + ) + + # fmt: off + # Setup connection button + if ( + not connection_method + or connection_method == "local_filesystem" + or not self.interface + or connection_method != self.interface.get_configs()["connection_method"] + ): + setup_connection_button.visible = False + # fmt: on + else: + setup_connection_button.visible = True + + if connection_method == "ssh": + setup_connection_button.label = "Setup SSH Connection" + elif connection_method == "gdrive": + setup_connection_button.label = "Setup Google Drive Connection" + elif connection_method == "aws": + setup_connection_button.label = "Setup AWS Connection" + def get_datashuttle_inputs_from_widgets(self) -> Dict: """ Get the configs to pass to `make_config_file()` from @@ -605,30 +863,67 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: else: cfg_kwargs["central_path"] = Path(central_path_value) - if self.query_one("#configs_ssh_radiobutton").value: - connection_method = "ssh" + for id in [ + "configs_local_filesystem_radiobutton", + "configs_ssh_radiobutton", + "configs_gdrive_radiobutton", + "configs_aws_radiobutton", + "configs_local_only_radiobutton", + ]: + if self.query_one("#" + id).value: + connection_method = self.connection_method_from_radiobutton_id( + id + ) + break - elif self.query_one("#configs_local_filesystem_radiobutton").value: - connection_method = "local_filesystem" + cfg_kwargs["connection_method"] = connection_method - elif self.query_one("#configs_local_only_radiobutton").value: - connection_method = None + # SSH specific + if connection_method == "ssh": + cfg_kwargs["central_host_id"] = ( + self.get_config_value_from_input_value( + "#configs_central_host_id_input" + ) + ) - cfg_kwargs["connection_method"] = connection_method + cfg_kwargs["central_host_username"] = ( + self.get_config_value_from_input_value( + "#configs_central_host_username_input" + ) + ) - central_host_id = self.query_one( - "#configs_central_host_id_input" - ).value - cfg_kwargs["central_host_id"] = ( - None if central_host_id == "" else central_host_id - ) + # Google Drive specific + elif connection_method == "gdrive": + cfg_kwargs["gdrive_client_id"] = ( + self.get_config_value_from_input_value( + "#configs_gdrive_client_id_input" + ) + ) - central_host_username = self.query_one( - "#configs_central_host_username_input" - ).value + cfg_kwargs["gdrive_root_folder_id"] = ( + self.get_config_value_from_input_value( + "#configs_gdrive_root_folder_id" + ) + ) - cfg_kwargs["central_host_username"] = ( - None if central_host_username == "" else central_host_username - ) + # AWS specific + elif connection_method == "aws": + cfg_kwargs["aws_access_key_id"] = ( + self.get_config_value_from_input_value( + "#configs_aws_access_key_id_input" + ) + ) + + aws_region = self.query_one("#configs_aws_region_select").value + cfg_kwargs["aws_region"] = ( + None if aws_region == Select.BLANK else aws_region + ) return cfg_kwargs + + def get_config_value_from_input_value( + self, input_box_selector: str + ) -> str | None: + input_value = self.query_one(input_box_selector).value + + return None if input_value == "" else input_value diff --git a/datashuttle/tui/tooltips.py b/datashuttle/tui/tooltips.py index e8d0c771..287b7226 100644 --- a/datashuttle/tui/tooltips.py +++ b/datashuttle/tui/tooltips.py @@ -65,6 +65,23 @@ def get_tooltip(id: str) -> str: "to a project folder, possibly on a mounted drive.\n\n" ) + # Google Drive configs + # ------------------------------------------------------------------------- + + # Google Drive Client ID + elif id == "#configs_gdrive_client_id_input": + tooltip = ( + "The Google Drive Client ID to use for authentication.\n\n" + "It can be obtained by creating an OAuth 2.0 client in the Google Cloud Console.\n\n" + "Can be left empty to use rclone's default client (slower)" + ) + + elif id == "#configs_gdrive_root_folder_id": + tooltip = ( + "The Google Drive root folder ID to use for transfer.\n\n" + "It can be obtained by navigating to the folder in Google Drive and copying the ID from the URL.\n\n" + ) + # Settings # ------------------------------------------------------------------------- diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py new file mode 100644 index 00000000..aebaf225 --- /dev/null +++ b/datashuttle/utils/aws.py @@ -0,0 +1,49 @@ +import json + +from datashuttle.configs.config_class import Configs +from datashuttle.utils import rclone, utils +from datashuttle.utils.custom_exceptions import ConfigError + + +def check_if_aws_bucket_exists(cfg: Configs) -> bool: + output = rclone.call_rclone( + f"lsjson {cfg.get_rclone_config_name()}:", pipe_std=True + ) + + files_and_folders = json.loads(output.stdout) + + names = list(map(lambda x: x.get("Name", None), files_and_folders)) + + bucket_name = cfg["central_path"].as_posix().strip("/").split("/")[0] + + if bucket_name not in names: + return False + + return True + + +def raise_if_bucket_absent(cfg: Configs) -> None: + if not check_if_aws_bucket_exists(cfg): + bucket_name = cfg["central_path"].as_posix().strip("/").split("/")[0] + utils.log_and_raise_error( + f'The bucket "{bucket_name}" does not exist.\n' + f"For data transfer to happen, the bucket must exist.\n" + f"Please change the bucket name in the `central_path`.", + ConfigError, + ) + + +# ----------------------------------------------------------------------------- +# For Python API +# ----------------------------------------------------------------------------- + + +def get_aws_secret_access_key(log: bool = True) -> str: + aws_secret_access_key = utils.get_connection_secret_from_user( + connection_method_name="AWS", + key_name_full="AWS secret access key", + key_name_short="secret key", + log_status=log, + ) + + return aws_secret_access_key.strip() diff --git a/datashuttle/utils/decorators.py b/datashuttle/utils/decorators.py index cacf5491..0bb64c20 100644 --- a/datashuttle/utils/decorators.py +++ b/datashuttle/utils/decorators.py @@ -28,6 +28,27 @@ def wrapper(*args, **kwargs): return wrapper +def requires_aws_configs(func): + + @wraps(func) + def wrapper(*args, **kwargs): + if ( + not args[0].cfg["aws_access_key_id"] + or not args[0].cfg["aws_region"] + ): + log_and_raise_error( + "Cannot setup AWS connection, 'aws_access_key_id' " + "or 'aws_region' is not set in the " + "configuration file", + ConfigError, + ) + + else: + return func(*args, **kwargs) + + return wrapper + + def check_configs_set(func): """ Check that configs have been loaded (i.e. diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index 56852640..b42478b0 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import ( TYPE_CHECKING, Any, @@ -20,7 +21,7 @@ from pathlib import Path from datashuttle.configs import canonical_folders, canonical_tags -from datashuttle.utils import ssh, utils, validation +from datashuttle.utils import rclone, ssh, utils, validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError # ----------------------------------------------------------------------------- @@ -515,14 +516,27 @@ def search_for_folders( verbose : If `True`, when a search folder cannot be found, a message will be printed with the missing path. """ - if local_or_central == "central" and cfg["connection_method"] == "ssh": - all_folder_names, all_filenames = ssh.search_ssh_central_for_folders( - search_path, - search_prefix, - cfg, - verbose, - return_full_path, - ) + if local_or_central == "central" and cfg["connection_method"] in [ + "ssh", + "gdrive", + "aws", + ]: + if cfg["connection_method"] == "ssh": + all_folder_names, all_filenames = ( + ssh.search_ssh_central_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) + ) + + else: + all_folder_names, all_filenames = search_gdrive_or_aws_for_folders( + search_path, search_prefix, cfg, return_full_path + ) + else: if not search_path.exists(): if verbose: @@ -537,6 +551,58 @@ def search_for_folders( return all_folder_names, all_filenames +def search_gdrive_or_aws_for_folders( + search_path: Path, + search_prefix: str, + cfg: Configs, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """ + Searches for files and folders in central path using `rclone lsjson` command. + This command lists all the files and folders in the central path in a json format. + The json contains file/folder info about each file/folder like name, type, etc. + """ + + output = rclone.call_rclone( + "lsjson " + f"{cfg.get_rclone_config_name()}:{search_path.as_posix()} " + f'--include "{search_prefix}"', + pipe_std=True, + ) + + all_folder_names: List[str] = [] + all_filenames: List[str] = [] + + if output.returncode != 0: + utils.log_and_message( + f"Error searching files at {search_path.as_posix()} \n {output.stderr.decode('utf-8') if output.stderr else ''}" + ) + return all_folder_names, all_filenames + + files_and_folders = json.loads(output.stdout) + + try: + for file_or_folder in files_and_folders: + name = file_or_folder["Name"] + is_dir = file_or_folder.get("IsDir", False) + + to_append = ( + (search_path / name).as_posix() if return_full_path else name + ) + + if is_dir: + all_folder_names.append(to_append) + else: + all_filenames.append(to_append) + + except Exception: + utils.log_and_message( + f"Error searching files at {search_path.as_posix()}" + ) + + return all_folder_names, all_filenames + + # Actual function implementation def search_filesystem_path_for_folders( search_path_with_prefix: Path, return_full_path: bool = False diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py new file mode 100644 index 00000000..bd8df990 --- /dev/null +++ b/datashuttle/utils/gdrive.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from datashuttle.configs.config_class import Configs + +import json + +from datashuttle.utils import rclone, utils + +# ----------------------------------------------------------------------------- +# Helper Functions +# ----------------------------------------------------------------------------- + +# These functions are used by both API and TUI for setting up connections to google drive. + + +def preliminary_for_setup_without_browser( + cfg: Configs, + gdrive_client_secret: str | None, + rclone_config_name: str, + log: bool = True, +) -> str: + """ + This function prepares the rclone configuration for Google Drive without using a browser. + + The `config_is_local=false` flag tells rclone that the configuration process is being run + on a headless machine which does not have access to a browser. + + The `--non-interactive` flag is used to control rclone's behaviour while running it through + external applications. An `rclone config create` command would assume default values for config + variables in an interactive mode. If the `--non-interactive` flag is provided and rclone needs + the user to input some detail, a JSON blob will be returned with the question in it. For this + particular setup, rclone outputs a command for user to run on a machine with a browser. + + This function runs `rclone config create` with the user credentials and returns the rclone's output info. + This output info is presented to the user while asking for a `config_token`. + + Next, the user will run rclone's given command, authenticate with google drive and input the + config token given by rclone for datashuttle to proceed with the setup. + """ + client_id_key_value = ( + f"client_id {cfg['gdrive_client_id']} " + if cfg["gdrive_client_id"] + else " " + ) + client_secret_key_value = ( + f"client_secret {gdrive_client_secret} " + if gdrive_client_secret + else "" + ) + output = rclone.call_rclone( + f"config create " + f"{rclone_config_name} " + f"drive " + f"{client_id_key_value}" + f"{client_secret_key_value}" + f"scope drive " + f"root_folder_id {cfg['gdrive_root_folder_id']} " + f"config_is_local=false " + f"--non-interactive", + pipe_std=True, + ) + + # TODO: make this more robust + # extracting rclone's message from the json + output_json = json.loads(output.stdout) + message = output_json["Option"]["Help"] + + if log: + utils.log(message) + + return message + + +# ----------------------------------------------------------------------------- +# Python API +# ----------------------------------------------------------------------------- + + +def ask_user_for_browser(log: bool = True) -> bool: + message = "Are you running Datashuttle on a machine with access to a web browser? (y/n): " + input_ = utils.get_user_input(message).lower() + + while input_ not in ["y", "n"]: + utils.print_message_to_user("Invalid input. Press either 'y' or 'n'.") + input_ = utils.get_user_input(message).lower() + + answer = input_ == "y" + + if log: + utils.log(message) + utils.log(f"User answer: {answer}") + + return answer + + +def prompt_and_get_config_token( + cfg: Configs, + gdrive_client_secret: str | None, + rclone_config_name: str, + log: bool = True, +) -> str: + """ + This function presents the rclone's output/message to ask the user to run a command, authenticate + with google drive and input the `config_token` generated by rclone. The `config_token` is + then used to complete rclone's config setup for google drive. + """ + message = preliminary_for_setup_without_browser( + cfg, gdrive_client_secret, rclone_config_name, log=log + ) + input_ = utils.get_user_input(message).strip() + + return input_ + + +def get_client_secret(log: bool = True) -> str: + gdrive_client_secret = utils.get_connection_secret_from_user( + connection_method_name="Google Drive", + key_name_full="Google Drive client secret", + key_name_short="secret key", + log_status=log, + ) + + return gdrive_client_secret.strip() diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 49d7da82..35123599 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -1,14 +1,20 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, List, Literal, Optional + +if TYPE_CHECKING: + from pathlib import Path + + from datashuttle.configs.config_class import Configs + from datashuttle.utils.custom_types import TopLevelFolder + import os import platform import subprocess import tempfile -from pathlib import Path from subprocess import CompletedProcess -from typing import Dict, List, Literal -from datashuttle.configs.config_class import Configs from datashuttle.utils import utils -from datashuttle.utils.custom_types import TopLevelFolder def call_rclone(command: str, pipe_std: bool = False) -> CompletedProcess: @@ -150,6 +156,139 @@ def setup_rclone_config_for_ssh( log_rclone_config_output() +def setup_rclone_config_for_gdrive( + cfg: Configs, + rclone_config_name: str, + gdrive_client_secret: str | None, + service_account_filepath: Optional[str] = None, + log: bool = True, +): + """ + Sets up rclone config for connections to Google Drive. This must + contain the `gdrive_root_folder_id` and optionally a `gdrive_client_id` + which also mandates for the presence of a Google Drive client secret. + + Parameters + ---------- + + cfg : Configs + datashuttle configs UserDict. + + rclone_config_name : rclone config name + canonical config name, generated by + datashuttle.cfg.get_rclone_config_name() + + gdrive_client_secret : Google Drive client secret, mandatory when + using a Google Drive client. + + config_token : a token to setup rclone config without opening a browser, + needed if the user's machine doesn't have access to a browser. + + log : whether to log, if True logger must already be initialised. + """ + client_id_key_value = ( + f"client_id {cfg['gdrive_client_id']} " + if cfg["gdrive_client_id"] + else " " + ) + client_secret_key_value = ( + f"client_secret {gdrive_client_secret} " + if gdrive_client_secret + else "" + ) + + # extra_args = ( + # f"config_is_local=false config_token={config_token}" + # if config_token + # else "" + # ) + + service_account_filepath_arg = ( + "" + if service_account_filepath is None + else f"service_account_file {service_account_filepath} " + ) + output = call_rclone( + f"config create " + f"{rclone_config_name} " + f"drive " + f"{client_id_key_value}" + f"{client_secret_key_value}" + f"scope drive " + f"root_folder_id {cfg['gdrive_root_folder_id']} " + f"{service_account_filepath_arg}", + # f"{extra_args}", + pipe_std=True, + ) + + if output.returncode != 0: + utils.log_and_raise_error( + output.stderr.decode("utf-8"), ConnectionError + ) + + if log: + log_rclone_config_output() + + +def setup_rclone_config_for_aws( + cfg: Configs, + rclone_config_name: str, + aws_secret_access_key: str, + log: bool = True, +): + """ + Sets up rclone config for connections to AWS S3 buckets. This must + contain the `aws_access_key_id` and `aws_region`. + + Parameters + ---------- + + cfg : Configs + datashuttle configs UserDict. + + rclone_config_name : rclone config name + canonical config name, generated by + datashuttle.cfg.get_rclone_config_name() + + aws_secret_access_key : the aws secret access key provided by the user. + + log : whether to log, if True logger must already be initialised. + """ + output = call_rclone( + "config create " + f"{rclone_config_name} " + "s3 provider AWS " + f"access_key_id {cfg['aws_access_key_id']} " + f"secret_access_key {aws_secret_access_key} " + f"region {cfg['aws_region']} " + f"location_constraint {cfg['aws_region']}", + pipe_std=True, + ) + + if output.returncode != 0: + utils.log_and_raise_error( + output.stderr.decode("utf-8"), ConnectionError + ) + + if log: + log_rclone_config_output() + + +def check_successful_connection_and_raise_error_on_fail(cfg: Configs) -> None: + """ + Check for a successful connection by executing an `ls` command. It pings the + the central host to list files and folders in the root directory. + If the command fails, it raises a ConnectionError with the error message. + """ + + output = call_rclone(f"ls {cfg.get_rclone_config_name()}:", pipe_std=True) + + if output.returncode != 0: + utils.log_and_raise_error( + output.stderr.decode("utf-8"), ConnectionError + ) + + def log_rclone_config_output(): output = call_rclone("config file", pipe_std=True) utils.log( diff --git a/datashuttle/utils/utils.py b/datashuttle/utils/utils.py index 87a39e5c..862e7bb4 100644 --- a/datashuttle/utils/utils.py +++ b/datashuttle/utils/utils.py @@ -1,6 +1,8 @@ from __future__ import annotations +import getpass import re +import sys import traceback import warnings from typing import TYPE_CHECKING, Any, List, Literal, Union, overload @@ -87,6 +89,44 @@ def get_user_input(message: str) -> str: return input_ +def get_connection_secret_from_user( + connection_method_name: str, + key_name_full: str, + key_name_short: str, + log_status: bool, +) -> str: + if not sys.stdin.isatty(): + proceed = input( + f"\nWARNING!\nThe next step is to enter a {key_name_full}, but it is not possible\n" + f"to hide your {key_name_short} while entering it in the current terminal.\n" + f"This can occur if running the command in an IDE.\n\n" + f"Press 'y' to proceed to {key_name_short} entry. " + f"The characters will not be hidden!\n" + f"Alternatively, run {connection_method_name} setup after starting Python in your " + f"system terminal \nrather than through an IDE: " + ) + if proceed != "y": + print_message_to_user( + f"Quitting {connection_method_name} setup as 'y' not pressed." + ) + log_and_raise_error( + f"{connection_method_name} setup aborted by user.", + ConnectionAbortedError, + ) + + input_ = input( + f"Please enter your {key_name_full}. Characters will not be hidden: " + ) + + else: + input_ = getpass.getpass(f"Please enter your {key_name_full}: ") + + if log_status: + log(f"{key_name_full} entered by user.") + + return input_ + + # ----------------------------------------------------------------------------- # Paths # ----------------------------------------------------------------------------- diff --git a/tests/tests_integration/test_aws_connection.py b/tests/tests_integration/test_aws_connection.py new file mode 100644 index 00000000..e50ce9c8 --- /dev/null +++ b/tests/tests_integration/test_aws_connection.py @@ -0,0 +1,64 @@ +import builtins +import os +import random +import string + +import pytest +import test_utils +from base import BaseTest + +from datashuttle.configs.canonical_configs import get_broad_datatypes +from datashuttle.utils import rclone + + +@pytest.mark.skipif(os.getenv("CI") is None, reason="Only runs in CI") +class TestGoogleDriveGithubCI(BaseTest): + + def test_google_drive_connection(self, no_cfg_project, tmp_path): + + central_path = f"test-datashuttle/test-id-{''.join(random.choices(string.digits, k=15))}" + + aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"] + aws_access_key_id_secret = os.environ["AWS_ACCESS_KEY_ID_SECRET"] + aws_region = os.environ["AWS_REGION"] + + no_cfg_project.make_config_file( + local_path=str(tmp_path), # any temp location TODO UPDATE + connection_method="aws", + central_path=central_path, + aws_access_key_id=aws_access_key_id, + aws_region=aws_region, + ) + + state = {"first": True} + + def mock_input(_: str) -> str: + if state["first"]: + state["first"] = False + return "y" + return aws_access_key_id_secret + + original_input = builtins.input + builtins.input = mock_input + + no_cfg_project.setup_aws_connection() # TODO: check that the connection method is correct for these funcs + + builtins.input = original_input + + subs, sessions = test_utils.get_default_sub_sessions_to_test() + + test_utils.make_and_check_local_project_folders( + no_cfg_project, "rawdata", subs, sessions, get_broad_datatypes() + ) + + no_cfg_project.upload_entire_project() + + # get a list of files on gdrive and check they are as expected + # assert the test id if its failed + + # only tidy up if as expected, otherwise we can leave the folder there to have a look + # and delete manually later + + rclone.call_rclone( + f"purge central_{no_cfg_project.project_name}_aws:{central_path}" + ) diff --git a/tests/tests_integration/test_gdrive_connection.py b/tests/tests_integration/test_gdrive_connection.py new file mode 100644 index 00000000..2c89bfd3 --- /dev/null +++ b/tests/tests_integration/test_gdrive_connection.py @@ -0,0 +1,64 @@ +import builtins +import os +import random +import string + +import pytest +import test_utils +from base import BaseTest + +from datashuttle.configs.canonical_configs import get_broad_datatypes +from datashuttle.utils import rclone + + +@pytest.mark.skipif(os.getenv("CI") is None, reason="Only runs in CI") +class TestGoogleDriveGithubCI(BaseTest): + + def test_google_drive_connection(self, no_cfg_project, tmp_path): + + central_path = ( + f"test-id-{''.join(random.choices(string.digits, k=15))}" + ) + + root_id = os.environ["GDRIVE_ROOT_FOLDER_ID"] + sa_path = os.environ["GDRIVE_SERVICE_ACCOUNT_FILE"] + + no_cfg_project.make_config_file( + local_path=str(tmp_path), # any temp location TODO UPDATE + connection_method="gdrive", + central_path=central_path, + gdrive_root_folder_id=root_id, + gdrive_client_id=None, # keep None + ) + + state = {"first": True} + + def mock_input(_: str) -> str: + if state["first"]: + state["first"] = False + return "n" + return sa_path + + original_input = builtins.input + builtins.input = mock_input + + no_cfg_project.setup_google_drive_connection() + + builtins.input = original_input + + subs, sessions = test_utils.get_default_sub_sessions_to_test() + + test_utils.make_and_check_local_project_folders( + no_cfg_project, "rawdata", subs, sessions, get_broad_datatypes() + ) + + no_cfg_project.upload_entire_project() + + # get a list of files on gdrive and check they are as expected + # assert the test id if its failed + + # only tidy up if as expected, otherwise we can leave the folder there to have a look + # and delete manually later + rclone.call_rclone( + f"purge central_{no_cfg_project.project_name}_gdrive:{central_path}" + )