From 83ac1f4d5721a679ab297a75f588f601245192a4 Mon Sep 17 00:00:00 2001 From: shrey Date: Tue, 25 Mar 2025 19:10:48 +0530 Subject: [PATCH 01/39] google drive setup via python api first draft --- datashuttle/configs/canonical_configs.py | 22 +++++++- datashuttle/datashuttle_class.py | 23 ++++++++ datashuttle/utils/folders.py | 27 ++++++--- datashuttle/utils/gdrive.py | 70 ++++++++++++++++++++++++ datashuttle/utils/rclone.py | 29 ++++++++++ 5 files changed, 161 insertions(+), 10 deletions(-) create mode 100644 datashuttle/utils/gdrive.py diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index b65ad6c6..9f1c58ae 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -39,9 +39,13 @@ def get_canonical_configs() -> dict: canonical_configs = { "local_path": Union[str, Path], "central_path": Optional[Union[str, Path]], - "connection_method": Optional[Literal["ssh", "local_filesystem"]], + "connection_method": Optional[ + Literal["ssh", "local_filesystem", "gdrive", "aws_s3"] + ], "central_host_id": Optional[str], "central_host_username": Optional[str], + "gdrive_client_id": Optional[str], + "gdrive_client_secret": Optional[str], } return canonical_configs @@ -128,6 +132,22 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: ConfigError, ) + # Check gdrive settings + if config_dict["connection_method"] == "gdrive" and ( + ( + config_dict["gdrive_client_id"] + and not config_dict["gdrive_client_secret"] + ) + or ( + not config_dict["gdrive_client_id"] + and config_dict["gdrive_client_secret"] + ) + ): + utils.log_and_raise_error( + "Both gdrive_client_id and gdrive_client_secret must be present together", + ConfigError, + ) + # Initialise the local project folder utils.print_message_to_user( f"Making project folder at: {config_dict['local_path']}" diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 9322323f..39d40ad1 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -892,6 +892,20 @@ def write_public_key(self, filepath: str) -> None: public.write(key.get_base64()) public.close() + # ------------------------------------------------------------------------- + # Google Drive + # ------------------------------------------------------------------------- + + @check_configs_set + def setup_google_drive_connection(self) -> None: + self._start_log( + "setup-google-drive-connection-to-central-server", + local_vars=locals(), + ) + + self._setup_rclone_gdrive_config(log=True) + ds_logger.close_log_filehandler() + # ------------------------------------------------------------------------- # Configs # ------------------------------------------------------------------------- @@ -903,6 +917,8 @@ def make_config_file( connection_method: str | None = None, central_host_id: Optional[str] = None, central_host_username: Optional[str] = None, + gdrive_client_id: Optional[str] = None, + gdrive_client_secret: Optional[str] = None, ) -> None: """ Initialise the configurations for datashuttle to use on the @@ -967,6 +983,8 @@ def make_config_file( "connection_method": connection_method, "central_host_id": central_host_id, "central_host_username": central_host_username, + "gdrive_client_id": gdrive_client_id, + "gdrive_client_secret": gdrive_client_secret, }, ) @@ -1470,6 +1488,11 @@ def _setup_rclone_central_local_filesystem_config(self) -> None: self.cfg.get_rclone_config_name("local_filesystem"), ) + def _setup_rclone_gdrive_config(self, log: bool) -> None: + rclone.setup_rclone_config_for_gdrive( + self.cfg, self.cfg.get_rclone_config_name("gdrive"), log=log + ) + # Persistent settings # ------------------------------------------------------------------------- diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index 56852640..5e9b8e05 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -20,7 +20,7 @@ from pathlib import Path from datashuttle.configs import canonical_folders, canonical_tags -from datashuttle.utils import ssh, utils, validation +from datashuttle.utils import gdrive, ssh, utils, validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError # ----------------------------------------------------------------------------- @@ -515,14 +515,23 @@ def search_for_folders( verbose : If `True`, when a search folder cannot be found, a message will be printed with the missing path. """ - if local_or_central == "central" and cfg["connection_method"] == "ssh": - all_folder_names, all_filenames = ssh.search_ssh_central_for_folders( - search_path, - search_prefix, - cfg, - verbose, - return_full_path, - ) + if local_or_central == "central": + if cfg["connection_method"] == "ssh": + all_folder_names, all_filenames = ( + ssh.search_ssh_central_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) + ) + elif cfg["connection_method"] == "gdrive": + all_folder_names, all_filenames = ( + gdrive.search_gdrive_central_for_folders( + search_path, search_prefix, cfg, verbose, return_full_path + ) + ) else: if not search_path.exists(): if verbose: diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py new file mode 100644 index 00000000..539f2918 --- /dev/null +++ b/datashuttle/utils/gdrive.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import json +import subprocess +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + + from datashuttle.configs.config_class import Configs + +from typing import Any, List, Tuple + +from datashuttle.utils import utils + + +# Generic function +def search_gdrive_central_for_folders( + search_path: Path, + search_prefix: str, + cfg: Configs, + verbose: bool = True, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + + command = ( + "rclone lsjson " + f"{cfg.get_rclone_config_name()}:{search_path.as_posix()} " + f'--include "{search_prefix}"', + ) + output = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + ) + + all_folder_names: List[str] = [] + all_filenames: List[str] = [] + + if output.returncode != 0: + if verbose: + utils.log_and_message( + f"Error searching files at {search_path.as_posix()} \n {output.stderr.decode("utf-8") if output.stderr else ""}" + ) + return all_folder_names, all_filenames + + files_and_folders = json.loads(output.stdout) + + try: + for file_or_folder in files_and_folders: + name = file_or_folder["Name"] + is_dir = file_or_folder.get("IsDir", False) + + to_append = ( + (search_path / name).as_posix() if return_full_path else name + ) + + if is_dir: + all_folder_names.append(to_append) + else: + all_filenames.append(to_append) + + except Exception: + if verbose: + utils.log_and_message( + f"Error searching files at {search_path.to_posix()}" + ) + + return all_folder_names, all_filenames diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 49d7da82..4c22f197 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -150,6 +150,35 @@ def setup_rclone_config_for_ssh( log_rclone_config_output() +def setup_rclone_config_for_gdrive( + cfg: Configs, + rclone_config_name: str, + log: bool = True, +): + client_id_key_value = ( + f"client_id {cfg['gdrive_client_id']} " + if cfg["gdrive_client_id"] + else " " + ) + client_secret_key_value = ( + f"client_secret {cfg['gdrive_client_secret']} " + if cfg["gdrive_client_secret"] + else "" + ) + call_rclone( + f"config create " + f"{rclone_config_name} " + f"drive " + f"{client_id_key_value}" + f"{client_secret_key_value}" + f"scope drive", + pipe_std=True, + ) + + if log: + log_rclone_config_output() + + def log_rclone_config_output(): output = call_rclone("config file", pipe_std=True) utils.log( From e37861ead47215e88a9713cc41da04cd3885346d Mon Sep 17 00:00:00 2001 From: shrey Date: Wed, 26 Mar 2025 04:23:04 +0530 Subject: [PATCH 02/39] enable google drive config setup via tui --- datashuttle/tui/css/tui_menu.tcss | 6 +- datashuttle/tui/shared/configs_content.py | 182 +++++++++++++++++++++- 2 files changed, 179 insertions(+), 9 deletions(-) diff --git a/datashuttle/tui/css/tui_menu.tcss b/datashuttle/tui/css/tui_menu.tcss index 00cb9458..07266c20 100644 --- a/datashuttle/tui/css/tui_menu.tcss +++ b/datashuttle/tui/css/tui_menu.tcss @@ -161,11 +161,15 @@ MessageBox:light > #messagebox_top_container { padding: 0 4 0 4; width: 26; color: $success; /* unsure about this */ + dock: left; } -#configs_setup_ssh_connection_button { +#setup_buttons_container > Button { margin: 2 1 0 0; + dock: left; } + + #configs_go_to_project_screen_button { margin: 2 1 0 0; } diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 974ee08a..9b728c5b 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -90,6 +90,24 @@ def compose(self) -> ComposeResult: ), ] + self.config_gdrive_widgets = [ + Label("Client ID", id="configs_gdrive_client_id_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="", + id="configs_gdrive_client_id_input", + ), + Label("Client Secret", id="configs_gdrive_client_secret_label"), + # TODO: HIDE THIS + ClickableInput( + self.parent_class.mainwindow, + placeholder="", + id="configs_gdrive_client_secret_input", + ), + ] + + self.config_aws_s3_widgets = [] + config_screen_widgets = [ Label("Local Path", id="configs_local_path_label"), Horizontal( @@ -108,6 +126,8 @@ def compose(self) -> ComposeResult: id="configs_local_filesystem_radiobutton", ), RadioButton("SSH", id="configs_ssh_radiobutton"), + RadioButton("Google Drive", id="configs_gdrive_radiobutton"), + RadioButton("AWS S3", id="configs_aws_s3_radiobutton"), RadioButton( "No connection (local only)", id="configs_local_only_radiobutton", @@ -115,6 +135,8 @@ def compose(self) -> ComposeResult: id="configs_connect_method_radioset", ), *self.config_ssh_widgets, + *self.config_gdrive_widgets, + *self.config_aws_s3_widgets, Label("Central Path", id="configs_central_path_label"), Horizontal( ClickableInput( @@ -127,9 +149,20 @@ def compose(self) -> ComposeResult: ), Horizontal( Button("Save", id="configs_save_configs_button"), - Button( - "Setup SSH Connection", - id="configs_setup_ssh_connection_button", + Horizontal( + Button( + "Setup SSH Connection", + id="configs_setup_ssh_connection_button", + ), + Button( + "Setup Google Drive Connection", + id="configs_setup_gdrive_connection_button", + ), + Button( + "Setup AWS Connection", + id="configs_setup_aws_connection_button", + ), + id="setup_buttons_container", ), # Below button is always hidden when accessing # configs from project manager screen @@ -184,11 +217,17 @@ def on_mount(self) -> None: self.query_one("#configs_go_to_project_screen_button").visible = False if self.interface: self.fill_widgets_with_project_configs() + self.setup_widgets_to_display( + connection_method=self.interface.get_configs()[ + "connection_method" + ] + ) else: self.query_one("#configs_local_filesystem_radiobutton").value = ( True ) - self.switch_ssh_widgets_display(display_ssh=False) + # self.switch_ssh_widgets_display(display_ssh=False) + self.setup_widgets_to_display(connection_method="local_filesystem") self.query_one("#configs_setup_ssh_connection_button").visible = ( False ) @@ -237,6 +276,8 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: "SSH", "Local Filesystem", "No connection (local only)", + "Google Drive", + "AWS S3", ], "Unexpected label." if label == "No connection (local only)": @@ -246,14 +287,21 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: True ) display_ssh = False + display_gdrive = False + display_aws = False else: self.query_one("#configs_central_path_input").disabled = False self.query_one("#configs_central_path_select_button").disabled = ( False ) display_ssh = True if label == "SSH" else False + display_gdrive = True if label == "Google Drive" else False + display_aws = True if label == "AWS S3" else False self.switch_ssh_widgets_display(display_ssh) + self.switch_gdrive_widgets_display(display_gdrive) + self.switch_aws_widgets_display(display_aws) + self.set_central_path_input_tooltip(display_ssh) def set_central_path_input_tooltip(self, display_ssh: bool) -> None: @@ -327,6 +375,36 @@ def switch_ssh_widgets_display(self, display_ssh: bool) -> None: placeholder ) + def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: + for widget in self.config_gdrive_widgets: + widget.display = display_gdrive + + self.query_one("#configs_central_path_select_button").display = ( + not display_gdrive + ) + + # TODO: ADD SETUP FOR BUTTONS + + if self.interface is None: + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False + else: + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = display_gdrive + + def switch_aws_widgets_display(self, display_aws: bool) -> None: + + if self.interface is None: + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) + else: + self.query_one("#configs_setup_aws_connection_button").visible = ( + display_aws + ) + def on_button_pressed(self, event: Button.Pressed) -> None: """ Enables the Create Folders button to read out current input values @@ -477,7 +555,21 @@ def setup_configs_for_a_new_project(self) -> None: "'Main Menu' and proceed to the project page, where you will be " "able to create and transfer project folders." ) + elif cfg_kwargs["connection_method"] == "gdrive": + + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = True + self.query_one( + "#configs_setup_gdrive_connection_button" + ).disabled = False + message = ( + "A datashuttle project has now been created.\n\n " + "Next, setup the Google Drive connection. Once complete, navigate to the " + "'Main Menu' and proceed to the project page, where you will be " + "able to create and transfer project folders." + ) else: message = ( "A datashuttle project has now been created.\n\n " @@ -505,7 +597,7 @@ def setup_configs_for_an_existing_project(self) -> None: # Handle the edge case where connection method is changed after # saving on the 'Make New Project' screen. - self.query_one("#configs_setup_ssh_connection_button").visible = True + # self.query_one("#configs_setup_ssh_connection_button").visible = True cfg_kwargs = self.get_datashuttle_inputs_from_widgets() @@ -556,6 +648,10 @@ def fill_widgets_with_project_configs(self) -> None: cfg_to_load["connection_method"] == "ssh", "configs_local_filesystem_radiobutton": cfg_to_load["connection_method"] == "local_filesystem", + "configs_gdrive_radiobutton": + cfg_to_load["connection_method"] == "gdrive", + "configs_aws_s3_radiobutton": + cfg_to_load["connection_method"] == "aws_s3", "configs_local_only_radiobutton": cfg_to_load["connection_method"] is None, } @@ -564,9 +660,9 @@ def fill_widgets_with_project_configs(self) -> None: for id, value in what_radiobuton_is_on.items(): self.query_one(f"#{id}").value = value - self.switch_ssh_widgets_display( - display_ssh=what_radiobuton_is_on["configs_ssh_radiobutton"] - ) + # self.switch_ssh_widgets_display( + # display_ssh=what_radiobuton_is_on["configs_ssh_radiobutton"] + # ) # Central Host ID input = self.query_one("#configs_central_host_id_input") @@ -586,6 +682,54 @@ def fill_widgets_with_project_configs(self) -> None: ) input.value = value + # Google Drive Client ID + input = self.query_one("#configs_gdrive_client_id_input") + value = ( + "" + if cfg_to_load["gdrive_client_id"] is None + else cfg_to_load["gdrive_client_id"] + ) + input.value = value + + # Google Drive Client Secret + input = self.query_one("#configs_gdrive_client_secret_input") + value = ( + "" + if cfg_to_load["gdrive_client_secret"] is None + else cfg_to_load["gdrive_client_secret"] + ) + input.value = value + + def setup_widgets_to_display(self, connection_method: str | None) -> None: + + if connection_method: + assert connection_method in [ + "local_filesystem", + "ssh", + "gdrive", + "aws_s3", + ], "Unexpected Connection Method" + + if connection_method == "ssh": + self.switch_ssh_widgets_display(True) + self.switch_gdrive_widgets_display(False) + self.switch_aws_widgets_display(False) + + elif connection_method == "gdrive": + self.switch_ssh_widgets_display(False) + self.switch_gdrive_widgets_display(True) + self.switch_aws_widgets_display(False) + + elif connection_method == "aws_s3": + self.switch_ssh_widgets_display(False) + self.switch_gdrive_widgets_display(False) + self.switch_aws_widgets_display(True) + + else: + self.switch_ssh_widgets_display(False) + self.switch_gdrive_widgets_display(False) + self.switch_aws_widgets_display(False) + def get_datashuttle_inputs_from_widgets(self) -> Dict: """ Get the configs to pass to `make_config_file()` from @@ -608,6 +752,12 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: if self.query_one("#configs_ssh_radiobutton").value: connection_method = "ssh" + elif self.query_one("#configs_gdrive_radiobutton").value: + connection_method = "gdrive" + + elif self.query_one("#configs_aws_s3_radiobutton").value: + connection_method = "aws_s3" + elif self.query_one("#configs_local_filesystem_radiobutton").value: connection_method = "local_filesystem" @@ -631,4 +781,20 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: None if central_host_username == "" else central_host_username ) + # TODO : ADD CFG FOR GDRIVE CLIENT ID AND CLIENT SECRET + + gdrive_client_id = self.query_one( + "#configs_gdrive_client_id_input" + ).value + gdrive_client_secret = self.query_one( + "#configs_gdrive_client_secret_input" + ).value + + cfg_kwargs["gdrive_client_id"] = ( + None if gdrive_client_id == "" else gdrive_client_id + ) + cfg_kwargs["gdrive_client_secret"] = ( + None if gdrive_client_secret == "" else gdrive_client_secret + ) + return cfg_kwargs From 01503f70bfa0084bb5bb954c856471766d496c41 Mon Sep 17 00:00:00 2001 From: shrey Date: Wed, 26 Mar 2025 22:21:04 +0530 Subject: [PATCH 03/39] minor compatibility and ui changes --- datashuttle/tui/shared/configs_content.py | 25 ++++++++++++++++------- datashuttle/utils/rclone.py | 4 ++-- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 9b728c5b..28aa7522 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -357,7 +357,10 @@ def switch_ssh_widgets_display(self, display_ssh: bool) -> None: not display_ssh ) - if self.interface is None: + if ( + self.interface is None + or self.interface.get_configs()["connection_method"] != "ssh" + ): self.query_one("#configs_setup_ssh_connection_button").visible = ( False ) @@ -383,9 +386,10 @@ def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: not display_gdrive ) - # TODO: ADD SETUP FOR BUTTONS - - if self.interface is None: + if ( + self.interface is None + or self.interface.get_configs()["connection_method"] != "gdrive" + ): self.query_one( "#configs_setup_gdrive_connection_button" ).visible = False @@ -396,7 +400,10 @@ def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: def switch_aws_widgets_display(self, display_aws: bool) -> None: - if self.interface is None: + if ( + self.interface is None + or self.interface.get_configs()["connection_method"] != "aws_s3" + ): self.query_one("#configs_setup_aws_connection_button").visible = ( False ) @@ -419,6 +426,8 @@ def on_button_pressed(self, event: Button.Pressed) -> None: elif event.button.id == "configs_setup_ssh_connection_button": self.setup_ssh_connection() + elif event.button.id == "configs_setup_gdrive_connection_button": + self.interface.project.setup_google_drive_connection() elif event.button.id == "configs_go_to_project_screen_button": self.parent_class.dismiss(self.interface) @@ -612,6 +621,8 @@ def setup_configs_for_an_existing_project(self) -> None: ), lambda unused: self.post_message(self.ConfigsSaved()), ) + # to trigger the appearance of buttons + self.setup_widgets_to_display(cfg_kwargs["connection_method"]) else: self.parent_class.mainwindow.show_modal_error_dialog(output) @@ -686,7 +697,7 @@ def fill_widgets_with_project_configs(self) -> None: input = self.query_one("#configs_gdrive_client_id_input") value = ( "" - if cfg_to_load["gdrive_client_id"] is None + if cfg_to_load.get("gdrive_client_id", None) is None else cfg_to_load["gdrive_client_id"] ) input.value = value @@ -695,7 +706,7 @@ def fill_widgets_with_project_configs(self) -> None: input = self.query_one("#configs_gdrive_client_secret_input") value = ( "" - if cfg_to_load["gdrive_client_secret"] is None + if cfg_to_load.get("gdrive_client_secret", None) is None else cfg_to_load["gdrive_client_secret"] ) input.value = value diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 4c22f197..bf3d767d 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -157,12 +157,12 @@ def setup_rclone_config_for_gdrive( ): client_id_key_value = ( f"client_id {cfg['gdrive_client_id']} " - if cfg["gdrive_client_id"] + if cfg.get("gdrive_client_id", None) else " " ) client_secret_key_value = ( f"client_secret {cfg['gdrive_client_secret']} " - if cfg["gdrive_client_secret"] + if cfg.get("gdrive_client_secret", None) else "" ) call_rclone( From 7380bac713c911882f8266443345b27d2c84715d Mon Sep 17 00:00:00 2001 From: shrey Date: Wed, 26 Mar 2025 22:34:00 +0530 Subject: [PATCH 04/39] protectedclient secret input box --- datashuttle/tui/shared/configs_content.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 28aa7522..98e365f8 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -17,6 +17,7 @@ from textual.message import Message from textual.widgets import ( Button, + Input, Label, RadioButton, RadioSet, @@ -98,10 +99,9 @@ def compose(self) -> ComposeResult: id="configs_gdrive_client_id_input", ), Label("Client Secret", id="configs_gdrive_client_secret_label"), - # TODO: HIDE THIS - ClickableInput( - self.parent_class.mainwindow, + Input( placeholder="", + password=True, id="configs_gdrive_client_secret_input", ), ] From 965e0169378f22f35a90fcec0655cfebc43751bf Mon Sep 17 00:00:00 2001 From: shrey Date: Thu, 27 Mar 2025 03:35:21 +0530 Subject: [PATCH 05/39] google drive connection setup via TUI --- datashuttle/tui/css/tui_menu.tcss | 34 +++++++ datashuttle/tui/screens/setup_aws.py | 0 datashuttle/tui/screens/setup_gdrive.py | 113 ++++++++++++++++++++++ datashuttle/tui/shared/configs_content.py | 28 ++++-- 4 files changed, 169 insertions(+), 6 deletions(-) create mode 100644 datashuttle/tui/screens/setup_aws.py create mode 100644 datashuttle/tui/screens/setup_gdrive.py diff --git a/datashuttle/tui/css/tui_menu.tcss b/datashuttle/tui/css/tui_menu.tcss index 07266c20..e8d2d9fc 100644 --- a/datashuttle/tui/css/tui_menu.tcss +++ b/datashuttle/tui/css/tui_menu.tcss @@ -64,6 +64,9 @@ SettingsScreen { GetHelpScreen { align: center middle; } +SetupGdriveScreen { + align: center middle; + } #get_help_label { align: center middle; text-align: center; @@ -114,6 +117,37 @@ MessageBox:light > #messagebox_top_container { align: center middle; } +#setup_gdrive_screen_container { + height: 75%; + width: 80%; + align: center middle; + border: thick $panel-lighten-1; +} + +#gdrive_setup_messagebox_message_container { + align: center middle; + overflow: hidden auto; + margin: 0 1; +} + +#gdrive_setup_messagebox_message { + text-align: center; + padding: 0 2; +} + +#setup_gdrive_ok_button { + align: center bottom; + height: 3; +} + +#setup_gdrive_cancel_button { + align: center bottom; +} + +#setup_gdrive_buttons_horizontal { + align: center middle; +} + /* Configs Content ----------------------------------------------------------------- */ #configs_container { diff --git a/datashuttle/tui/screens/setup_aws.py b/datashuttle/tui/screens/setup_aws.py new file mode 100644 index 00000000..e69de29b diff --git a/datashuttle/tui/screens/setup_gdrive.py b/datashuttle/tui/screens/setup_gdrive.py new file mode 100644 index 00000000..87d2c3ed --- /dev/null +++ b/datashuttle/tui/screens/setup_gdrive.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from textual.app import ComposeResult + from textual.worker import Worker + + from datashuttle.tui.interface import Interface + +from textual import work +from textual.containers import Container, Horizontal +from textual.screen import ModalScreen +from textual.widgets import ( + Button, + Static, +) + + +class SetupGdriveScreen(ModalScreen): + """ """ + + def __init__(self, interface: Interface) -> None: + super(SetupGdriveScreen, self).__init__() + + self.interface = interface + self.stage = 0 + self.setup_worker: Worker | None = None + + def compose(self) -> ComposeResult: + yield Container( + Horizontal( + Static( + "Ready to setup Google Drive. " "Press OK to proceed", + id="gdrive_setup_messagebox_message", + ), + id="gdrive_setup_messagebox_message_container", + ), + # Input(), + Horizontal( + Button("OK", id="setup_gdrive_ok_button"), + Button("Cancel", id="setup_gdrive_cancel_button"), + id="setup_gdrive_buttons_horizontal", + ), + id="setup_gdrive_screen_container", + ) + + def on_button_pressed(self, event: Button.Pressed) -> None: + """ """ + if ( + event.button.id == "setup_gdrive_cancel_button" + or event.button.id == "setup_gdrive_finish_button" + ): + if self.setup_worker and self.setup_worker.is_running: + self.setup_worker.cancel() # fix + self.dismiss() + + elif event.button.id == "setup_gdrive_ok_button": + self.ask_user_for_browser() + + elif event.button.id == "setup_gdrive_yes_button": + self.open_browser_and_setup_gdrive_connection() + + def ask_user_for_browser(self) -> None: + message = ( + "Are you running Datashuttle on a machine " + "that can open a web browser?" + ) + self.query_one("#gdrive_setup_messagebox_message").update(message) + + yes_button = Button("Yes", id="setup_gdrive_yes_button") + no_button = Button("No", id="setup_gdrive_no_button") + + self.query_one("#setup_gdrive_ok_button").remove() + self.query_one("#setup_gdrive_buttons_horizontal").mount( + yes_button, no_button, before="#setup_gdrive_cancel_button" + ) + + self.stage += 1 + + def open_browser_and_setup_gdrive_connection(self) -> None: + # TODO: ADD SOME SUCCESS, OUTPUT + message = "Please authenticate through browser." + self.query_one("#gdrive_setup_messagebox_message").update(message) + + self.query_one("#setup_gdrive_yes_button").remove() + self.query_one("#setup_gdrive_no_button").remove() + + async def _setup_gdrive_and_update_ui(): + worker = self.setup_gdrive_connection() + self.setup_worker = worker + if worker.is_running: + await worker.wait() + + # TODO : check if successful + self.show_finish_screen() + + asyncio.create_task(_setup_gdrive_and_update_ui()) + + @work(exclusive=True, thread=True) + def setup_gdrive_connection(self) -> Worker: + self.interface.project.setup_google_drive_connection() + self.stage += 1 + + def show_finish_screen(self) -> None: + message = "Setup Complete!" + self.query_one("#setup_gdrive_cancel_button").remove() + + self.query_one("#gdrive_setup_messagebox_message").update(message) + self.query_one("#setup_gdrive_buttons_horizontal").mount( + Button("Finish", id="setup_gdrive_finish_button") + ) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 98e365f8..91423404 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -26,7 +26,7 @@ from datashuttle.tui.custom_widgets import ClickableInput from datashuttle.tui.interface import Interface -from datashuttle.tui.screens import modal_dialogs, setup_ssh +from datashuttle.tui.screens import modal_dialogs, setup_gdrive, setup_ssh from datashuttle.tui.tooltips import get_tooltip @@ -95,12 +95,12 @@ def compose(self) -> ComposeResult: Label("Client ID", id="configs_gdrive_client_id_label"), ClickableInput( self.parent_class.mainwindow, - placeholder="", + placeholder="Google Drive Client ID (leave blank to use rclone's default client (slower))", id="configs_gdrive_client_id_input", ), Label("Client Secret", id="configs_gdrive_client_secret_label"), Input( - placeholder="", + placeholder="Google Drive Client Secret (leave blank to use rclone's default client (slower))", password=True, id="configs_gdrive_client_secret_input", ), @@ -427,7 +427,8 @@ def on_button_pressed(self, event: Button.Pressed) -> None: self.setup_ssh_connection() elif event.button.id == "configs_setup_gdrive_connection_button": - self.interface.project.setup_google_drive_connection() + self.setup_gdrive_connection() + elif event.button.id == "configs_go_to_project_screen_button": self.parent_class.dismiss(self.interface) @@ -496,6 +497,23 @@ def setup_ssh_connection(self) -> None: setup_ssh.SetupSshScreen(self.interface) ) + def setup_gdrive_connection(self) -> None: + """ + Set up the `SetupGdriveScreen` screen, + """ + assert self.interface is not None, "type narrow flexible `interface`" + + if not self.widget_configs_match_saved_configs(): + self.parent_class.mainwindow.show_modal_error_dialog( + "The values set above must equal the datashuttle settings. " + "Either press 'Save' or reload this page." + ) + return + + self.parent_class.mainwindow.push_screen( + setup_gdrive.SetupGdriveScreen(self.interface) + ) + def widget_configs_match_saved_configs(self): """ Check that the configs currently stored in the widgets @@ -792,8 +810,6 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: None if central_host_username == "" else central_host_username ) - # TODO : ADD CFG FOR GDRIVE CLIENT ID AND CLIENT SECRET - gdrive_client_id = self.query_one( "#configs_gdrive_client_id_input" ).value From 6698cb02eaa9b17d65ed0fb755f4e565975f4553 Mon Sep 17 00:00:00 2001 From: shrey Date: Sat, 29 Mar 2025 03:33:23 +0530 Subject: [PATCH 06/39] add aws as remote storage via python api first draft --- datashuttle/__init__.py | 1 + datashuttle/configs/canonical_configs.py | 17 +++++- datashuttle/configs/regions.py | 30 ++++++++++ datashuttle/datashuttle_class.py | 33 +++++++++++ datashuttle/tui/shared/configs_content.py | 21 ++++--- datashuttle/utils/decorators.py | 21 +++++++ datashuttle/utils/folders.py | 68 +++++++++++++++++++++- datashuttle/utils/gdrive.py | 70 ----------------------- datashuttle/utils/rclone.py | 20 +++++++ 9 files changed, 197 insertions(+), 84 deletions(-) create mode 100644 datashuttle/configs/regions.py diff --git a/datashuttle/__init__.py b/datashuttle/__init__.py index 501cb203..c9a66f3c 100644 --- a/datashuttle/__init__.py +++ b/datashuttle/__init__.py @@ -2,6 +2,7 @@ from datashuttle.datashuttle_class import DataShuttle from datashuttle.datashuttle_functions import quick_validate_project +from datashuttle.configs.regions import AWS_REGION try: diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 9f1c58ae..8c7d4415 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -46,6 +46,9 @@ def get_canonical_configs() -> dict: "central_host_username": Optional[str], "gdrive_client_id": Optional[str], "gdrive_client_secret": Optional[str], + "aws_access_key_id": Optional[str], + "aws_s3_region": Optional[str], + # "aws_s3_endpoint_url": Optional[str], } return canonical_configs @@ -133,7 +136,7 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: ) # Check gdrive settings - if config_dict["connection_method"] == "gdrive" and ( + elif config_dict["connection_method"] == "gdrive" and ( ( config_dict["gdrive_client_id"] and not config_dict["gdrive_client_secret"] @@ -144,7 +147,17 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: ) ): utils.log_and_raise_error( - "Both gdrive_client_id and gdrive_client_secret must be present together", + "Both gdrive_client_id and gdrive_client_secret must be present together.", + ConfigError, + ) + + # Check AWS settings + elif config_dict["connection_method"] == "aws_s3" and ( + not config_dict["aws_access_key_id"] + or not config_dict["aws_s3_region"] + ): + utils.log_and_raise_error( + "Both aws_access_key_id and aws_s3_region must be present for AWS connection.", ConfigError, ) diff --git a/datashuttle/configs/regions.py b/datashuttle/configs/regions.py new file mode 100644 index 00000000..f45c825d --- /dev/null +++ b/datashuttle/configs/regions.py @@ -0,0 +1,30 @@ +from enum import Enum + + +class AWS_REGION(Enum): + US_EAST_1 = "us-east-1" + US_EAST_2 = "us-east-2" + US_WEST_1 = "us-west-1" + US_WEST_2 = "us-west-2" + CA_CENTRAL_1 = "ca-central-1" + EU_WEST_1 = "eu-west-1" + EU_WEST_2 = "eu-west-2" + EU_WEST_3 = "eu-west-3" + EU_NORTH_1 = "eu-north-1" + EU_SOUTH_1 = "eu_south-1" + EUI_CENTRAL_1 = "eu-central-1" + AP_SOUTHEAST_1 = "ap-southeast-1" + AP_SOUTHEAST_2 = "ap-southeast-2" + API_NORTHEAST_1 = "ap-northeast-1" + AP_NORTHEAST_2 = "ap-northeast-2" + AP_NORTHEAST_3 = "ap-northeast-3" + AP_SOUTH_1 = "ap-south-1" + AP_EAST_1 = "ap-east-1" + SA_EAST_1 = "sa-east-1" + IL_CENTRAL_1 = "il-central-1" + ME_SOUTH_1 = "me-south-1" + AF_SOUTH_1 = "af-south-1" + CN_NORTH_1 = "cn-north-1" + CN_NORTHWEST_1 = "cn-northwest-1" + US_GOV_EAST_1 = "us-gov-east-1" + US_GOV_WEST_1 = "us-gov-west-1" diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 39d40ad1..c2d3dc4e 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -32,6 +32,7 @@ canonical_configs, canonical_folders, load_configs, + regions, ) from datashuttle.configs.config_class import Configs from datashuttle.datashuttle_functions import _format_top_level_folder @@ -53,6 +54,7 @@ from datashuttle.utils.decorators import ( # noqa check_configs_set, check_is_not_local_project, + requires_aws_configs, requires_ssh_configs, ) @@ -906,6 +908,21 @@ def setup_google_drive_connection(self) -> None: self._setup_rclone_gdrive_config(log=True) ds_logger.close_log_filehandler() + # ------------------------------------------------------------------------- + # AWS S3 + # ------------------------------------------------------------------------- + + @requires_aws_configs + @check_configs_set + def setup_aws_s3_connection(self, aws_secret_access_key: str) -> None: + self._start_log( + "setup-aws-s3-connection-to-central-server", + local_vars=locals(), + ) + + self._setup_rclone_aws_config(aws_secret_access_key, log=True) + ds_logger.close_log_filehandler() + # ------------------------------------------------------------------------- # Configs # ------------------------------------------------------------------------- @@ -919,6 +936,8 @@ def make_config_file( central_host_username: Optional[str] = None, gdrive_client_id: Optional[str] = None, gdrive_client_secret: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_s3_region: Optional[regions.AWS_REGION] = None, ) -> None: """ Initialise the configurations for datashuttle to use on the @@ -985,6 +1004,10 @@ def make_config_file( "central_host_username": central_host_username, "gdrive_client_id": gdrive_client_id, "gdrive_client_secret": gdrive_client_secret, + "aws_access_key_id": aws_access_key_id, + "aws_s3_region": ( + aws_s3_region.value if aws_s3_region else None + ), }, ) @@ -1493,6 +1516,16 @@ def _setup_rclone_gdrive_config(self, log: bool) -> None: self.cfg, self.cfg.get_rclone_config_name("gdrive"), log=log ) + def _setup_rclone_aws_config( + self, aws_secret_access_key: str, log: bool + ) -> None: + rclone.setup_rclone_config_for_aws_s3( + self.cfg, + aws_secret_access_key, + self.cfg.get_rclone_config_name("aws_s3"), + log=log, + ) + # Persistent settings # ------------------------------------------------------------------------- diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 91423404..6012af7d 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -280,6 +280,7 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: "AWS S3", ], "Unexpected label." + connection_method = None if label == "No connection (local only)": self.query_one("#configs_central_path_input").value = "" self.query_one("#configs_central_path_input").disabled = True @@ -287,20 +288,21 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: True ) display_ssh = False - display_gdrive = False - display_aws = False else: self.query_one("#configs_central_path_input").disabled = False self.query_one("#configs_central_path_select_button").disabled = ( False ) display_ssh = True if label == "SSH" else False - display_gdrive = True if label == "Google Drive" else False - display_aws = True if label == "AWS S3" else False - self.switch_ssh_widgets_display(display_ssh) - self.switch_gdrive_widgets_display(display_gdrive) - self.switch_aws_widgets_display(display_aws) + if label == "SSH": + connection_method = "ssh" + elif label == "Google Drive": + connection_method = "gdrive" + elif label == "AWS S3": + connection_method = "aws_s3" + + self.setup_widgets_to_display(connection_method) self.set_central_path_input_tooltip(display_ssh) @@ -740,14 +742,15 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: ], "Unexpected Connection Method" if connection_method == "ssh": - self.switch_ssh_widgets_display(True) + # order matters -> fix this self.switch_gdrive_widgets_display(False) self.switch_aws_widgets_display(False) + self.switch_ssh_widgets_display(True) elif connection_method == "gdrive": self.switch_ssh_widgets_display(False) - self.switch_gdrive_widgets_display(True) self.switch_aws_widgets_display(False) + self.switch_gdrive_widgets_display(True) elif connection_method == "aws_s3": self.switch_ssh_widgets_display(False) diff --git a/datashuttle/utils/decorators.py b/datashuttle/utils/decorators.py index cacf5491..36e7d0d5 100644 --- a/datashuttle/utils/decorators.py +++ b/datashuttle/utils/decorators.py @@ -28,6 +28,27 @@ def wrapper(*args, **kwargs): return wrapper +def requires_aws_configs(func): + + @wraps(func) + def wrapper(*args, **kwargs): + if ( + not args[0].cfg["aws_access_key_id"] + or not args[0].cfg["aws_s3_region"] + ): + log_and_raise_error( + "Cannot setup AWS connection, 'aws_access_key_id' " + "or 'aws_s3_region' is not set in the " + "configuration file", + ConfigError, + ) + + else: + return func(*args, **kwargs) + + return wrapper + + def check_configs_set(func): """ Check that configs have been loaded (i.e. diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index 5e9b8e05..13abeee4 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json +import subprocess from typing import ( TYPE_CHECKING, Any, @@ -20,7 +22,7 @@ from pathlib import Path from datashuttle.configs import canonical_folders, canonical_tags -from datashuttle.utils import gdrive, ssh, utils, validation +from datashuttle.utils import ssh, utils, validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError # ----------------------------------------------------------------------------- @@ -526,12 +528,17 @@ def search_for_folders( return_full_path, ) ) - elif cfg["connection_method"] == "gdrive": + + elif ( + cfg["connection_method"] == "gdrive" + or cfg["connection_method"] == "aws_s3" + ): all_folder_names, all_filenames = ( - gdrive.search_gdrive_central_for_folders( + search_remote_central_for_folders( search_path, search_prefix, cfg, verbose, return_full_path ) ) + else: if not search_path.exists(): if verbose: @@ -546,6 +553,61 @@ def search_for_folders( return all_folder_names, all_filenames +def search_remote_central_for_folders( + search_path: Path, + search_prefix: str, + cfg: Configs, + verbose: bool = True, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + + command = ( + "rclone lsjson " + f"{cfg.get_rclone_config_name()}:{search_path.as_posix()} " + f'--include "{search_prefix}"', + ) + output = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + ) + + all_folder_names: List[str] = [] + all_filenames: List[str] = [] + + if output.returncode != 0: + if verbose: + utils.log_and_message( + f"Error searching files at {search_path.as_posix()} \n {output.stderr.decode("utf-8") if output.stderr else ""}" + ) + return all_folder_names, all_filenames + + files_and_folders = json.loads(output.stdout) + + try: + for file_or_folder in files_and_folders: + name = file_or_folder["Name"] + is_dir = file_or_folder.get("IsDir", False) + + to_append = ( + (search_path / name).as_posix() if return_full_path else name + ) + + if is_dir: + all_folder_names.append(to_append) + else: + all_filenames.append(to_append) + + except Exception: + if verbose: + utils.log_and_message( + f"Error searching files at {search_path.to_posix()}" + ) + + return all_folder_names, all_filenames + + # Actual function implementation def search_filesystem_path_for_folders( search_path_with_prefix: Path, return_full_path: bool = False diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index 539f2918..e69de29b 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -1,70 +0,0 @@ -from __future__ import annotations - -import json -import subprocess -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pathlib import Path - - from datashuttle.configs.config_class import Configs - -from typing import Any, List, Tuple - -from datashuttle.utils import utils - - -# Generic function -def search_gdrive_central_for_folders( - search_path: Path, - search_prefix: str, - cfg: Configs, - verbose: bool = True, - return_full_path: bool = False, -) -> Tuple[List[Any], List[Any]]: - - command = ( - "rclone lsjson " - f"{cfg.get_rclone_config_name()}:{search_path.as_posix()} " - f'--include "{search_prefix}"', - ) - output = subprocess.run( - command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True, - ) - - all_folder_names: List[str] = [] - all_filenames: List[str] = [] - - if output.returncode != 0: - if verbose: - utils.log_and_message( - f"Error searching files at {search_path.as_posix()} \n {output.stderr.decode("utf-8") if output.stderr else ""}" - ) - return all_folder_names, all_filenames - - files_and_folders = json.loads(output.stdout) - - try: - for file_or_folder in files_and_folders: - name = file_or_folder["Name"] - is_dir = file_or_folder.get("IsDir", False) - - to_append = ( - (search_path / name).as_posix() if return_full_path else name - ) - - if is_dir: - all_folder_names.append(to_append) - else: - all_filenames.append(to_append) - - except Exception: - if verbose: - utils.log_and_message( - f"Error searching files at {search_path.to_posix()}" - ) - - return all_folder_names, all_filenames diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index bf3d767d..4b5f637b 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -179,6 +179,26 @@ def setup_rclone_config_for_gdrive( log_rclone_config_output() +def setup_rclone_config_for_aws_s3( + cfg: Configs, + aws_secret_access_key: str, + rclone_config_name: str, + log: bool = True, +): + call_rclone( + "config create " + f"{rclone_config_name} " + "s3 provider AWS " + f"access_key_id {cfg['aws_access_key_id']} " + f"secret_access_key {aws_secret_access_key} " + f"region {cfg['aws_s3_region']} " + f"location_constraint {cfg['aws_s3_region']}" + ) + + if log: + log_rclone_config_output() + + def log_rclone_config_output(): output = call_rclone("config file", pipe_std=True) utils.log( From 60b77aa9f912c8c007a4063b7eb441eeb8963db4 Mon Sep 17 00:00:00 2001 From: shrey Date: Sat, 29 Mar 2025 21:49:51 +0530 Subject: [PATCH 07/39] add: logging and connection check for aws s3 --- datashuttle/datashuttle_class.py | 7 +++++ datashuttle/utils/aws.py | 50 ++++++++++++++++++++++++++++++++ datashuttle/utils/rclone.py | 3 +- 3 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 datashuttle/utils/aws.py diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index c2d3dc4e..9659531d 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -37,6 +37,7 @@ from datashuttle.configs.config_class import Configs from datashuttle.datashuttle_functions import _format_top_level_folder from datashuttle.utils import ( + aws, ds_logger, folders, formatting, @@ -921,6 +922,12 @@ def setup_aws_s3_connection(self, aws_secret_access_key: str) -> None: ) self._setup_rclone_aws_config(aws_secret_access_key, log=True) + + aws.check_successful_connection(self.cfg) + utils.log_and_message("AWS Connection Successful.") + + aws.warn_if_bucket_absent(self.cfg) + ds_logger.close_log_filehandler() # ------------------------------------------------------------------------- diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py new file mode 100644 index 00000000..23a0a972 --- /dev/null +++ b/datashuttle/utils/aws.py @@ -0,0 +1,50 @@ +import json + +from datashuttle.configs.config_class import Configs +from datashuttle.utils import rclone, utils + + +def check_successful_connection(cfg: Configs) -> None: + """Check for a successful connection by executing an `ls` command""" + + output = rclone.call_rclone( + f"ls {cfg.get_rclone_config_name()}:", pipe_std=True + ) + + if output.returncode != 0: + utils.log_and_raise_error( + output.stderr.decode("utf-8"), ConnectionError + ) + + +def check_if_aws_bucket_exists(cfg: Configs) -> bool: + output = rclone.call_rclone( + f"lsjson {cfg.get_rclone_config_name()}:", pipe_std=True + ) + + files_and_folders = json.loads(output.stdout) + + names = list(map(lambda x: x.get("Name", None), files_and_folders)) + + bucket_name = cfg["central_path"].as_posix().strip("/").split("/")[0] + + if bucket_name not in names: + return False + + return True + + +# ----------------------------------------------------------------------------- +# For Python API +# ----------------------------------------------------------------------------- + + +def warn_if_bucket_absent(cfg: Configs) -> None: + + if not check_if_aws_bucket_exists(cfg): + bucket_name = cfg["central_path"].as_posix().strip("/").split("/")[0] + utils.print_message_to_user( + f'WARNING: The bucket "{bucket_name}" does not exist.\n' + f"For data transfer to happen, the bucket must exist.\n" + f"Please change the bucket name in the `central_path`. " + ) diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 4b5f637b..8d2a0564 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -192,7 +192,8 @@ def setup_rclone_config_for_aws_s3( f"access_key_id {cfg['aws_access_key_id']} " f"secret_access_key {aws_secret_access_key} " f"region {cfg['aws_s3_region']} " - f"location_constraint {cfg['aws_s3_region']}" + f"location_constraint {cfg['aws_s3_region']}", + pipe_std=True, ) if log: From d9755da6b1f898cdf84e0bffccf6f2e733961f99 Mon Sep 17 00:00:00 2001 From: shrey Date: Sun, 30 Mar 2025 04:19:25 +0530 Subject: [PATCH 08/39] update: type checking for aws regions --- datashuttle/configs/canonical_configs.py | 3 ++- datashuttle/configs/regions.py | 13 +++++++++---- datashuttle/datashuttle_class.py | 7 ++----- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 8c7d4415..1b7283ed 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -27,6 +27,7 @@ import typeguard +from datashuttle.configs.regions import AWS_REGION from datashuttle.utils import folders, utils from datashuttle.utils.custom_exceptions import ConfigError @@ -47,7 +48,7 @@ def get_canonical_configs() -> dict: "gdrive_client_id": Optional[str], "gdrive_client_secret": Optional[str], "aws_access_key_id": Optional[str], - "aws_s3_region": Optional[str], + "aws_s3_region": Optional[Literal[*AWS_REGION.get_all_regions()]], # "aws_s3_endpoint_url": Optional[str], } diff --git a/datashuttle/configs/regions.py b/datashuttle/configs/regions.py index f45c825d..b42ecf30 100644 --- a/datashuttle/configs/regions.py +++ b/datashuttle/configs/regions.py @@ -1,7 +1,4 @@ -from enum import Enum - - -class AWS_REGION(Enum): +class AWS_REGION: US_EAST_1 = "us-east-1" US_EAST_2 = "us-east-2" US_WEST_1 = "us-west-1" @@ -28,3 +25,11 @@ class AWS_REGION(Enum): CN_NORTHWEST_1 = "cn-northwest-1" US_GOV_EAST_1 = "us-gov-east-1" US_GOV_WEST_1 = "us-gov-west-1" + + @classmethod + def get_all_regions(cls): + return [ + value + for key, value in vars(cls).items() + if not key.startswith("__") and isinstance(value, str) + ] diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 9659531d..1a01dc89 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -32,7 +32,6 @@ canonical_configs, canonical_folders, load_configs, - regions, ) from datashuttle.configs.config_class import Configs from datashuttle.datashuttle_functions import _format_top_level_folder @@ -944,7 +943,7 @@ def make_config_file( gdrive_client_id: Optional[str] = None, gdrive_client_secret: Optional[str] = None, aws_access_key_id: Optional[str] = None, - aws_s3_region: Optional[regions.AWS_REGION] = None, + aws_s3_region: Optional[str] = None, ) -> None: """ Initialise the configurations for datashuttle to use on the @@ -1012,9 +1011,7 @@ def make_config_file( "gdrive_client_id": gdrive_client_id, "gdrive_client_secret": gdrive_client_secret, "aws_access_key_id": aws_access_key_id, - "aws_s3_region": ( - aws_s3_region.value if aws_s3_region else None - ), + "aws_s3_region": aws_s3_region, }, ) From f3947a737962f8a2630e89686fc42e1f8558c387 Mon Sep 17 00:00:00 2001 From: shrey Date: Sun, 30 Mar 2025 04:29:31 +0530 Subject: [PATCH 09/39] add: save aws configs via TUI --- datashuttle/tui/css/tui_menu.tcss | 4 + datashuttle/tui/shared/configs_content.py | 98 ++++++++++++++++++----- 2 files changed, 84 insertions(+), 18 deletions(-) diff --git a/datashuttle/tui/css/tui_menu.tcss b/datashuttle/tui/css/tui_menu.tcss index e8d2d9fc..494c307e 100644 --- a/datashuttle/tui/css/tui_menu.tcss +++ b/datashuttle/tui/css/tui_menu.tcss @@ -242,6 +242,10 @@ MessageBox:light > #messagebox_top_container { padding: 0 0 2 0; } +#configs_aws_s3_region_select { + width: 70%; +} + /* This Horizontal wrapper container is necessary to make the config label and button align center */ diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 6012af7d..f0474f58 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -21,9 +21,11 @@ Label, RadioButton, RadioSet, + Select, Static, ) +from datashuttle import AWS_REGION from datashuttle.tui.custom_widgets import ClickableInput from datashuttle.tui.interface import Interface from datashuttle.tui.screens import modal_dialogs, setup_gdrive, setup_ssh @@ -106,7 +108,19 @@ def compose(self) -> ComposeResult: ), ] - self.config_aws_s3_widgets = [] + self.config_aws_s3_widgets = [ + Label("AWS Access Key ID", id="configs_aws_access_key_id_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="AWS Access Key ID eg. EJIBCLSIP2K2PQK3CDON", + id="configs_aws_access_key_id_input", + ), + Label("AWS S3 Region", id="configs_aws_s3_region_label"), + Select( + ((region, region) for region in AWS_REGION.get_all_regions()), + id="configs_aws_s3_region_select", + ), + ] config_screen_widgets = [ Label("Local Path", id="configs_local_path_label"), @@ -402,6 +416,13 @@ def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: def switch_aws_widgets_display(self, display_aws: bool) -> None: + for widget in self.config_aws_s3_widgets: + widget.display = display_aws + + self.query_one("#configs_central_path_select_button").display = ( + not display_aws + ) + if ( self.interface is None or self.interface.get_configs()["connection_method"] != "aws_s3" @@ -566,7 +587,12 @@ def setup_configs_for_a_new_project(self) -> None: self.query_one("#configs_go_to_project_screen_button").visible = ( True ) - + message_template = ( + "A datashuttle project has now been created.\n\n " + "Next, setup the {method_name} connection. Once complete, navigate to the " + "'Main Menu' and proceed to the project page, where you will be " + "able to create and transfer project folders." + ) # Could not find a neater way to combine the push screen # while initiating the callback in one case but not the other. if cfg_kwargs["connection_method"] == "ssh": @@ -578,12 +604,8 @@ def setup_configs_for_a_new_project(self) -> None: "#configs_setup_ssh_connection_button" ).disabled = False - message = ( - "A datashuttle project has now been created.\n\n " - "Next, setup the SSH connection. Once complete, navigate to the " - "'Main Menu' and proceed to the project page, where you will be " - "able to create and transfer project folders." - ) + message = message_template.format(method_name="SSH") + elif cfg_kwargs["connection_method"] == "gdrive": self.query_one( @@ -593,12 +615,19 @@ def setup_configs_for_a_new_project(self) -> None: "#configs_setup_gdrive_connection_button" ).disabled = False - message = ( - "A datashuttle project has now been created.\n\n " - "Next, setup the Google Drive connection. Once complete, navigate to the " - "'Main Menu' and proceed to the project page, where you will be " - "able to create and transfer project folders." - ) + message = message_template.format(method_name="Google Drive") + + elif cfg_kwargs["connection_method"] == "aws_s3": + + self.query_one( + "#configs_setup_aws_connection_button" + ).visible = True + self.query_one( + "#configs_setup_aws_connection_button" + ).disabled = False + + message = message_template.format(method_name="AWS") + else: message = ( "A datashuttle project has now been created.\n\n " @@ -731,6 +760,24 @@ def fill_widgets_with_project_configs(self) -> None: ) input.value = value + # AWS Access Key ID + input = self.query_one("#configs_aws_access_key_id_input") + value = ( + "" + if cfg_to_load.get("aws_access_key_id", None) is None + else cfg_to_load["aws_access_key_id"] + ) + input.value = value + + # AWS S3 Region + select = self.query_one("#configs_aws_s3_region_select") + value = ( + Select.BLANK + if cfg_to_load.get("aws_s3_region", None) is None + else cfg_to_load["aws_s3_region"] + ) + select.value = value + def setup_widgets_to_display(self, connection_method: str | None) -> None: if connection_method: @@ -798,6 +845,7 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: cfg_kwargs["connection_method"] = connection_method + # SSH specific central_host_id = self.query_one( "#configs_central_host_id_input" ).value @@ -813,18 +861,32 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: None if central_host_username == "" else central_host_username ) + # Google Drive specific gdrive_client_id = self.query_one( "#configs_gdrive_client_id_input" ).value - gdrive_client_secret = self.query_one( - "#configs_gdrive_client_secret_input" - ).value - cfg_kwargs["gdrive_client_id"] = ( None if gdrive_client_id == "" else gdrive_client_id ) + + gdrive_client_secret = self.query_one( + "#configs_gdrive_client_secret_input" + ).value cfg_kwargs["gdrive_client_secret"] = ( None if gdrive_client_secret == "" else gdrive_client_secret ) + # AWS specific + aws_access_key_id = self.query_one( + "#configs_aws_access_key_id_input" + ).value + cfg_kwargs["aws_access_key_id"] = ( + None if aws_access_key_id == "" else aws_access_key_id + ) + + aws_s3_region = self.query_one("#configs_aws_s3_region_select").value + cfg_kwargs["aws_s3_region"] = ( + None if aws_s3_region == Select.BLANK else aws_s3_region + ) + return cfg_kwargs From b328af91520e60dd4f61093f8668660c68a4df4a Mon Sep 17 00:00:00 2001 From: shrey Date: Sun, 30 Mar 2025 17:53:48 +0530 Subject: [PATCH 10/39] add: setup aws connection via TUI --- datashuttle/tui/css/tui_menu.tcss | 36 ++++++++- datashuttle/tui/interface.py | 12 +++ datashuttle/tui/screens/setup_aws.py | 98 +++++++++++++++++++++++ datashuttle/tui/shared/configs_content.py | 24 +++++- 4 files changed, 168 insertions(+), 2 deletions(-) diff --git a/datashuttle/tui/css/tui_menu.tcss b/datashuttle/tui/css/tui_menu.tcss index 494c307e..8fcd9fd6 100644 --- a/datashuttle/tui/css/tui_menu.tcss +++ b/datashuttle/tui/css/tui_menu.tcss @@ -66,7 +66,10 @@ GetHelpScreen { } SetupGdriveScreen { align: center middle; - } +} +SetupAwsScreen { + align: center middle; +} #get_help_label { align: center middle; text-align: center; @@ -148,6 +151,37 @@ MessageBox:light > #messagebox_top_container { align: center middle; } +#setup_aws_screen_container { + height: 75%; + width: 80%; + align: center middle; + border: thick $panel-lighten-1; +} + +#setup_aws_messagebox_message_container { + align: center middle; + overflow: hidden auto; + margin: 0 1; +} + +#setup_aws_messagebox_message { + text-align: center; + padding: 0 2; +} + +#setup_aws_ok_button { + align: center bottom; + height: 3; +} + +#setup_aws_cancel_button { + align: center bottom; +} + +#setup_aws_buttons_horizontal { + align: center middle; +} + /* Configs Content ----------------------------------------------------------------- */ #configs_container { diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index e9520bb0..f34141f7 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -493,3 +493,15 @@ def setup_key_pair_and_rclone_config( except BaseException as e: return False, str(e) + + # Setup AWS + # ---------------------------------------------------------------------------------- + + def setup_aws_connection( + self, aws_secret_access_key: str + ) -> InterfaceOutput: + try: + self.project.setup_aws_s3_connection(aws_secret_access_key) + return True, None + except BaseException as e: + return False, str(e) diff --git a/datashuttle/tui/screens/setup_aws.py b/datashuttle/tui/screens/setup_aws.py index e69de29b..d216d751 100644 --- a/datashuttle/tui/screens/setup_aws.py +++ b/datashuttle/tui/screens/setup_aws.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from textual.app import ComposeResult + + from datashuttle.tui.interface import Interface + +from textual.containers import Container, Horizontal +from textual.screen import ModalScreen +from textual.widgets import Button, Input, Static + + +class SetupAwsScreen(ModalScreen): + """ + This dialog window handles the TUI equivalent of API's + `setup_aws_s3_connection()`. + """ + + def __init__(self, interface: Interface) -> None: + super(SetupAwsScreen, self).__init__() + + self.interface = interface + self.stage = 0 + + def compose(self) -> ComposeResult: + yield Container( + Horizontal( + Static( + "Ready to setup AWS connection. " "Press OK to proceed", + id="setup_aws_messagebox_message", + ), + id="setup_aws_messagebox_message_container", + ), + Input(password=True, id="setup_aws_secret_access_key_input"), + Horizontal( + Button("OK", id="setup_aws_ok_button"), + Button("Cancel", id="setup_aws_cancel_button"), + id="setup_aws_buttons_horizontal", + ), + id="setup_aws_screen_container", + ) + + def on_mount(self) -> None: + self.query_one("#setup_aws_secret_access_key_input").visible = False + + def on_button_pressed(self, event: Button.Pressed) -> None: + """ """ + if event.button.id == "setup_aws_cancel_button": + self.dismiss() + + if event.button.id == "setup_aws_ok_button": + if self.stage == 0: + self.prompt_user_for_aws_secret_access_key() + + elif self.stage == 1: + self.use_secret_access_key_to_setup_aws_connection() + + elif self.stage == 2: + self.dismiss() + + def prompt_user_for_aws_secret_access_key(self) -> None: + message = "Please Enter your AWS Secret Access Key" + + self.query_one("#setup_aws_messagebox_message").update(message) + self.query_one("#setup_aws_secret_access_key_input").visible = True + + self.stage += 1 + + def use_secret_access_key_to_setup_aws_connection(self) -> None: + secret_access_key = self.query_one( + "#setup_aws_secret_access_key_input" + ).value + + success, output = self.interface.setup_aws_connection( + secret_access_key + ) + + if success: + message = "AWS Connection Successful!" + self.query_one("#setup_aws_secret_access_key_input").visible = ( + False + ) + + else: + message = ( + f"AWS setup failed. Please check your configs and secret access key" + f"\n\n Traceback: {output}" + ) + self.query_one("#setup_aws_secret_access_key_input").disabled = ( + True + ) + + self.query_one("#setup_aws_ok_button").label = "Finish" + self.query_one("#setup_aws_messagebox_message").update(message) + self.query_one("#setup_aws_cancel_button").disabled = True + self.stage += 1 diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index f0474f58..b7cbcce5 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -28,7 +28,12 @@ from datashuttle import AWS_REGION from datashuttle.tui.custom_widgets import ClickableInput from datashuttle.tui.interface import Interface -from datashuttle.tui.screens import modal_dialogs, setup_gdrive, setup_ssh +from datashuttle.tui.screens import ( + modal_dialogs, + setup_aws, + setup_gdrive, + setup_ssh, +) from datashuttle.tui.tooltips import get_tooltip @@ -452,6 +457,9 @@ def on_button_pressed(self, event: Button.Pressed) -> None: elif event.button.id == "configs_setup_gdrive_connection_button": self.setup_gdrive_connection() + elif event.button.id == "configs_setup_aws_connection_button": + self.setup_aws_connection() + elif event.button.id == "configs_go_to_project_screen_button": self.parent_class.dismiss(self.interface) @@ -537,6 +545,20 @@ def setup_gdrive_connection(self) -> None: setup_gdrive.SetupGdriveScreen(self.interface) ) + def setup_aws_connection(self) -> None: + assert self.interface is not None, "type narrow flexible `interface`" + + if not self.widget_configs_match_saved_configs(): + self.parent_class.mainwindow.show_modal_error_dialog( + "The values set above must equal the datashuttle settings. " + "Either press 'Save' or reload this page." + ) + return + + self.parent_class.mainwindow.push_screen( + setup_aws.SetupAwsScreen(self.interface) + ) + def widget_configs_match_saved_configs(self): """ Check that the configs currently stored in the widgets From d2a4d835f7baa790fa8f494670474c5fb783dc65 Mon Sep 17 00:00:00 2001 From: shrey Date: Tue, 1 Apr 2025 03:59:35 +0530 Subject: [PATCH 11/39] feat: setup google drive on machines with no browser --- datashuttle/datashuttle_class.py | 17 +++++-- datashuttle/tui/css/tui_menu.tcss | 1 + datashuttle/tui/interface.py | 24 ++++++++- datashuttle/tui/screens/setup_gdrive.py | 67 +++++++++++++++++++++++-- datashuttle/utils/gdrive.py | 65 ++++++++++++++++++++++++ datashuttle/utils/rclone.py | 12 ++++- 6 files changed, 177 insertions(+), 9 deletions(-) diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 1a01dc89..359b793a 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -40,6 +40,7 @@ ds_logger, folders, formatting, + gdrive, getters, rclone, ssh, @@ -904,8 +905,15 @@ def setup_google_drive_connection(self) -> None: "setup-google-drive-connection-to-central-server", local_vars=locals(), ) + browser_available = gdrive.ask_user_for_browser() + config_token = None - self._setup_rclone_gdrive_config(log=True) + if not browser_available: + config_token = gdrive.prompt_and_get_config_token( + self.cfg, self.cfg.get_rclone_config_name("gdrive") + ) + + self._setup_rclone_gdrive_config(config_token, log=True) ds_logger.close_log_filehandler() # ------------------------------------------------------------------------- @@ -1515,9 +1523,12 @@ def _setup_rclone_central_local_filesystem_config(self) -> None: self.cfg.get_rclone_config_name("local_filesystem"), ) - def _setup_rclone_gdrive_config(self, log: bool) -> None: + def _setup_rclone_gdrive_config(self, config_token, log: bool) -> None: rclone.setup_rclone_config_for_gdrive( - self.cfg, self.cfg.get_rclone_config_name("gdrive"), log=log + self.cfg, + self.cfg.get_rclone_config_name("gdrive"), + config_token, + log=log, ) def _setup_rclone_aws_config( diff --git a/datashuttle/tui/css/tui_menu.tcss b/datashuttle/tui/css/tui_menu.tcss index 8fcd9fd6..9e10158d 100644 --- a/datashuttle/tui/css/tui_menu.tcss +++ b/datashuttle/tui/css/tui_menu.tcss @@ -128,6 +128,7 @@ MessageBox:light > #messagebox_top_container { } #gdrive_setup_messagebox_message_container { + height: 70%; align: center middle; overflow: hidden auto; margin: 0 1; diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index f34141f7..bcafd809 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -11,7 +11,7 @@ from datashuttle import DataShuttle from datashuttle.configs import load_configs -from datashuttle.utils import ssh +from datashuttle.utils import gdrive, ssh class Interface: @@ -494,6 +494,28 @@ def setup_key_pair_and_rclone_config( except BaseException as e: return False, str(e) + # Setup Google Drive + # ---------------------------------------------------------------------------------- + + def setup_google_drive_connection( + self, config_token: Optional[str] = None + ) -> InterfaceOutput: + try: + self.project._setup_rclone_gdrive_config(config_token, log=True) + return True, None + except BaseException as e: + return False, str(e) + + def get_rclone_message_for_gdrive_without_browser(self): + try: + output = gdrive.preliminary_for_setup_without_browser( + self.project.cfg, + self.project.cfg.get_rclone_config_name("gdrive"), + ) + return True, output + except BaseException as e: + return False, str(e) + # Setup AWS # ---------------------------------------------------------------------------------- diff --git a/datashuttle/tui/screens/setup_gdrive.py b/datashuttle/tui/screens/setup_gdrive.py index 87d2c3ed..1e9020ac 100644 --- a/datashuttle/tui/screens/setup_gdrive.py +++ b/datashuttle/tui/screens/setup_gdrive.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from textual.app import ComposeResult @@ -14,6 +14,7 @@ from textual.screen import ModalScreen from textual.widgets import ( Button, + Input, Static, ) @@ -62,6 +63,12 @@ def on_button_pressed(self, event: Button.Pressed) -> None: elif event.button.id == "setup_gdrive_yes_button": self.open_browser_and_setup_gdrive_connection() + elif event.button.id == "setup_gdrive_no_button": + self.prompt_user_for_config_token() + + elif event.button.id == "setup_gdrive_enter_button": + self.setup_gdrive_connection_using_config_token() + def ask_user_for_browser(self) -> None: message = ( "Are you running Datashuttle on a machine " @@ -98,9 +105,60 @@ async def _setup_gdrive_and_update_ui(): asyncio.create_task(_setup_gdrive_and_update_ui()) + def prompt_user_for_config_token(self) -> None: + + self.query_one("#setup_gdrive_yes_button").remove() + self.query_one("#setup_gdrive_no_button").remove() + + success, message = ( + self.interface.get_rclone_message_for_gdrive_without_browser() + ) + + if not success: + self.display_failed() + return + + self.query_one("#gdrive_setup_messagebox_message").update( + message + "\nPress shift+click to copy." + ) + + enter_button = Button("Enter", id="setup_gdrive_enter_button") + input_box = Input(id="setup_gdrive_config_token_input") + + self.query_one("#setup_gdrive_buttons_horizontal").mount( + enter_button, before="#setup_gdrive_cancel_button" + ) + self.query_one("#setup_gdrive_screen_container").mount( + input_box, before="#setup_gdrive_buttons_horizontal" + ) + + def setup_gdrive_connection_using_config_token(self) -> None: + + self.query_one("#setup_gdrive_config_token_input").disabled = True + + enter_button = self.query_one("#setup_gdrive_enter_button") + enter_button.disabled = True + + config_token = self.query_one("#setup_gdrive_config_token_input").value + + async def _setup_gdrive_and_update_ui(): + worker = self.setup_gdrive_connection(config_token) + self.setup_worker = worker + if worker.is_running: + await worker.wait() + + enter_button.remove() + + # TODO : check if successful + self.show_finish_screen() + + asyncio.create_task(_setup_gdrive_and_update_ui()) + @work(exclusive=True, thread=True) - def setup_gdrive_connection(self) -> Worker: - self.interface.project.setup_google_drive_connection() + def setup_gdrive_connection( + self, config_token: Optional[str] = None + ) -> Worker: + self.interface.setup_google_drive_connection(config_token) self.stage += 1 def show_finish_screen(self) -> None: @@ -111,3 +169,6 @@ def show_finish_screen(self) -> None: self.query_one("#setup_gdrive_buttons_horizontal").mount( Button("Finish", id="setup_gdrive_finish_button") ) + + def display_failed(self) -> None: + pass diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index e69de29b..5143323b 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -0,0 +1,65 @@ +import json + +from datashuttle.configs.config_class import Configs +from datashuttle.utils import rclone, utils + + +def preliminary_for_setup_without_browser( + cfg: Configs, rclone_config_name: str +): + client_id_key_value = ( + f"client_id {cfg['gdrive_client_id']} " + if cfg.get("gdrive_client_id", None) + else " " + ) + client_secret_key_value = ( + f"client_secret {cfg['gdrive_client_secret']} " + if cfg.get("gdrive_client_secret", None) + else "" + ) + output = rclone.call_rclone( + f"config create " + f"{rclone_config_name} " + f"drive " + f"{client_id_key_value}" + f"{client_secret_key_value}" + f"scope drive " + f"config_is_local=false " + f"--non-interactive", + pipe_std=True, + ) + + # TODO: make this more robust + output_json = json.loads(output.stdout) + return output_json["Option"]["Help"] + + +# ----------------------------------------------------------------------------- +# Python API +# ----------------------------------------------------------------------------- + + +def ask_user_for_browser() -> bool: + message = "Are you running Datashuttle on a machine with access to a web browser? (y/n): " + input_ = utils.get_user_input(message).lower() + + while input_ not in ["y", "n"]: + utils.print_message_to_user("Invalid input. Press either 'y' or 'n'.") + input_ = utils.get_user_input(message).lower() + + if input_ == "y": + answer = True + else: + answer = False + + # TODO: Add logging here + + return answer + + +def prompt_and_get_config_token(cfg: Configs, rclone_config_name: str) -> str: + message = preliminary_for_setup_without_browser(cfg, rclone_config_name) + + input_ = utils.get_user_input(message).strip() + + return input_ diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 8d2a0564..ecadf6e2 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -4,7 +4,7 @@ import tempfile from pathlib import Path from subprocess import CompletedProcess -from typing import Dict, List, Literal +from typing import Dict, List, Literal, Optional from datashuttle.configs.config_class import Configs from datashuttle.utils import utils @@ -153,6 +153,7 @@ def setup_rclone_config_for_ssh( def setup_rclone_config_for_gdrive( cfg: Configs, rclone_config_name: str, + config_token: Optional[str] = None, log: bool = True, ): client_id_key_value = ( @@ -165,13 +166,20 @@ def setup_rclone_config_for_gdrive( if cfg.get("gdrive_client_secret", None) else "" ) + + extra_args = ( + f"config_is_local=false config_token={config_token}" + if config_token + else "" + ) call_rclone( f"config create " f"{rclone_config_name} " f"drive " f"{client_id_key_value}" f"{client_secret_key_value}" - f"scope drive", + f"scope drive " + f"{extra_args}", pipe_std=True, ) From d9ae8644a97ddeb2c5a6313e22f36679363f1e54 Mon Sep 17 00:00:00 2001 From: shrey Date: Tue, 1 Apr 2025 04:06:37 +0530 Subject: [PATCH 12/39] fix: minor bug --- datashuttle/utils/folders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index 13abeee4..b91776a4 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -602,7 +602,7 @@ def search_remote_central_for_folders( except Exception: if verbose: utils.log_and_message( - f"Error searching files at {search_path.to_posix()}" + f"Error searching files at {search_path.as_posix()}" ) return all_folder_names, all_filenames From 965bb17ddfa688baf08d2054b7f06668851111fa Mon Sep 17 00:00:00 2001 From: shrey Date: Tue, 1 Apr 2025 04:27:18 +0530 Subject: [PATCH 13/39] fix: logical error --- datashuttle/utils/folders.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index b91776a4..d16d8be4 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -517,7 +517,11 @@ def search_for_folders( verbose : If `True`, when a search folder cannot be found, a message will be printed with the missing path. """ - if local_or_central == "central": + if local_or_central == "central" and cfg["connection_method"] in [ + "ssh", + "gdrive", + "aws_s3", + ]: if cfg["connection_method"] == "ssh": all_folder_names, all_filenames = ( ssh.search_ssh_central_for_folders( From 377cea790d69732a5b4df3e78116ec1a070f9462 Mon Sep 17 00:00:00 2001 From: shrey Date: Wed, 2 Apr 2025 01:14:36 +0530 Subject: [PATCH 14/39] add: logging for google drive connections --- datashuttle/datashuttle_class.py | 6 ++++-- datashuttle/tui/interface.py | 3 ++- datashuttle/utils/gdrive.py | 23 ++++++++++++++++------- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 359b793a..e96d7b15 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -905,12 +905,14 @@ def setup_google_drive_connection(self) -> None: "setup-google-drive-connection-to-central-server", local_vars=locals(), ) - browser_available = gdrive.ask_user_for_browser() + browser_available = gdrive.ask_user_for_browser(log=True) config_token = None if not browser_available: config_token = gdrive.prompt_and_get_config_token( - self.cfg, self.cfg.get_rclone_config_name("gdrive") + self.cfg, + self.cfg.get_rclone_config_name("gdrive"), + log=True, ) self._setup_rclone_gdrive_config(config_token, log=True) diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index bcafd809..19589b28 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -501,7 +501,7 @@ def setup_google_drive_connection( self, config_token: Optional[str] = None ) -> InterfaceOutput: try: - self.project._setup_rclone_gdrive_config(config_token, log=True) + self.project._setup_rclone_gdrive_config(config_token, log=False) return True, None except BaseException as e: return False, str(e) @@ -511,6 +511,7 @@ def get_rclone_message_for_gdrive_without_browser(self): output = gdrive.preliminary_for_setup_without_browser( self.project.cfg, self.project.cfg.get_rclone_config_name("gdrive"), + log=False, ) return True, output except BaseException as e: diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index 5143323b..4f73e2d9 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -5,7 +5,7 @@ def preliminary_for_setup_without_browser( - cfg: Configs, rclone_config_name: str + cfg: Configs, rclone_config_name: str, log: bool = True ): client_id_key_value = ( f"client_id {cfg['gdrive_client_id']} " @@ -31,7 +31,12 @@ def preliminary_for_setup_without_browser( # TODO: make this more robust output_json = json.loads(output.stdout) - return output_json["Option"]["Help"] + message = output_json["Option"]["Help"] + + if log: + utils.log(message) + + return message # ----------------------------------------------------------------------------- @@ -39,7 +44,7 @@ def preliminary_for_setup_without_browser( # ----------------------------------------------------------------------------- -def ask_user_for_browser() -> bool: +def ask_user_for_browser(log: bool = True) -> bool: message = "Are you running Datashuttle on a machine with access to a web browser? (y/n): " input_ = utils.get_user_input(message).lower() @@ -52,14 +57,18 @@ def ask_user_for_browser() -> bool: else: answer = False - # TODO: Add logging here + if log: + utils.log(message) return answer -def prompt_and_get_config_token(cfg: Configs, rclone_config_name: str) -> str: - message = preliminary_for_setup_without_browser(cfg, rclone_config_name) - +def prompt_and_get_config_token( + cfg: Configs, rclone_config_name: str, log: bool = True +) -> str: + message = preliminary_for_setup_without_browser( + cfg, rclone_config_name, log=log + ) input_ = utils.get_user_input(message).strip() return input_ From 409d448e6da6bb09be0427a729825b369cb3afdf Mon Sep 17 00:00:00 2001 From: shrey Date: Sat, 31 May 2025 04:52:02 +0530 Subject: [PATCH 15/39] refactor: move google drive client secret to be entered at runtime while setting up connection --- datashuttle/configs/canonical_configs.py | 16 ++---- datashuttle/datashuttle_class.py | 23 +++++++-- datashuttle/utils/aws.py | 8 ++- datashuttle/utils/folders.py | 33 ++++++------ datashuttle/utils/gdrive.py | 65 +++++++++++++++++++++--- datashuttle/utils/rclone.py | 7 +-- 6 files changed, 105 insertions(+), 47 deletions(-) diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 1b7283ed..8b341f6e 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -46,7 +46,7 @@ def get_canonical_configs() -> dict: "central_host_id": Optional[str], "central_host_username": Optional[str], "gdrive_client_id": Optional[str], - "gdrive_client_secret": Optional[str], + # "gdrive_client_secret": Optional[str], "aws_access_key_id": Optional[str], "aws_s3_region": Optional[Literal[*AWS_REGION.get_all_regions()]], # "aws_s3_endpoint_url": Optional[str], @@ -138,18 +138,10 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: # Check gdrive settings elif config_dict["connection_method"] == "gdrive" and ( - ( - config_dict["gdrive_client_id"] - and not config_dict["gdrive_client_secret"] - ) - or ( - not config_dict["gdrive_client_id"] - and config_dict["gdrive_client_secret"] - ) + not config_dict["gdrive_client_id"] ): - utils.log_and_raise_error( - "Both gdrive_client_id and gdrive_client_secret must be present together.", - ConfigError, + utils.log_and_message( + "`gdrive_client_id` not found in config. default rlcone client will be used (slower)." ) # Check AWS settings diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index e96d7b15..0a41aebc 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -905,17 +905,26 @@ def setup_google_drive_connection(self) -> None: "setup-google-drive-connection-to-central-server", local_vars=locals(), ) + + if self.cfg["gdrive_client_id"]: + gdrive_client_secret = gdrive.get_client_secret() + else: + gdrive_client_secret = None + browser_available = gdrive.ask_user_for_browser(log=True) config_token = None if not browser_available: config_token = gdrive.prompt_and_get_config_token( self.cfg, + gdrive_client_secret, self.cfg.get_rclone_config_name("gdrive"), log=True, ) - self._setup_rclone_gdrive_config(config_token, log=True) + self._setup_rclone_gdrive_config( + gdrive_client_secret, config_token, log=True + ) ds_logger.close_log_filehandler() # ------------------------------------------------------------------------- @@ -932,7 +941,7 @@ def setup_aws_s3_connection(self, aws_secret_access_key: str) -> None: self._setup_rclone_aws_config(aws_secret_access_key, log=True) - aws.check_successful_connection(self.cfg) + aws.check_successful_connection_and_raise_error_on_fail(self.cfg) utils.log_and_message("AWS Connection Successful.") aws.warn_if_bucket_absent(self.cfg) @@ -951,7 +960,6 @@ def make_config_file( central_host_id: Optional[str] = None, central_host_username: Optional[str] = None, gdrive_client_id: Optional[str] = None, - gdrive_client_secret: Optional[str] = None, aws_access_key_id: Optional[str] = None, aws_s3_region: Optional[str] = None, ) -> None: @@ -1019,7 +1027,6 @@ def make_config_file( "central_host_id": central_host_id, "central_host_username": central_host_username, "gdrive_client_id": gdrive_client_id, - "gdrive_client_secret": gdrive_client_secret, "aws_access_key_id": aws_access_key_id, "aws_s3_region": aws_s3_region, }, @@ -1525,9 +1532,15 @@ def _setup_rclone_central_local_filesystem_config(self) -> None: self.cfg.get_rclone_config_name("local_filesystem"), ) - def _setup_rclone_gdrive_config(self, config_token, log: bool) -> None: + def _setup_rclone_gdrive_config( + self, + gdrive_client_secret: str | None, + config_token: str | None, + log: bool, + ) -> None: rclone.setup_rclone_config_for_gdrive( self.cfg, + gdrive_client_secret, self.cfg.get_rclone_config_name("gdrive"), config_token, log=log, diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py index 23a0a972..2e6651ad 100644 --- a/datashuttle/utils/aws.py +++ b/datashuttle/utils/aws.py @@ -4,8 +4,12 @@ from datashuttle.utils import rclone, utils -def check_successful_connection(cfg: Configs) -> None: - """Check for a successful connection by executing an `ls` command""" +def check_successful_connection_and_raise_error_on_fail(cfg: Configs) -> None: + """ + Check for a successful connection by executing an `ls` command. It pings the + the central host to list files and folders in the root directory. + If the command fails, it raises a ConnectionError with the error message. + """ output = rclone.call_rclone( f"ls {cfg.get_rclone_config_name()}:", pipe_std=True diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index d16d8be4..1e1c9f21 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -533,14 +533,9 @@ def search_for_folders( ) ) - elif ( - cfg["connection_method"] == "gdrive" - or cfg["connection_method"] == "aws_s3" - ): - all_folder_names, all_filenames = ( - search_remote_central_for_folders( - search_path, search_prefix, cfg, verbose, return_full_path - ) + else: + all_folder_names, all_filenames = search_gdrive_or_aws_for_folders( + search_path, search_prefix, cfg, return_full_path ) else: @@ -557,13 +552,17 @@ def search_for_folders( return all_folder_names, all_filenames -def search_remote_central_for_folders( +def search_gdrive_or_aws_for_folders( search_path: Path, search_prefix: str, cfg: Configs, - verbose: bool = True, return_full_path: bool = False, ) -> Tuple[List[Any], List[Any]]: + """ + Searches for files and folders in central path using `rclone lsjson` command. + This command lists all the files and folders in the central path in a json format. + The json contains file/folder info about each file/folder like name, type, etc. + """ command = ( "rclone lsjson " @@ -581,10 +580,9 @@ def search_remote_central_for_folders( all_filenames: List[str] = [] if output.returncode != 0: - if verbose: - utils.log_and_message( - f"Error searching files at {search_path.as_posix()} \n {output.stderr.decode("utf-8") if output.stderr else ""}" - ) + utils.log_and_message( + f"Error searching files at {search_path.as_posix()} \n {output.stderr.decode('utf-8') if output.stderr else ""}" + ) return all_folder_names, all_filenames files_and_folders = json.loads(output.stdout) @@ -604,10 +602,9 @@ def search_remote_central_for_folders( all_filenames.append(to_append) except Exception: - if verbose: - utils.log_and_message( - f"Error searching files at {search_path.as_posix()}" - ) + utils.log_and_message( + f"Error searching files at {search_path.as_posix()}" + ) return all_folder_names, all_filenames diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index 4f73e2d9..56602897 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -1,20 +1,32 @@ +import getpass import json +import sys from datashuttle.configs.config_class import Configs from datashuttle.utils import rclone, utils +# ----------------------------------------------------------------------------- +# Helper Functions +# ----------------------------------------------------------------------------- + +# These functions are used by both API and TUI for setting up connections to google drive. + def preliminary_for_setup_without_browser( - cfg: Configs, rclone_config_name: str, log: bool = True -): + cfg: Configs, + gdrive_client_secret: str | None, + rclone_config_name: str, + log: bool = True, +) -> str: + # TODO: Add docstrings client_id_key_value = ( f"client_id {cfg['gdrive_client_id']} " - if cfg.get("gdrive_client_id", None) + if cfg["gdrive_client_id"] else " " ) client_secret_key_value = ( - f"client_secret {cfg['gdrive_client_secret']} " - if cfg.get("gdrive_client_secret", None) + f"client_secret {gdrive_client_secret} " + if gdrive_client_secret else "" ) output = rclone.call_rclone( @@ -59,16 +71,55 @@ def ask_user_for_browser(log: bool = True) -> bool: if log: utils.log(message) + utils.log(f"User answer: {answer}") return answer def prompt_and_get_config_token( - cfg: Configs, rclone_config_name: str, log: bool = True + cfg: Configs, + gdrive_client_secret: str | None, + rclone_config_name: str, + log: bool = True, ) -> str: + # TODO: Add docstrings message = preliminary_for_setup_without_browser( - cfg, rclone_config_name, log=log + cfg, gdrive_client_secret, rclone_config_name, log=log ) input_ = utils.get_user_input(message).strip() return input_ + + +def get_client_secret(log: bool = True) -> str: + if not sys.stdin.isatty(): + proceed = input( + "\nWARNING!\nThe next step is to enter a google drive client secret, but it is not possible\n" + "to hide your client secret while entering it in the current terminal.\n" + "This can occur if running the command in an IDE.\n\n" + "Press 'y' to proceed to client secret entry. " + "The characters will not be hidden!\n" + "Alternatively, run ssh setup after starting Python in your " + "system terminal \nrather than through an IDE: " + ) + if proceed != "y": + utils.print_message_to_user( + "Quitting google drive setup as 'y' not pressed." + ) + utils.log_and_raise_error( + "Google Drive setup aborted by user.", ConnectionAbortedError + ) + + gdrive_client_secret = input( + "Please enter your google drive client secret. Characters will not be hidden: " + ) + + else: + gdrive_client_secret = getpass.getpass( + "Please enter your google drive client secret: " + ) + + if log: + utils.log("Google Drive client secret entered by user.") + + return gdrive_client_secret.strip() diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index ecadf6e2..2db48e9b 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -152,18 +152,19 @@ def setup_rclone_config_for_ssh( def setup_rclone_config_for_gdrive( cfg: Configs, + gdrive_client_secret: str | None, rclone_config_name: str, config_token: Optional[str] = None, log: bool = True, ): client_id_key_value = ( f"client_id {cfg['gdrive_client_id']} " - if cfg.get("gdrive_client_id", None) + if cfg["gdrive_client_id"] else " " ) client_secret_key_value = ( - f"client_secret {cfg['gdrive_client_secret']} " - if cfg.get("gdrive_client_secret", None) + f"client_secret {gdrive_client_secret} " + if gdrive_client_secret else "" ) From 0b33b868a76795c509951dc9df5e5d19af580ba2 Mon Sep 17 00:00:00 2001 From: shrey Date: Sat, 31 May 2025 23:23:44 +0530 Subject: [PATCH 16/39] refactor: aws_regions.py; provide aws secret access key at runtime --- datashuttle/__init__.py | 2 +- datashuttle/configs/aws_regions.py | 88 +++++++++++++++++++++++ datashuttle/configs/canonical_configs.py | 4 +- datashuttle/configs/regions.py | 35 --------- datashuttle/datashuttle_class.py | 4 +- datashuttle/tui/interface.py | 3 +- datashuttle/tui/shared/configs_content.py | 4 +- datashuttle/utils/aws.py | 36 ++++++++++ datashuttle/utils/gdrive.py | 2 +- 9 files changed, 135 insertions(+), 43 deletions(-) create mode 100644 datashuttle/configs/aws_regions.py delete mode 100644 datashuttle/configs/regions.py diff --git a/datashuttle/__init__.py b/datashuttle/__init__.py index c9a66f3c..2b9121e2 100644 --- a/datashuttle/__init__.py +++ b/datashuttle/__init__.py @@ -2,7 +2,7 @@ from datashuttle.datashuttle_class import DataShuttle from datashuttle.datashuttle_functions import quick_validate_project -from datashuttle.configs.regions import AWS_REGION +from datashuttle.configs.aws_regions import AWS_REGION try: diff --git a/datashuttle/configs/aws_regions.py b/datashuttle/configs/aws_regions.py new file mode 100644 index 00000000..60e14beb --- /dev/null +++ b/datashuttle/configs/aws_regions.py @@ -0,0 +1,88 @@ +from typing import Dict, List + +# ----------------------------------------------------------------------------- +# AWS regions +# ----------------------------------------------------------------------------- + +# These function are used for type checking and providing intellisense to the developer + + +def get_aws_regions() -> Dict[str, str]: + aws_regions = { + "US_EAST_1": "us-east-1", + "US_EAST_2": "us-east-2", + "US_WEST_1": "us-west-1", + "US_WEST_2": "us-west-2", + "CA_CENTRAL_1": "ca-central-1", + "EU_WEST_1": "eu-west-1", + "EU_WEST_2": "eu-west-2", + "EU_WEST_3": "eu-west-3", + "EU_NORTH_1": "eu-north-1", + "EU_SOUTH_1": "eu-south-1", + "EU_CENTRAL_1": "eu-central-1", + "AP_SOUTHEAST_1": "ap-southeast-1", + "AP_SOUTHEAST_2": "ap-southeast-2", + "AP_NORTHEAST_1": "ap-northeast-1", + "AP_NORTHEAST_2": "ap-northeast-2", + "AP_NORTHEAST_3": "ap-northeast-3", + "AP_SOUTH_1": "ap-south-1", + "AP_EAST_1": "ap-east-1", + "SA_EAST_1": "sa-east-1", + "IL_CENTRAL_1": "il-central-1", + "ME_SOUTH_1": "me-south-1", + "AF_SOUTH_1": "af-south-1", + "CN_NORTH_1": "cn-north-1", + "CN_NORTHWEST_1": "cn-northwest-1", + "US_GOV_EAST_1": "us-gov-east-1", + "US_GOV_WEST_1": "us-gov-west-1", + } + return aws_regions + + +def get_aws_regions_list() -> List[str]: + return list(get_aws_regions().values()) + + +AWS_REGIONS_DICT = get_aws_regions() # runtime constant + + +class AWS_REGION: + """ + A class to represent AWS regions as constants. + It is used to provide intellisense for AWS regions in IDEs. + """ + + US_EAST_1 = AWS_REGIONS_DICT["US_EAST_1"] + US_EAST_2 = AWS_REGIONS_DICT["US_EAST_2"] + US_WEST_1 = AWS_REGIONS_DICT["US_WEST_1"] + US_WEST_2 = AWS_REGIONS_DICT["US_WEST_2"] + CA_CENTRAL_1 = AWS_REGIONS_DICT["CA_CENTRAL_1"] + EU_WEST_1 = AWS_REGIONS_DICT["EU_WEST_1"] + EU_WEST_2 = AWS_REGIONS_DICT["EU_WEST_2"] + EU_WEST_3 = AWS_REGIONS_DICT["EU_WEST_3"] + EU_NORTH_1 = AWS_REGIONS_DICT["EU_NORTH_1"] + EU_SOUTH_1 = AWS_REGIONS_DICT["EU_SOUTH_1"] + EU_CENTRAL_1 = AWS_REGIONS_DICT["EU_CENTRAL_1"] + AP_SOUTHEAST_1 = AWS_REGIONS_DICT["AP_SOUTHEAST_1"] + AP_SOUTHEAST_2 = AWS_REGIONS_DICT["AP_SOUTHEAST_2"] + AP_NORTHEAST_1 = AWS_REGIONS_DICT["AP_NORTHEAST_1"] + AP_NORTHEAST_2 = AWS_REGIONS_DICT["AP_NORTHEAST_2"] + AP_NORTHEAST_3 = AWS_REGIONS_DICT["AP_NORTHEAST_3"] + AP_SOUTH_1 = AWS_REGIONS_DICT["AP_SOUTH_1"] + AP_EAST_1 = AWS_REGIONS_DICT["AP_EAST_1"] + SA_EAST_1 = AWS_REGIONS_DICT["SA_EAST_1"] + IL_CENTRAL_1 = AWS_REGIONS_DICT["IL_CENTRAL_1"] + ME_SOUTH_1 = AWS_REGIONS_DICT["ME_SOUTH_1"] + AF_SOUTH_1 = AWS_REGIONS_DICT["AF_SOUTH_1"] + CN_NORTH_1 = AWS_REGIONS_DICT["CN_NORTH_1"] + CN_NORTHWEST_1 = AWS_REGIONS_DICT["CN_NORTHWEST_1"] + US_GOV_EAST_1 = AWS_REGIONS_DICT["US_GOV_EAST_1"] + US_GOV_WEST_1 = AWS_REGIONS_DICT["US_GOV_WEST_1"] + + @classmethod + def get_all_regions(cls): + return [ + value + for key, value in vars(cls).items() + if not key.startswith("__") and isinstance(value, str) + ] diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 8b341f6e..17bf12a0 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -27,7 +27,7 @@ import typeguard -from datashuttle.configs.regions import AWS_REGION +from datashuttle.configs.aws_regions import get_aws_regions_list from datashuttle.utils import folders, utils from datashuttle.utils.custom_exceptions import ConfigError @@ -48,7 +48,7 @@ def get_canonical_configs() -> dict: "gdrive_client_id": Optional[str], # "gdrive_client_secret": Optional[str], "aws_access_key_id": Optional[str], - "aws_s3_region": Optional[Literal[*AWS_REGION.get_all_regions()]], + "aws_s3_region": Optional[Literal[*get_aws_regions_list()]], # "aws_s3_endpoint_url": Optional[str], } diff --git a/datashuttle/configs/regions.py b/datashuttle/configs/regions.py deleted file mode 100644 index b42ecf30..00000000 --- a/datashuttle/configs/regions.py +++ /dev/null @@ -1,35 +0,0 @@ -class AWS_REGION: - US_EAST_1 = "us-east-1" - US_EAST_2 = "us-east-2" - US_WEST_1 = "us-west-1" - US_WEST_2 = "us-west-2" - CA_CENTRAL_1 = "ca-central-1" - EU_WEST_1 = "eu-west-1" - EU_WEST_2 = "eu-west-2" - EU_WEST_3 = "eu-west-3" - EU_NORTH_1 = "eu-north-1" - EU_SOUTH_1 = "eu_south-1" - EUI_CENTRAL_1 = "eu-central-1" - AP_SOUTHEAST_1 = "ap-southeast-1" - AP_SOUTHEAST_2 = "ap-southeast-2" - API_NORTHEAST_1 = "ap-northeast-1" - AP_NORTHEAST_2 = "ap-northeast-2" - AP_NORTHEAST_3 = "ap-northeast-3" - AP_SOUTH_1 = "ap-south-1" - AP_EAST_1 = "ap-east-1" - SA_EAST_1 = "sa-east-1" - IL_CENTRAL_1 = "il-central-1" - ME_SOUTH_1 = "me-south-1" - AF_SOUTH_1 = "af-south-1" - CN_NORTH_1 = "cn-north-1" - CN_NORTHWEST_1 = "cn-northwest-1" - US_GOV_EAST_1 = "us-gov-east-1" - US_GOV_WEST_1 = "us-gov-west-1" - - @classmethod - def get_all_regions(cls): - return [ - value - for key, value in vars(cls).items() - if not key.startswith("__") and isinstance(value, str) - ] diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 0a41aebc..cb849f0b 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -933,12 +933,14 @@ def setup_google_drive_connection(self) -> None: @requires_aws_configs @check_configs_set - def setup_aws_s3_connection(self, aws_secret_access_key: str) -> None: + def setup_aws_s3_connection(self) -> None: self._start_log( "setup-aws-s3-connection-to-central-server", local_vars=locals(), ) + aws_secret_access_key = aws.get_aws_secret_access_key() + self._setup_rclone_aws_config(aws_secret_access_key, log=True) aws.check_successful_connection_and_raise_error_on_fail(self.cfg) diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index 19589b28..917fd5be 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -498,7 +498,8 @@ def setup_key_pair_and_rclone_config( # ---------------------------------------------------------------------------------- def setup_google_drive_connection( - self, config_token: Optional[str] = None + self, + config_token: Optional[str] = None, ) -> InterfaceOutput: try: self.project._setup_rclone_gdrive_config(config_token, log=False) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index b7cbcce5..3aaae7a2 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -25,7 +25,7 @@ Static, ) -from datashuttle import AWS_REGION +from datashuttle.configs.aws_regions import get_aws_regions_list from datashuttle.tui.custom_widgets import ClickableInput from datashuttle.tui.interface import Interface from datashuttle.tui.screens import ( @@ -122,7 +122,7 @@ def compose(self) -> ComposeResult: ), Label("AWS S3 Region", id="configs_aws_s3_region_label"), Select( - ((region, region) for region in AWS_REGION.get_all_regions()), + ((region, region) for region in get_aws_regions_list()), id="configs_aws_s3_region_select", ), ] diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py index 2e6651ad..c43abbac 100644 --- a/datashuttle/utils/aws.py +++ b/datashuttle/utils/aws.py @@ -1,4 +1,6 @@ +import getpass import json +import sys from datashuttle.configs.config_class import Configs from datashuttle.utils import rclone, utils @@ -52,3 +54,37 @@ def warn_if_bucket_absent(cfg: Configs) -> None: f"For data transfer to happen, the bucket must exist.\n" f"Please change the bucket name in the `central_path`. " ) + + +def get_aws_secret_access_key(log: bool = True) -> str: + if not sys.stdin.isatty(): + proceed = input( + "\nWARNING!\nThe next step is to enter a AWS secret access key, but it is not possible\n" + "to hide your secret access key while entering it in the current terminal.\n" + "This can occur if running the command in an IDE.\n\n" + "Press 'y' to proceed to secret key entry. " + "The characters will not be hidden!\n" + "Alternatively, run AWS S3 setup after starting Python in your " + "system terminal \nrather than through an IDE: " + ) + if proceed != "y": + utils.print_message_to_user( + "Quitting AWS S3 setup as 'y' not pressed." + ) + utils.log_and_raise_error( + "AWS S3 setup aborted by user.", ConnectionAbortedError + ) + + aws_secret_access_key = input( + "Please enter your AWS secret access key. Characters will not be hidden: " + ) + + else: + aws_secret_access_key = getpass.getpass( + "Please enter your AWS secret access key: " + ) + + if log: + utils.log("AWS secret access key entered by user.") + + return aws_secret_access_key.strip() diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index 56602897..ed9fa2bc 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -99,7 +99,7 @@ def get_client_secret(log: bool = True) -> str: "This can occur if running the command in an IDE.\n\n" "Press 'y' to proceed to client secret entry. " "The characters will not be hidden!\n" - "Alternatively, run ssh setup after starting Python in your " + "Alternatively, run google drive setup after starting Python in your " "system terminal \nrather than through an IDE: " ) if proceed != "y": From 0733f518d542d81d6aab04f162d6ce1738c0349a Mon Sep 17 00:00:00 2001 From: shrey Date: Sun, 1 Jun 2025 00:14:23 +0530 Subject: [PATCH 17/39] add: docstrings to gdrive.py --- datashuttle/utils/gdrive.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index ed9fa2bc..979bd9b0 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -18,7 +18,24 @@ def preliminary_for_setup_without_browser( rclone_config_name: str, log: bool = True, ) -> str: - # TODO: Add docstrings + """ + This function prepares the rclone configuration for Google Drive without using a browser. + + The `config_is_local=false` flag tells rclone that the configuration process is being run + on a headless machine which does not have access to a browser. + + The `--non-interactive` flag is used to control rclone's behaviour while running it through + external applications. An `rclone config create` command would assume default values for config + variables in an interactive mode. If the `--non-interactive` flag is provided and rclone needs + the user to input some detail, a JSON blob will be returned with the question in it. For this + particular setup, rclone outputs a command for user to run on a machine with a browser. + + This function runs `rclone config create` with the user credentials and returns the rclone's output info. + This output info is presented to the user while asking for a `config_token`. + + Next, the user will run rclone's given command, authenticate with google drive and input the + config token given by rclone for datashuttle to proceed with the setup. + """ client_id_key_value = ( f"client_id {cfg['gdrive_client_id']} " if cfg["gdrive_client_id"] @@ -82,7 +99,11 @@ def prompt_and_get_config_token( rclone_config_name: str, log: bool = True, ) -> str: - # TODO: Add docstrings + """ + This function presents the rclone's output/message to ask the user to run a command, authenticate + with google drive and input the `config_token` generated by rclone. The `config_token` is + then used to complete rclone's config setup for google drive. + """ message = preliminary_for_setup_without_browser( cfg, gdrive_client_secret, rclone_config_name, log=log ) From 023ada350de77514108dde2e71b18afe4f9b14df Mon Sep 17 00:00:00 2001 From: shrey Date: Sun, 1 Jun 2025 15:55:36 +0530 Subject: [PATCH 18/39] add: root_folder_id config to google drive; some refactor --- datashuttle/configs/canonical_configs.py | 20 ++++++++++------ datashuttle/datashuttle_class.py | 7 +++++- datashuttle/utils/aws.py | 17 -------------- datashuttle/utils/gdrive.py | 2 ++ datashuttle/utils/rclone.py | 30 ++++++++++++++++++++++-- 5 files changed, 49 insertions(+), 27 deletions(-) diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 17bf12a0..29ab03fb 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -46,7 +46,7 @@ def get_canonical_configs() -> dict: "central_host_id": Optional[str], "central_host_username": Optional[str], "gdrive_client_id": Optional[str], - # "gdrive_client_secret": Optional[str], + "gdrive_root_folder_id": Optional[str], "aws_access_key_id": Optional[str], "aws_s3_region": Optional[Literal[*get_aws_regions_list()]], # "aws_s3_endpoint_url": Optional[str], @@ -137,12 +137,18 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: ) # Check gdrive settings - elif config_dict["connection_method"] == "gdrive" and ( - not config_dict["gdrive_client_id"] - ): - utils.log_and_message( - "`gdrive_client_id` not found in config. default rlcone client will be used (slower)." - ) + elif config_dict["connection_method"] == "gdrive": + if not config_dict["gdrive_root_folder_id"]: + utils.log_and_raise_error( + "'gdrive_root_folder_id' is required if 'connection_method' " + "is 'gdrive'.", + ConfigError, + ) + + if not config_dict["gdrive_client_id"]: + utils.log_and_message( + "`gdrive_client_id` not found in config. default rlcone client will be used (slower)." + ) # Check AWS settings elif config_dict["connection_method"] == "aws_s3" and ( diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index cb849f0b..397af6fc 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -925,6 +925,9 @@ def setup_google_drive_connection(self) -> None: self._setup_rclone_gdrive_config( gdrive_client_secret, config_token, log=True ) + + rclone.check_successful_connection_and_raise_error_on_fail(self.cfg) + ds_logger.close_log_filehandler() # ------------------------------------------------------------------------- @@ -943,7 +946,7 @@ def setup_aws_s3_connection(self) -> None: self._setup_rclone_aws_config(aws_secret_access_key, log=True) - aws.check_successful_connection_and_raise_error_on_fail(self.cfg) + rclone.check_successful_connection_and_raise_error_on_fail(self.cfg) utils.log_and_message("AWS Connection Successful.") aws.warn_if_bucket_absent(self.cfg) @@ -962,6 +965,7 @@ def make_config_file( central_host_id: Optional[str] = None, central_host_username: Optional[str] = None, gdrive_client_id: Optional[str] = None, + gdrive_root_folder_id: Optional[str] = None, aws_access_key_id: Optional[str] = None, aws_s3_region: Optional[str] = None, ) -> None: @@ -1029,6 +1033,7 @@ def make_config_file( "central_host_id": central_host_id, "central_host_username": central_host_username, "gdrive_client_id": gdrive_client_id, + "gdrive_root_folder_id": gdrive_root_folder_id, "aws_access_key_id": aws_access_key_id, "aws_s3_region": aws_s3_region, }, diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py index c43abbac..d9e8d55e 100644 --- a/datashuttle/utils/aws.py +++ b/datashuttle/utils/aws.py @@ -6,23 +6,6 @@ from datashuttle.utils import rclone, utils -def check_successful_connection_and_raise_error_on_fail(cfg: Configs) -> None: - """ - Check for a successful connection by executing an `ls` command. It pings the - the central host to list files and folders in the root directory. - If the command fails, it raises a ConnectionError with the error message. - """ - - output = rclone.call_rclone( - f"ls {cfg.get_rclone_config_name()}:", pipe_std=True - ) - - if output.returncode != 0: - utils.log_and_raise_error( - output.stderr.decode("utf-8"), ConnectionError - ) - - def check_if_aws_bucket_exists(cfg: Configs) -> bool: output = rclone.call_rclone( f"lsjson {cfg.get_rclone_config_name()}:", pipe_std=True diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index 979bd9b0..b80f2615 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -53,12 +53,14 @@ def preliminary_for_setup_without_browser( f"{client_id_key_value}" f"{client_secret_key_value}" f"scope drive " + f"root_folder_id {cfg['gdrive_root_folder_id']} " f"config_is_local=false " f"--non-interactive", pipe_std=True, ) # TODO: make this more robust + # extracting rclone's message from the json output_json = json.loads(output.stdout) message = output_json["Option"]["Help"] diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 2db48e9b..42518009 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -173,17 +173,23 @@ def setup_rclone_config_for_gdrive( if config_token else "" ) - call_rclone( + output = call_rclone( f"config create " f"{rclone_config_name} " f"drive " f"{client_id_key_value}" f"{client_secret_key_value}" f"scope drive " + f"root_folder_id {cfg['gdrive_root_folder_id']} " f"{extra_args}", pipe_std=True, ) + if output.returncode != 0: + utils.log_and_raise_error( + output.stderr.decode("utf-8"), ConnectionError + ) + if log: log_rclone_config_output() @@ -194,7 +200,7 @@ def setup_rclone_config_for_aws_s3( rclone_config_name: str, log: bool = True, ): - call_rclone( + output = call_rclone( "config create " f"{rclone_config_name} " "s3 provider AWS " @@ -205,10 +211,30 @@ def setup_rclone_config_for_aws_s3( pipe_std=True, ) + if output.returncode != 0: + utils.log_and_raise_error( + output.stderr.decode("utf-8"), ConnectionError + ) + if log: log_rclone_config_output() +def check_successful_connection_and_raise_error_on_fail(cfg: Configs) -> None: + """ + Check for a successful connection by executing an `ls` command. It pings the + the central host to list files and folders in the root directory. + If the command fails, it raises a ConnectionError with the error message. + """ + + output = call_rclone(f"ls {cfg.get_rclone_config_name()}:", pipe_std=True) + + if output.returncode != 0: + utils.log_and_raise_error( + output.stderr.decode("utf-8"), ConnectionError + ) + + def log_rclone_config_output(): output = call_rclone("config file", pipe_std=True) utils.log( From e869986df5357879775ff30645949a178cf1d5cc Mon Sep 17 00:00:00 2001 From: shrey Date: Mon, 2 Jun 2025 03:10:55 +0530 Subject: [PATCH 19/39] refactor: radiobuttons switch in configs.py --- datashuttle/tui/shared/configs_content.py | 157 ++++++++++++---------- 1 file changed, 88 insertions(+), 69 deletions(-) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 3aaae7a2..a08b64ff 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -17,7 +17,6 @@ from textual.message import Message from textual.widgets import ( Button, - Input, Label, RadioButton, RadioSet, @@ -105,11 +104,11 @@ def compose(self) -> ComposeResult: placeholder="Google Drive Client ID (leave blank to use rclone's default client (slower))", id="configs_gdrive_client_id_input", ), - Label("Client Secret", id="configs_gdrive_client_secret_label"), - Input( - placeholder="Google Drive Client Secret (leave blank to use rclone's default client (slower))", - password=True, - id="configs_gdrive_client_secret_input", + Label("Root Folder ID", id="configs_gdrive_root_folder_id_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="Google Drive Root Folder ID (leave blank to use the topmost folder)", + id="configs_gdrive_root_folder_id", ), ] @@ -142,11 +141,21 @@ def compose(self) -> ComposeResult: RadioSet( RadioButton( "Local Filesystem", - id="configs_local_filesystem_radiobutton", + id=self.radiobutton_id_from_connection_method( + "local_filesystem" + ), + ), + RadioButton( + "SSH", id=self.radiobutton_id_from_connection_method("ssh") + ), + RadioButton( + "Google Drive", + id=self.radiobutton_id_from_connection_method("gdrive"), + ), + RadioButton( + "AWS S3", + id=self.radiobutton_id_from_connection_method("aws_s3"), ), - RadioButton("SSH", id="configs_ssh_radiobutton"), - RadioButton("Google Drive", id="configs_gdrive_radiobutton"), - RadioButton("AWS S3", id="configs_aws_s3_radiobutton"), RadioButton( "No connection (local only)", id="configs_local_only_radiobutton", @@ -245,7 +254,6 @@ def on_mount(self) -> None: self.query_one("#configs_local_filesystem_radiobutton").value = ( True ) - # self.switch_ssh_widgets_display(display_ssh=False) self.setup_widgets_to_display(connection_method="local_filesystem") self.query_one("#configs_setup_ssh_connection_button").visible = ( False @@ -291,6 +299,12 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: disabled. """ label = str(event.pressed.label) + radiobutton_id = event.pressed.id + + connection_method = self.connection_method_from_radiobutton_id( + radiobutton_id + ) + assert label in [ "SSH", "Local Filesystem", @@ -299,32 +313,46 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: "AWS S3", ], "Unexpected label." - connection_method = None - if label == "No connection (local only)": - self.query_one("#configs_central_path_input").value = "" - self.query_one("#configs_central_path_input").disabled = True - self.query_one("#configs_central_path_select_button").disabled = ( - True - ) - display_ssh = False - else: - self.query_one("#configs_central_path_input").disabled = False - self.query_one("#configs_central_path_select_button").disabled = ( - False - ) - display_ssh = True if label == "SSH" else False - - if label == "SSH": - connection_method = "ssh" - elif label == "Google Drive": - connection_method = "gdrive" - elif label == "AWS S3": - connection_method = "aws_s3" + connection_method = self.connection_method_from_radiobutton_id( + radiobutton_id + ) + display_ssh = ( + True if connection_method == "ssh" else False + ) # temporarily, for tooltips self.setup_widgets_to_display(connection_method) self.set_central_path_input_tooltip(display_ssh) + def radiobutton_id_from_connection_method( + self, connection_method: str + ) -> str: + return f"configs_{connection_method}_radiobutton" + + def connection_method_from_radiobutton_id( + self, radiobutton_id: str + ) -> str | None: + """ + Get the connection method from the radiobutton id. + """ + assert radiobutton_id.startswith("configs_") + assert radiobutton_id.endswith("_radiobutton") + + connection_string = radiobutton_id[ + len("configs_") : -len("_radiobutton") + ] + return ( + connection_string + if connection_string + in [ + "ssh", + "gdrive", + "aws_s3", + "local_filesystem", + ] + else None + ) + def set_central_path_input_tooltip(self, display_ssh: bool) -> None: """ Use a different tooltip depending on whether connection method @@ -374,10 +402,6 @@ def switch_ssh_widgets_display(self, display_ssh: bool) -> None: for widget in self.config_ssh_widgets: widget.display = display_ssh - self.query_one("#configs_central_path_select_button").display = ( - not display_ssh - ) - if ( self.interface is None or self.interface.get_configs()["connection_method"] != "ssh" @@ -403,10 +427,6 @@ def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: for widget in self.config_gdrive_widgets: widget.display = display_gdrive - self.query_one("#configs_central_path_select_button").display = ( - not display_gdrive - ) - if ( self.interface is None or self.interface.get_configs()["connection_method"] != "gdrive" @@ -424,10 +444,6 @@ def switch_aws_widgets_display(self, display_aws: bool) -> None: for widget in self.config_aws_s3_widgets: widget.display = display_aws - self.query_one("#configs_central_path_select_button").display = ( - not display_aws - ) - if ( self.interface is None or self.interface.get_configs()["connection_method"] != "aws_s3" @@ -773,12 +789,12 @@ def fill_widgets_with_project_configs(self) -> None: ) input.value = value - # Google Drive Client Secret - input = self.query_one("#configs_gdrive_client_secret_input") + # Google Drive Root Folder ID + input = self.query_one("#configs_gdrive_root_folder_id") value = ( "" - if cfg_to_load.get("gdrive_client_secret", None) is None - else cfg_to_load["gdrive_client_secret"] + if cfg_to_load.get("gdrive_root_folder_id", None) is None + else cfg_to_load["gdrive_root_folder_id"] ) input.value = value @@ -810,26 +826,29 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: "aws_s3", ], "Unexpected Connection Method" - if connection_method == "ssh": - # order matters -> fix this - self.switch_gdrive_widgets_display(False) - self.switch_aws_widgets_display(False) - self.switch_ssh_widgets_display(True) - - elif connection_method == "gdrive": - self.switch_ssh_widgets_display(False) - self.switch_aws_widgets_display(False) - self.switch_gdrive_widgets_display(True) + connection_widget_display_functions = { + "ssh": self.switch_ssh_widgets_display, + "gdrive": self.switch_gdrive_widgets_display, + "aws_s3": self.switch_aws_widgets_display, + } - elif connection_method == "aws_s3": - self.switch_ssh_widgets_display(False) - self.switch_gdrive_widgets_display(False) - self.switch_aws_widgets_display(True) + for name, widget_func in connection_widget_display_functions.items(): + if connection_method == name: + widget_func(True) + else: + widget_func(False) + if not connection_method: + self.query_one("#configs_central_path_input").value = "" + self.query_one("#configs_central_path_input").disabled = True + self.query_one("#configs_central_path_select_button").disabled = ( + True + ) else: - self.switch_ssh_widgets_display(False) - self.switch_gdrive_widgets_display(False) - self.switch_aws_widgets_display(False) + self.query_one("#configs_central_path_select_button").disabled = ( + False + ) + self.query_one("#configs_central_path_input").disabled = False def get_datashuttle_inputs_from_widgets(self) -> Dict: """ @@ -891,11 +910,11 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: None if gdrive_client_id == "" else gdrive_client_id ) - gdrive_client_secret = self.query_one( - "#configs_gdrive_client_secret_input" + gdrive_root_folder_id = self.query_one( + "#configs_gdrive_root_folder_id" ).value - cfg_kwargs["gdrive_client_secret"] = ( - None if gdrive_client_secret == "" else gdrive_client_secret + cfg_kwargs["gdrive_root_folder_id"] = ( + None if gdrive_root_folder_id == "" else gdrive_root_folder_id ) # AWS specific From 9edbc8f9194cfe6c4933a0be38ec4dca77099bf2 Mon Sep 17 00:00:00 2001 From: shrey Date: Mon, 2 Jun 2025 05:42:41 +0530 Subject: [PATCH 20/39] edit: minor changes to SetupAwsScreen for setting up aws connection --- datashuttle/tui/interface.py | 9 +++++++-- datashuttle/tui/screens/setup_aws.py | 5 ++++- datashuttle/tui/shared/configs_content.py | 1 + 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index 917fd5be..1664e508 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -11,7 +11,7 @@ from datashuttle import DataShuttle from datashuttle.configs import load_configs -from datashuttle.utils import gdrive, ssh +from datashuttle.utils import gdrive, rclone, ssh class Interface: @@ -525,7 +525,12 @@ def setup_aws_connection( self, aws_secret_access_key: str ) -> InterfaceOutput: try: - self.project.setup_aws_s3_connection(aws_secret_access_key) + self.project._setup_rclone_aws_config( + aws_secret_access_key, log=False + ) + rclone.check_successful_connection_and_raise_error_on_fail( + self.project.cfg + ) return True, None except BaseException as e: return False, str(e) diff --git a/datashuttle/tui/screens/setup_aws.py b/datashuttle/tui/screens/setup_aws.py index d216d751..84e65137 100644 --- a/datashuttle/tui/screens/setup_aws.py +++ b/datashuttle/tui/screens/setup_aws.py @@ -15,7 +15,10 @@ class SetupAwsScreen(ModalScreen): """ This dialog window handles the TUI equivalent of API's - `setup_aws_s3_connection()`. + `setup_aws_s3_connection()`. This asks the user for confirmation + to proceed with the setup, and then prompts the user for the AWS Secret Access Key. + The secret access key is then used to setup rclone config for AWS S3. + Works similar to `SetupSshScreen`. """ def __init__(self, interface: Interface) -> None: diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index a08b64ff..55ad6635 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -839,6 +839,7 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: widget_func(False) if not connection_method: + # local only project self.query_one("#configs_central_path_input").value = "" self.query_one("#configs_central_path_input").disabled = True self.query_one("#configs_central_path_select_button").disabled = ( From 150f2ea3572755e2ea1cd94e75b32bd8dc935059 Mon Sep 17 00:00:00 2001 From: shrey Date: Tue, 3 Jun 2025 04:14:33 +0530 Subject: [PATCH 21/39] refactor: SetupGdriveScreen and handle errors --- datashuttle/tui/interface.py | 13 ++- datashuttle/tui/screens/setup_gdrive.py | 143 +++++++++++++++++------- 2 files changed, 111 insertions(+), 45 deletions(-) diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index 1664e508..cf1f0494 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -499,18 +499,27 @@ def setup_key_pair_and_rclone_config( def setup_google_drive_connection( self, + gdrive_client_secret: Optional[str] = None, config_token: Optional[str] = None, ) -> InterfaceOutput: try: - self.project._setup_rclone_gdrive_config(config_token, log=False) + self.project._setup_rclone_gdrive_config( + gdrive_client_secret, config_token, log=False + ) + rclone.check_successful_connection_and_raise_error_on_fail( + self.project.cfg + ) return True, None except BaseException as e: return False, str(e) - def get_rclone_message_for_gdrive_without_browser(self): + def get_rclone_message_for_gdrive_without_browser( + self, gdrive_client_secret: Optional[str] = None + ) -> InterfaceOutput: try: output = gdrive.preliminary_for_setup_without_browser( self.project.cfg, + gdrive_client_secret, self.project.cfg.get_rclone_config_name("gdrive"), log=False, ) diff --git a/datashuttle/tui/screens/setup_gdrive.py b/datashuttle/tui/screens/setup_gdrive.py index 1e9020ac..e4798673 100644 --- a/datashuttle/tui/screens/setup_gdrive.py +++ b/datashuttle/tui/screens/setup_gdrive.py @@ -8,6 +8,7 @@ from textual.worker import Worker from datashuttle.tui.interface import Interface + from datashuttle.utils.custom_types import InterfaceOutput from textual import work from textual.containers import Container, Horizontal @@ -26,8 +27,14 @@ def __init__(self, interface: Interface) -> None: super(SetupGdriveScreen, self).__init__() self.interface = interface - self.stage = 0 + self.stage: float = 0 self.setup_worker: Worker | None = None + self.gdrive_client_secret: Optional[str] = None + + self.input_box: Input = Input( + id="setup_gdrive_generic_input_box", + placeholder="Enter value here", + ) def compose(self) -> ComposeResult: yield Container( @@ -38,7 +45,6 @@ def compose(self) -> ComposeResult: ), id="gdrive_setup_messagebox_message_container", ), - # Input(), Horizontal( Button("OK", id="setup_gdrive_ok_button"), Button("Cancel", id="setup_gdrive_cancel_button"), @@ -58,7 +64,19 @@ def on_button_pressed(self, event: Button.Pressed) -> None: self.dismiss() elif event.button.id == "setup_gdrive_ok_button": - self.ask_user_for_browser() + if self.stage == 0: + if self.interface.project.cfg["gdrive_client_id"]: + self.ask_user_for_gdrive_client_secret() + else: + self.ask_user_for_browser() + + elif self.stage == 0.5: + self.gdrive_client_secret = ( + self.input_box.value.strip() + if self.input_box.value.strip() + else None + ) + self.ask_user_for_browser() elif event.button.id == "setup_gdrive_yes_button": self.open_browser_and_setup_gdrive_connection() @@ -69,41 +87,53 @@ def on_button_pressed(self, event: Button.Pressed) -> None: elif event.button.id == "setup_gdrive_enter_button": self.setup_gdrive_connection_using_config_token() + def ask_user_for_gdrive_client_secret(self) -> None: + message = ( + "Please provide the client secret for Google Drive. " + "You can find it in your Google Cloud Console." + ) + self.update_message_box_message(message) + + ok_button = self.query_one("#setup_gdrive_ok_button") + ok_button.label = "Enter" + + self.mount_input_box_before_buttons() + + self.stage += 0.5 + def ask_user_for_browser(self) -> None: message = ( "Are you running Datashuttle on a machine " "that can open a web browser?" ) - self.query_one("#gdrive_setup_messagebox_message").update(message) + self.update_message_box_message(message) + self.query_one("#setup_gdrive_ok_button").remove() + + # Remove the input box if it was mounted previously + if self.input_box.is_mounted: + self.input_box.remove() + + # Mount the Yes and No buttons yes_button = Button("Yes", id="setup_gdrive_yes_button") no_button = Button("No", id="setup_gdrive_no_button") - self.query_one("#setup_gdrive_ok_button").remove() + # Mount a cancel button self.query_one("#setup_gdrive_buttons_horizontal").mount( yes_button, no_button, before="#setup_gdrive_cancel_button" ) - self.stage += 1 + self.stage += 0.5 if self.stage == 0.5 else 1 def open_browser_and_setup_gdrive_connection(self) -> None: - # TODO: ADD SOME SUCCESS, OUTPUT message = "Please authenticate through browser." - self.query_one("#gdrive_setup_messagebox_message").update(message) + self.update_message_box_message(message) + # Remove the Yes and No buttons self.query_one("#setup_gdrive_yes_button").remove() self.query_one("#setup_gdrive_no_button").remove() - async def _setup_gdrive_and_update_ui(): - worker = self.setup_gdrive_connection() - self.setup_worker = worker - if worker.is_running: - await worker.wait() - - # TODO : check if successful - self.show_finish_screen() - - asyncio.create_task(_setup_gdrive_and_update_ui()) + asyncio.create_task(self.setup_gdrive_connection_and_update_ui()) def prompt_user_for_config_token(self) -> None: @@ -111,64 +141,91 @@ def prompt_user_for_config_token(self) -> None: self.query_one("#setup_gdrive_no_button").remove() success, message = ( - self.interface.get_rclone_message_for_gdrive_without_browser() + self.interface.get_rclone_message_for_gdrive_without_browser( + self.gdrive_client_secret + ) ) if not success: - self.display_failed() + self.display_failed(message) return - self.query_one("#gdrive_setup_messagebox_message").update( + self.update_message_box_message( message + "\nPress shift+click to copy." ) enter_button = Button("Enter", id="setup_gdrive_enter_button") - input_box = Input(id="setup_gdrive_config_token_input") + self.input_box.value = "" self.query_one("#setup_gdrive_buttons_horizontal").mount( enter_button, before="#setup_gdrive_cancel_button" ) - self.query_one("#setup_gdrive_screen_container").mount( - input_box, before="#setup_gdrive_buttons_horizontal" - ) + self.mount_input_box_before_buttons() def setup_gdrive_connection_using_config_token(self) -> None: - self.query_one("#setup_gdrive_config_token_input").disabled = True + self.input_box.disabled = True enter_button = self.query_one("#setup_gdrive_enter_button") enter_button.disabled = True - config_token = self.query_one("#setup_gdrive_config_token_input").value + config_token = self.input_box.value.strip() - async def _setup_gdrive_and_update_ui(): - worker = self.setup_gdrive_connection(config_token) - self.setup_worker = worker - if worker.is_running: - await worker.wait() - - enter_button.remove() + asyncio.create_task( + self.setup_gdrive_connection_and_update_ui(config_token) + ) - # TODO : check if successful + async def setup_gdrive_connection_and_update_ui( + self, config_token: Optional[str] = None + ) -> None: + worker = self.setup_gdrive_connection(config_token) + self.setup_worker = worker + if worker.is_running: + await worker.wait() + + if config_token: + enter_button = self.query_one("#setup_gdrive_enter_button") + enter_button.disabled = True + + success, output = worker.result + if success: self.show_finish_screen() - - asyncio.create_task(_setup_gdrive_and_update_ui()) + else: + self.display_failed(output) @work(exclusive=True, thread=True) def setup_gdrive_connection( self, config_token: Optional[str] = None - ) -> Worker: - self.interface.setup_google_drive_connection(config_token) - self.stage += 1 + ) -> Worker[InterfaceOutput]: + success, output = self.interface.setup_google_drive_connection( + self.gdrive_client_secret, config_token + ) + return success, output + + # ---------------------------------------------------------------------------------- + # UI Update Methods + # ---------------------------------------------------------------------------------- def show_finish_screen(self) -> None: message = "Setup Complete!" self.query_one("#setup_gdrive_cancel_button").remove() - self.query_one("#gdrive_setup_messagebox_message").update(message) + self.update_message_box_message(message) self.query_one("#setup_gdrive_buttons_horizontal").mount( Button("Finish", id="setup_gdrive_finish_button") ) - def display_failed(self) -> None: - pass + def display_failed(self, output) -> None: + message = ( + f"Google Drive setup failed. Please check your configs and client secret" + f"\n\n Traceback: {output}" + ) + self.update_message_box_message(message) + + def update_message_box_message(self, message: str) -> None: + self.query_one("#gdrive_setup_messagebox_message").update(message) + + def mount_input_box_before_buttons(self) -> None: + self.query_one("#setup_gdrive_screen_container").mount( + self.input_box, before="#setup_gdrive_buttons_horizontal" + ) From 70659cde2221c205ef999a655933cd1049e14c11 Mon Sep 17 00:00:00 2001 From: shrey Date: Tue, 3 Jun 2025 05:52:22 +0530 Subject: [PATCH 22/39] add: some tooltips for google drive configs --- datashuttle/tui/shared/configs_content.py | 9 +++++++-- datashuttle/tui/tooltips.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 55ad6635..ee97a4c7 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -101,13 +101,13 @@ def compose(self) -> ComposeResult: Label("Client ID", id="configs_gdrive_client_id_label"), ClickableInput( self.parent_class.mainwindow, - placeholder="Google Drive Client ID (leave blank to use rclone's default client (slower))", + placeholder="Google Drive Client ID", id="configs_gdrive_client_id_input", ), Label("Root Folder ID", id="configs_gdrive_root_folder_id_label"), ClickableInput( self.parent_class.mainwindow, - placeholder="Google Drive Root Folder ID (leave blank to use the topmost folder)", + placeholder="Google Drive Root Folder ID", id="configs_gdrive_root_folder_id", ), ] @@ -284,6 +284,8 @@ def on_mount(self) -> None: "#configs_local_only_radiobutton", "#configs_central_host_username_input", "#configs_central_host_id_input", + "#configs_gdrive_client_id_input", + "#configs_gdrive_root_folder_id", ]: self.query_one(id).tooltip = get_tooltip(id) @@ -625,12 +627,15 @@ def setup_configs_for_a_new_project(self) -> None: self.query_one("#configs_go_to_project_screen_button").visible = ( True ) + + # A message template to display custom message to user according to the chosen connection method message_template = ( "A datashuttle project has now been created.\n\n " "Next, setup the {method_name} connection. Once complete, navigate to the " "'Main Menu' and proceed to the project page, where you will be " "able to create and transfer project folders." ) + # Could not find a neater way to combine the push screen # while initiating the callback in one case but not the other. if cfg_kwargs["connection_method"] == "ssh": diff --git a/datashuttle/tui/tooltips.py b/datashuttle/tui/tooltips.py index e8d0c771..287b7226 100644 --- a/datashuttle/tui/tooltips.py +++ b/datashuttle/tui/tooltips.py @@ -65,6 +65,23 @@ def get_tooltip(id: str) -> str: "to a project folder, possibly on a mounted drive.\n\n" ) + # Google Drive configs + # ------------------------------------------------------------------------- + + # Google Drive Client ID + elif id == "#configs_gdrive_client_id_input": + tooltip = ( + "The Google Drive Client ID to use for authentication.\n\n" + "It can be obtained by creating an OAuth 2.0 client in the Google Cloud Console.\n\n" + "Can be left empty to use rclone's default client (slower)" + ) + + elif id == "#configs_gdrive_root_folder_id": + tooltip = ( + "The Google Drive root folder ID to use for transfer.\n\n" + "It can be obtained by navigating to the folder in Google Drive and copying the ID from the URL.\n\n" + ) + # Settings # ------------------------------------------------------------------------- From 985e9210fdbedac8ebef4206300b3343f14b4f6c Mon Sep 17 00:00:00 2001 From: shrey Date: Wed, 4 Jun 2025 20:56:09 +0530 Subject: [PATCH 23/39] fix: vanishing central path, radio button order, minor refactor --- datashuttle/tui/shared/configs_content.py | 149 ++++++++++++---------- 1 file changed, 81 insertions(+), 68 deletions(-) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index ee97a4c7..994c1815 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -139,6 +139,10 @@ def compose(self) -> ComposeResult: ), Label("Connection Method", id="configs_connect_method_label"), RadioSet( + RadioButton( + "No connection (local only)", + id="configs_local_only_radiobutton", + ), RadioButton( "Local Filesystem", id=self.radiobutton_id_from_connection_method( @@ -156,10 +160,6 @@ def compose(self) -> ComposeResult: "AWS S3", id=self.radiobutton_id_from_connection_method("aws_s3"), ), - RadioButton( - "No connection (local only)", - id="configs_local_only_radiobutton", - ), id="configs_connect_method_radioset", ), *self.config_ssh_widgets, @@ -322,6 +322,7 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: True if connection_method == "ssh" else False ) # temporarily, for tooltips + self.fill_inputs_with_project_configs() self.setup_widgets_to_display(connection_method) self.set_central_path_input_tooltip(display_ssh) @@ -724,24 +725,15 @@ def fill_widgets_with_project_configs(self) -> None: widgets with the current project configs. This in some instances requires recasting to a new type of changing the value. - In the case of the `connection_method` widget, the associated - "ssh" widgets are hidden / displayed based on the current setting, - in `self.switch_ssh_widgets_display()`. + In the case of the `connection_method` widget, the associated connection + method radio button is hidden / displayed based on the current settings. + This change of radio button triggers `on_radio_set_changed` which displays + the appropriate connection method widgets. """ assert self.interface is not None, "type narrow flexible `interface`" cfg_to_load = self.interface.get_textual_compatible_project_configs() - # Local Path - input = self.query_one("#configs_local_path_input") - input.value = cfg_to_load["local_path"] - - # Central Path - input = self.query_one("#configs_central_path_input") - input.value = ( - cfg_to_load["central_path"] if cfg_to_load["central_path"] else "" - ) - # Connection Method # Make a dict of radiobutton: is on bool to easily find # how to set radiobuttons and associated configs @@ -763,9 +755,27 @@ def fill_widgets_with_project_configs(self) -> None: for id, value in what_radiobuton_is_on.items(): self.query_one(f"#{id}").value = value - # self.switch_ssh_widgets_display( - # display_ssh=what_radiobuton_is_on["configs_ssh_radiobutton"] - # ) + self.fill_inputs_with_project_configs() + + def fill_inputs_with_project_configs(self) -> None: + """ + This fills the input widgets with the current project configs. It is + used while setting up widgets for the project while mounting the current + tab and also to repopulate input widgets when the radio buttons change. + """ + assert self.interface is not None, "type narrow flexible `interface`" + + cfg_to_load = self.interface.get_textual_compatible_project_configs() + + # Local Path + input = self.query_one("#configs_local_path_input") + input.value = cfg_to_load["local_path"] + + # Central Path + input = self.query_one("#configs_central_path_input") + input.value = ( + cfg_to_load["central_path"] if cfg_to_load["central_path"] else "" + ) # Central Host ID input = self.query_one("#configs_central_host_id_input") @@ -875,65 +885,68 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: else: cfg_kwargs["central_path"] = Path(central_path_value) - if self.query_one("#configs_ssh_radiobutton").value: - connection_method = "ssh" - - elif self.query_one("#configs_gdrive_radiobutton").value: - connection_method = "gdrive" - - elif self.query_one("#configs_aws_s3_radiobutton").value: - connection_method = "aws_s3" - - elif self.query_one("#configs_local_filesystem_radiobutton").value: - connection_method = "local_filesystem" - - elif self.query_one("#configs_local_only_radiobutton").value: - connection_method = None + for id in [ + "configs_local_filesystem_radiobutton", + "configs_ssh_radiobutton", + "configs_gdrive_radiobutton", + "configs_aws_s3_radiobutton", + "configs_local_only_radiobutton", + ]: + if self.query_one("#" + id).value: + connection_method = self.connection_method_from_radiobutton_id( + id + ) + break cfg_kwargs["connection_method"] = connection_method # SSH specific - central_host_id = self.query_one( - "#configs_central_host_id_input" - ).value - cfg_kwargs["central_host_id"] = ( - None if central_host_id == "" else central_host_id - ) + if connection_method == "ssh": + central_host_id = self.query_one( + "#configs_central_host_id_input" + ).value + cfg_kwargs["central_host_id"] = ( + None if central_host_id == "" else central_host_id + ) - central_host_username = self.query_one( - "#configs_central_host_username_input" - ).value + central_host_username = self.query_one( + "#configs_central_host_username_input" + ).value - cfg_kwargs["central_host_username"] = ( - None if central_host_username == "" else central_host_username - ) + cfg_kwargs["central_host_username"] = ( + None if central_host_username == "" else central_host_username + ) # Google Drive specific - gdrive_client_id = self.query_one( - "#configs_gdrive_client_id_input" - ).value - cfg_kwargs["gdrive_client_id"] = ( - None if gdrive_client_id == "" else gdrive_client_id - ) + elif connection_method == "gdrive": + gdrive_client_id = self.query_one( + "#configs_gdrive_client_id_input" + ).value + cfg_kwargs["gdrive_client_id"] = ( + None if gdrive_client_id == "" else gdrive_client_id + ) - gdrive_root_folder_id = self.query_one( - "#configs_gdrive_root_folder_id" - ).value - cfg_kwargs["gdrive_root_folder_id"] = ( - None if gdrive_root_folder_id == "" else gdrive_root_folder_id - ) + gdrive_root_folder_id = self.query_one( + "#configs_gdrive_root_folder_id" + ).value + cfg_kwargs["gdrive_root_folder_id"] = ( + None if gdrive_root_folder_id == "" else gdrive_root_folder_id + ) # AWS specific - aws_access_key_id = self.query_one( - "#configs_aws_access_key_id_input" - ).value - cfg_kwargs["aws_access_key_id"] = ( - None if aws_access_key_id == "" else aws_access_key_id - ) + elif connection_method == "aws_s3": + aws_access_key_id = self.query_one( + "#configs_aws_access_key_id_input" + ).value + cfg_kwargs["aws_access_key_id"] = ( + None if aws_access_key_id == "" else aws_access_key_id + ) - aws_s3_region = self.query_one("#configs_aws_s3_region_select").value - cfg_kwargs["aws_s3_region"] = ( - None if aws_s3_region == Select.BLANK else aws_s3_region - ) + aws_s3_region = self.query_one( + "#configs_aws_s3_region_select" + ).value + cfg_kwargs["aws_s3_region"] = ( + None if aws_s3_region == Select.BLANK else aws_s3_region + ) return cfg_kwargs From d7f13d4d9c6c819ec72572012ad932bc203e8b31 Mon Sep 17 00:00:00 2001 From: shrey Date: Thu, 5 Jun 2025 00:23:22 +0530 Subject: [PATCH 24/39] fix: minor bug --- datashuttle/tui/shared/configs_content.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 994c1815..8af1296c 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -322,7 +322,9 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: True if connection_method == "ssh" else False ) # temporarily, for tooltips - self.fill_inputs_with_project_configs() + if self.interface: + self.fill_inputs_with_project_configs() + self.setup_widgets_to_display(connection_method) self.set_central_path_input_tooltip(display_ssh) From f7807d1aa68550b61d416e25fae1900f852672c5 Mon Sep 17 00:00:00 2001 From: shrey Date: Thu, 5 Jun 2025 00:48:51 +0530 Subject: [PATCH 25/39] refactor: single button for setup connection --- datashuttle/tui/shared/configs_content.py | 143 ++++++++-------------- 1 file changed, 52 insertions(+), 91 deletions(-) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 8af1296c..5f6f625d 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -179,16 +179,8 @@ def compose(self) -> ComposeResult: Button("Save", id="configs_save_configs_button"), Horizontal( Button( - "Setup SSH Connection", - id="configs_setup_ssh_connection_button", - ), - Button( - "Setup Google Drive Connection", - id="configs_setup_gdrive_connection_button", - ), - Button( - "Setup AWS Connection", - id="configs_setup_aws_connection_button", + "Setup Button", + id="configs_setup_connection_button", ), id="setup_buttons_container", ), @@ -255,9 +247,6 @@ def on_mount(self) -> None: True ) self.setup_widgets_to_display(connection_method="local_filesystem") - self.query_one("#configs_setup_ssh_connection_button").visible = ( - False - ) # Setup tooltips if not self.interface: @@ -407,18 +396,6 @@ def switch_ssh_widgets_display(self, display_ssh: bool) -> None: for widget in self.config_ssh_widgets: widget.display = display_ssh - if ( - self.interface is None - or self.interface.get_configs()["connection_method"] != "ssh" - ): - self.query_one("#configs_setup_ssh_connection_button").visible = ( - False - ) - else: - self.query_one("#configs_setup_ssh_connection_button").visible = ( - display_ssh - ) - if not self.query_one("#configs_central_path_input").value: if display_ssh: placeholder = f"e.g. {self.get_platform_dependent_example_paths('central', ssh=True)}" @@ -432,35 +409,10 @@ def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: for widget in self.config_gdrive_widgets: widget.display = display_gdrive - if ( - self.interface is None - or self.interface.get_configs()["connection_method"] != "gdrive" - ): - self.query_one( - "#configs_setup_gdrive_connection_button" - ).visible = False - else: - self.query_one( - "#configs_setup_gdrive_connection_button" - ).visible = display_gdrive - def switch_aws_widgets_display(self, display_aws: bool) -> None: - for widget in self.config_aws_s3_widgets: widget.display = display_aws - if ( - self.interface is None - or self.interface.get_configs()["connection_method"] != "aws_s3" - ): - self.query_one("#configs_setup_aws_connection_button").visible = ( - False - ) - else: - self.query_one("#configs_setup_aws_connection_button").visible = ( - display_aws - ) - def on_button_pressed(self, event: Button.Pressed) -> None: """ Enables the Create Folders button to read out current input values @@ -472,14 +424,21 @@ def on_button_pressed(self, event: Button.Pressed) -> None: else: self.setup_configs_for_an_existing_project() - elif event.button.id == "configs_setup_ssh_connection_button": - self.setup_ssh_connection() + elif event.button.id == "configs_setup_connection_button": + assert ( + self.interface is not None + ), "type narrow flexible `interface`" - elif event.button.id == "configs_setup_gdrive_connection_button": - self.setup_gdrive_connection() + connection_method = self.interface.get_configs()[ + "connection_method" + ] - elif event.button.id == "configs_setup_aws_connection_button": - self.setup_aws_connection() + if connection_method == "ssh": + self.setup_ssh_connection() + elif connection_method == "gdrive": + self.setup_gdrive_connection() + elif connection_method == "aws_s3": + self.setup_aws_connection() elif event.button.id == "configs_go_to_project_screen_button": self.parent_class.dismiss(self.interface) @@ -641,37 +600,18 @@ def setup_configs_for_a_new_project(self) -> None: # Could not find a neater way to combine the push screen # while initiating the callback in one case but not the other. - if cfg_kwargs["connection_method"] == "ssh": + connection_method = cfg_kwargs["connection_method"] - self.query_one( - "#configs_setup_ssh_connection_button" - ).visible = True - self.query_one( - "#configs_setup_ssh_connection_button" - ).disabled = False + # To trigger the appearance of "Setup connection" button + self.setup_widgets_to_display(connection_method) + if connection_method == "ssh": message = message_template.format(method_name="SSH") - elif cfg_kwargs["connection_method"] == "gdrive": - - self.query_one( - "#configs_setup_gdrive_connection_button" - ).visible = True - self.query_one( - "#configs_setup_gdrive_connection_button" - ).disabled = False - + elif connection_method == "gdrive": message = message_template.format(method_name="Google Drive") - elif cfg_kwargs["connection_method"] == "aws_s3": - - self.query_one( - "#configs_setup_aws_connection_button" - ).visible = True - self.query_one( - "#configs_setup_aws_connection_button" - ).disabled = False - + elif connection_method == "aws_s3": message = message_template.format(method_name="AWS") else: @@ -716,7 +656,7 @@ def setup_configs_for_an_existing_project(self) -> None: ), lambda unused: self.post_message(self.ConfigsSaved()), ) - # to trigger the appearance of buttons + # To trigger the appearance of "Setup connection" button self.setup_widgets_to_display(cfg_kwargs["connection_method"]) else: self.parent_class.mainwindow.show_modal_error_dialog(output) @@ -855,18 +795,39 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: else: widget_func(False) + self.query_one("#configs_central_path_input").disabled = ( + connection_method is None + ) + self.query_one("#configs_central_path_select_button").disabled = ( + connection_method is None + ) + + # Local only project if not connection_method: - # local only project self.query_one("#configs_central_path_input").value = "" - self.query_one("#configs_central_path_input").disabled = True - self.query_one("#configs_central_path_select_button").disabled = ( - True - ) + + setup_connection_button = self.query_one( + "#configs_setup_connection_button" + ) + + # fmt: off + if ( + not connection_method + or connection_method == "local_filesystem" + or not self.interface + or connection_method != self.interface.get_configs()["connection_method"] + ): + setup_connection_button.visible = False + # fmt: on else: - self.query_one("#configs_central_path_select_button").disabled = ( - False - ) - self.query_one("#configs_central_path_input").disabled = False + setup_connection_button.visible = True + + if connection_method == "ssh": + setup_connection_button.label = "Setup SSH Connection" + elif connection_method == "gdrive": + setup_connection_button.label = "Setup Google Drive Connection" + elif connection_method == "aws": + setup_connection_button.label = "Setup AWS Connection" def get_datashuttle_inputs_from_widgets(self) -> Dict: """ From be8f6b1fbc2b4bde08f2117c53b53da69897bb18 Mon Sep 17 00:00:00 2001 From: shrey Date: Thu, 5 Jun 2025 22:52:50 +0530 Subject: [PATCH 26/39] add: backwards compatibility to configs while load from config file --- datashuttle/configs/config_class.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/datashuttle/configs/config_class.py b/datashuttle/configs/config_class.py index 562c3310..c74dde82 100644 --- a/datashuttle/configs/config_class.py +++ b/datashuttle/configs/config_class.py @@ -116,7 +116,10 @@ def dump_to_file(self) -> None: def load_from_file(self) -> None: """ Load a config dict saved at .yaml file. Note this will - not automatically check the configs are valid, this + do a minimal backwards compatibility check and add config + keys to ensure backwards compatibility with new connection + methods added to Datashuttle. + But this will not automatically check the configs are valid, this requires calling self.check_dict_values_raise_on_fail() """ with open(self.file_path, "r") as config_file: @@ -124,8 +127,32 @@ def load_from_file(self) -> None: load_configs.convert_str_and_pathlib_paths(config_dict, "str_to_path") + config_dict = self.ensure_backwards_compatibilty_for_config( + config_dict + ) self.data = config_dict + def ensure_backwards_compatibilty_for_config( + self, config_dict: Dict + ) -> Dict: + canonical_config_keys_to_add = [ + "gdrive_client_id", + "gdrive_root_folder_id", + "aws_access_key_id", + "aws_s3_region", + ] + + # All keys shall be missing for a backwards compatibility update + needs_update = all( + key not in config_dict.keys() + for key in canonical_config_keys_to_add + ) + if needs_update: + for key in canonical_config_keys_to_add: + config_dict[key] = None + + return config_dict + # ------------------------------------------------------------------------- # Utils # ------------------------------------------------------------------------- From 0b7483be13e2789ac7fb46c0dcf96829e11851c9 Mon Sep 17 00:00:00 2001 From: shrey Date: Fri, 6 Jun 2025 00:54:13 +0530 Subject: [PATCH 27/39] edit: raise error on bucket not present --- datashuttle/datashuttle_class.py | 4 ++-- datashuttle/tui/interface.py | 3 ++- datashuttle/tui/shared/configs_content.py | 2 +- datashuttle/utils/aws.py | 21 +++++++++++---------- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 397af6fc..2ee5141b 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -947,9 +947,9 @@ def setup_aws_s3_connection(self) -> None: self._setup_rclone_aws_config(aws_secret_access_key, log=True) rclone.check_successful_connection_and_raise_error_on_fail(self.cfg) - utils.log_and_message("AWS Connection Successful.") + aws.raise_if_bucket_absent(self.cfg) - aws.warn_if_bucket_absent(self.cfg) + utils.log_and_message("AWS Connection Successful.") ds_logger.close_log_filehandler() diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index cf1f0494..c62e1d8f 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -11,7 +11,7 @@ from datashuttle import DataShuttle from datashuttle.configs import load_configs -from datashuttle.utils import gdrive, rclone, ssh +from datashuttle.utils import aws, gdrive, rclone, ssh class Interface: @@ -540,6 +540,7 @@ def setup_aws_connection( rclone.check_successful_connection_and_raise_error_on_fail( self.project.cfg ) + aws.raise_if_bucket_absent(self.project.cfg) return True, None except BaseException as e: return False, str(e) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 5f6f625d..a868d679 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -826,7 +826,7 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: setup_connection_button.label = "Setup SSH Connection" elif connection_method == "gdrive": setup_connection_button.label = "Setup Google Drive Connection" - elif connection_method == "aws": + elif connection_method == "aws_s3": setup_connection_button.label = "Setup AWS Connection" def get_datashuttle_inputs_from_widgets(self) -> Dict: diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py index d9e8d55e..c6e506a4 100644 --- a/datashuttle/utils/aws.py +++ b/datashuttle/utils/aws.py @@ -4,6 +4,7 @@ from datashuttle.configs.config_class import Configs from datashuttle.utils import rclone, utils +from datashuttle.utils.custom_exceptions import ConfigError def check_if_aws_bucket_exists(cfg: Configs) -> bool: @@ -23,22 +24,22 @@ def check_if_aws_bucket_exists(cfg: Configs) -> bool: return True -# ----------------------------------------------------------------------------- -# For Python API -# ----------------------------------------------------------------------------- - - -def warn_if_bucket_absent(cfg: Configs) -> None: - +def raise_if_bucket_absent(cfg: Configs) -> None: if not check_if_aws_bucket_exists(cfg): bucket_name = cfg["central_path"].as_posix().strip("/").split("/")[0] - utils.print_message_to_user( - f'WARNING: The bucket "{bucket_name}" does not exist.\n' + utils.log_and_raise_error( + f'The bucket "{bucket_name}" does not exist.\n' f"For data transfer to happen, the bucket must exist.\n" - f"Please change the bucket name in the `central_path`. " + f"Please change the bucket name in the `central_path`.", + ConfigError, ) +# ----------------------------------------------------------------------------- +# For Python API +# ----------------------------------------------------------------------------- + + def get_aws_secret_access_key(log: bool = True) -> str: if not sys.stdin.isatty(): proceed = input( From 2579827b83f4d67e9acdfcd5d80fbfa65167e6b6 Mon Sep 17 00:00:00 2001 From: shrey Date: Mon, 9 Jun 2025 23:41:13 +0530 Subject: [PATCH 28/39] rename: aws region config key --- datashuttle/configs/canonical_configs.py | 7 +++---- datashuttle/configs/config_class.py | 2 +- datashuttle/datashuttle_class.py | 4 ++-- datashuttle/tui/css/tui_menu.tcss | 2 +- datashuttle/tui/shared/configs_content.py | 18 ++++++++---------- datashuttle/utils/decorators.py | 4 ++-- datashuttle/utils/rclone.py | 4 ++-- 7 files changed, 19 insertions(+), 22 deletions(-) diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 29ab03fb..63db3f78 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -48,7 +48,7 @@ def get_canonical_configs() -> dict: "gdrive_client_id": Optional[str], "gdrive_root_folder_id": Optional[str], "aws_access_key_id": Optional[str], - "aws_s3_region": Optional[Literal[*get_aws_regions_list()]], + "aws_region": Optional[Literal[*get_aws_regions_list()]], # "aws_s3_endpoint_url": Optional[str], } @@ -152,11 +152,10 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: # Check AWS settings elif config_dict["connection_method"] == "aws_s3" and ( - not config_dict["aws_access_key_id"] - or not config_dict["aws_s3_region"] + not config_dict["aws_access_key_id"] or not config_dict["aws_region"] ): utils.log_and_raise_error( - "Both aws_access_key_id and aws_s3_region must be present for AWS connection.", + "Both aws_access_key_id and aws_region must be present for AWS connection.", ConfigError, ) diff --git a/datashuttle/configs/config_class.py b/datashuttle/configs/config_class.py index c74dde82..5baad039 100644 --- a/datashuttle/configs/config_class.py +++ b/datashuttle/configs/config_class.py @@ -139,7 +139,7 @@ def ensure_backwards_compatibilty_for_config( "gdrive_client_id", "gdrive_root_folder_id", "aws_access_key_id", - "aws_s3_region", + "aws_region", ] # All keys shall be missing for a backwards compatibility update diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 2ee5141b..a7b2f606 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -967,7 +967,7 @@ def make_config_file( gdrive_client_id: Optional[str] = None, gdrive_root_folder_id: Optional[str] = None, aws_access_key_id: Optional[str] = None, - aws_s3_region: Optional[str] = None, + aws_region: Optional[str] = None, ) -> None: """ Initialise the configurations for datashuttle to use on the @@ -1035,7 +1035,7 @@ def make_config_file( "gdrive_client_id": gdrive_client_id, "gdrive_root_folder_id": gdrive_root_folder_id, "aws_access_key_id": aws_access_key_id, - "aws_s3_region": aws_s3_region, + "aws_region": aws_region, }, ) diff --git a/datashuttle/tui/css/tui_menu.tcss b/datashuttle/tui/css/tui_menu.tcss index 9e10158d..e32a7b55 100644 --- a/datashuttle/tui/css/tui_menu.tcss +++ b/datashuttle/tui/css/tui_menu.tcss @@ -277,7 +277,7 @@ MessageBox:light > #messagebox_top_container { padding: 0 0 2 0; } -#configs_aws_s3_region_select { +#configs_aws_region_select { width: 70%; } diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index a868d679..43fcb829 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -119,10 +119,10 @@ def compose(self) -> ComposeResult: placeholder="AWS Access Key ID eg. EJIBCLSIP2K2PQK3CDON", id="configs_aws_access_key_id_input", ), - Label("AWS S3 Region", id="configs_aws_s3_region_label"), + Label("AWS S3 Region", id="configs_aws_region_label"), Select( ((region, region) for region in get_aws_regions_list()), - id="configs_aws_s3_region_select", + id="configs_aws_region_select", ), ] @@ -765,11 +765,11 @@ def fill_inputs_with_project_configs(self) -> None: input.value = value # AWS S3 Region - select = self.query_one("#configs_aws_s3_region_select") + select = self.query_one("#configs_aws_region_select") value = ( Select.BLANK - if cfg_to_load.get("aws_s3_region", None) is None - else cfg_to_load["aws_s3_region"] + if cfg_to_load.get("aws_region", None) is None + else cfg_to_load["aws_region"] ) select.value = value @@ -905,11 +905,9 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: None if aws_access_key_id == "" else aws_access_key_id ) - aws_s3_region = self.query_one( - "#configs_aws_s3_region_select" - ).value - cfg_kwargs["aws_s3_region"] = ( - None if aws_s3_region == Select.BLANK else aws_s3_region + aws_region = self.query_one("#configs_aws_region_select").value + cfg_kwargs["aws_region"] = ( + None if aws_region == Select.BLANK else aws_region ) return cfg_kwargs diff --git a/datashuttle/utils/decorators.py b/datashuttle/utils/decorators.py index 36e7d0d5..0bb64c20 100644 --- a/datashuttle/utils/decorators.py +++ b/datashuttle/utils/decorators.py @@ -34,11 +34,11 @@ def requires_aws_configs(func): def wrapper(*args, **kwargs): if ( not args[0].cfg["aws_access_key_id"] - or not args[0].cfg["aws_s3_region"] + or not args[0].cfg["aws_region"] ): log_and_raise_error( "Cannot setup AWS connection, 'aws_access_key_id' " - "or 'aws_s3_region' is not set in the " + "or 'aws_region' is not set in the " "configuration file", ConfigError, ) diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 42518009..5c7cd208 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -206,8 +206,8 @@ def setup_rclone_config_for_aws_s3( "s3 provider AWS " f"access_key_id {cfg['aws_access_key_id']} " f"secret_access_key {aws_secret_access_key} " - f"region {cfg['aws_s3_region']} " - f"location_constraint {cfg['aws_s3_region']}", + f"region {cfg['aws_region']} " + f"location_constraint {cfg['aws_region']}", pipe_std=True, ) From 0a1ca872930f20cd24695a7dedca279f6f17f781 Mon Sep 17 00:00:00 2001 From: shrey Date: Tue, 10 Jun 2025 00:12:08 +0530 Subject: [PATCH 29/39] rename: connection method from aws_s3 to aws --- datashuttle/configs/canonical_configs.py | 4 ++-- datashuttle/datashuttle_class.py | 8 +++---- datashuttle/tui/screens/setup_aws.py | 2 +- datashuttle/tui/shared/configs_content.py | 28 +++++++++++------------ datashuttle/utils/folders.py | 2 +- datashuttle/utils/rclone.py | 2 +- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 63db3f78..0969c858 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -41,7 +41,7 @@ def get_canonical_configs() -> dict: "local_path": Union[str, Path], "central_path": Optional[Union[str, Path]], "connection_method": Optional[ - Literal["ssh", "local_filesystem", "gdrive", "aws_s3"] + Literal["ssh", "local_filesystem", "gdrive", "aws"] ], "central_host_id": Optional[str], "central_host_username": Optional[str], @@ -151,7 +151,7 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: ) # Check AWS settings - elif config_dict["connection_method"] == "aws_s3" and ( + elif config_dict["connection_method"] == "aws" and ( not config_dict["aws_access_key_id"] or not config_dict["aws_region"] ): utils.log_and_raise_error( diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index a7b2f606..672dd4a4 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -936,9 +936,9 @@ def setup_google_drive_connection(self) -> None: @requires_aws_configs @check_configs_set - def setup_aws_s3_connection(self) -> None: + def setup_aws_connection(self) -> None: self._start_log( - "setup-aws-s3-connection-to-central-server", + "setup-aws-connection-to-central-server", local_vars=locals(), ) @@ -1556,10 +1556,10 @@ def _setup_rclone_gdrive_config( def _setup_rclone_aws_config( self, aws_secret_access_key: str, log: bool ) -> None: - rclone.setup_rclone_config_for_aws_s3( + rclone.setup_rclone_config_for_aws( self.cfg, aws_secret_access_key, - self.cfg.get_rclone_config_name("aws_s3"), + self.cfg.get_rclone_config_name("aws"), log=log, ) diff --git a/datashuttle/tui/screens/setup_aws.py b/datashuttle/tui/screens/setup_aws.py index 84e65137..8f59d7ca 100644 --- a/datashuttle/tui/screens/setup_aws.py +++ b/datashuttle/tui/screens/setup_aws.py @@ -15,7 +15,7 @@ class SetupAwsScreen(ModalScreen): """ This dialog window handles the TUI equivalent of API's - `setup_aws_s3_connection()`. This asks the user for confirmation + `setup_aws_connection()`. This asks the user for confirmation to proceed with the setup, and then prompts the user for the AWS Secret Access Key. The secret access key is then used to setup rclone config for AWS S3. Works similar to `SetupSshScreen`. diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 43fcb829..9eb5a172 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -112,7 +112,7 @@ def compose(self) -> ComposeResult: ), ] - self.config_aws_s3_widgets = [ + self.config_aws_widgets = [ Label("AWS Access Key ID", id="configs_aws_access_key_id_label"), ClickableInput( self.parent_class.mainwindow, @@ -158,13 +158,13 @@ def compose(self) -> ComposeResult: ), RadioButton( "AWS S3", - id=self.radiobutton_id_from_connection_method("aws_s3"), + id=self.radiobutton_id_from_connection_method("aws"), ), id="configs_connect_method_radioset", ), *self.config_ssh_widgets, *self.config_gdrive_widgets, - *self.config_aws_s3_widgets, + *self.config_aws_widgets, Label("Central Path", id="configs_central_path_label"), Horizontal( ClickableInput( @@ -341,7 +341,7 @@ def connection_method_from_radiobutton_id( in [ "ssh", "gdrive", - "aws_s3", + "aws", "local_filesystem", ] else None @@ -410,7 +410,7 @@ def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: widget.display = display_gdrive def switch_aws_widgets_display(self, display_aws: bool) -> None: - for widget in self.config_aws_s3_widgets: + for widget in self.config_aws_widgets: widget.display = display_aws def on_button_pressed(self, event: Button.Pressed) -> None: @@ -437,7 +437,7 @@ def on_button_pressed(self, event: Button.Pressed) -> None: self.setup_ssh_connection() elif connection_method == "gdrive": self.setup_gdrive_connection() - elif connection_method == "aws_s3": + elif connection_method == "aws": self.setup_aws_connection() elif event.button.id == "configs_go_to_project_screen_button": @@ -611,7 +611,7 @@ def setup_configs_for_a_new_project(self) -> None: elif connection_method == "gdrive": message = message_template.format(method_name="Google Drive") - elif connection_method == "aws_s3": + elif connection_method == "aws": message = message_template.format(method_name="AWS") else: @@ -687,8 +687,8 @@ def fill_widgets_with_project_configs(self) -> None: cfg_to_load["connection_method"] == "local_filesystem", "configs_gdrive_radiobutton": cfg_to_load["connection_method"] == "gdrive", - "configs_aws_s3_radiobutton": - cfg_to_load["connection_method"] == "aws_s3", + "configs_aws_radiobutton": + cfg_to_load["connection_method"] == "aws", "configs_local_only_radiobutton": cfg_to_load["connection_method"] is None, } @@ -780,13 +780,13 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: "local_filesystem", "ssh", "gdrive", - "aws_s3", + "aws", ], "Unexpected Connection Method" connection_widget_display_functions = { "ssh": self.switch_ssh_widgets_display, "gdrive": self.switch_gdrive_widgets_display, - "aws_s3": self.switch_aws_widgets_display, + "aws": self.switch_aws_widgets_display, } for name, widget_func in connection_widget_display_functions.items(): @@ -826,7 +826,7 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: setup_connection_button.label = "Setup SSH Connection" elif connection_method == "gdrive": setup_connection_button.label = "Setup Google Drive Connection" - elif connection_method == "aws_s3": + elif connection_method == "aws": setup_connection_button.label = "Setup AWS Connection" def get_datashuttle_inputs_from_widgets(self) -> Dict: @@ -852,7 +852,7 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: "configs_local_filesystem_radiobutton", "configs_ssh_radiobutton", "configs_gdrive_radiobutton", - "configs_aws_s3_radiobutton", + "configs_aws_radiobutton", "configs_local_only_radiobutton", ]: if self.query_one("#" + id).value: @@ -897,7 +897,7 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: ) # AWS specific - elif connection_method == "aws_s3": + elif connection_method == "aws": aws_access_key_id = self.query_one( "#configs_aws_access_key_id_input" ).value diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index 1e1c9f21..95f1aee2 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -520,7 +520,7 @@ def search_for_folders( if local_or_central == "central" and cfg["connection_method"] in [ "ssh", "gdrive", - "aws_s3", + "aws", ]: if cfg["connection_method"] == "ssh": all_folder_names, all_filenames = ( diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 5c7cd208..f2e17539 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -194,7 +194,7 @@ def setup_rclone_config_for_gdrive( log_rclone_config_output() -def setup_rclone_config_for_aws_s3( +def setup_rclone_config_for_aws( cfg: Configs, aws_secret_access_key: str, rclone_config_name: str, From 8bb7c2824a5eded06a13a4119db485ac5448b261 Mon Sep 17 00:00:00 2001 From: shrey Date: Tue, 10 Jun 2025 00:37:39 +0530 Subject: [PATCH 30/39] add: utility function to remove duplicate code --- datashuttle/tui/shared/configs_content.py | 43 ++++++++++++----------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 9eb5a172..8e1af2de 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -865,44 +865,38 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: # SSH specific if connection_method == "ssh": - central_host_id = self.query_one( - "#configs_central_host_id_input" - ).value cfg_kwargs["central_host_id"] = ( - None if central_host_id == "" else central_host_id + self.get_config_value_from_input_value( + "#configs_central_host_id_input" + ) ) - central_host_username = self.query_one( - "#configs_central_host_username_input" - ).value - cfg_kwargs["central_host_username"] = ( - None if central_host_username == "" else central_host_username + self.get_config_value_from_input_value( + "#configs_central_host_username_input" + ) ) # Google Drive specific elif connection_method == "gdrive": - gdrive_client_id = self.query_one( - "#configs_gdrive_client_id_input" - ).value cfg_kwargs["gdrive_client_id"] = ( - None if gdrive_client_id == "" else gdrive_client_id + self.get_config_value_from_input_value( + "#configs_gdrive_client_id_input" + ) ) - gdrive_root_folder_id = self.query_one( - "#configs_gdrive_root_folder_id" - ).value cfg_kwargs["gdrive_root_folder_id"] = ( - None if gdrive_root_folder_id == "" else gdrive_root_folder_id + self.get_config_value_from_input_value( + "#configs_gdrive_root_folder_id" + ) ) # AWS specific elif connection_method == "aws": - aws_access_key_id = self.query_one( - "#configs_aws_access_key_id_input" - ).value cfg_kwargs["aws_access_key_id"] = ( - None if aws_access_key_id == "" else aws_access_key_id + self.get_config_value_from_input_value( + "#configs_aws_access_key_id_input" + ) ) aws_region = self.query_one("#configs_aws_region_select").value @@ -911,3 +905,10 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: ) return cfg_kwargs + + def get_config_value_from_input_value( + self, input_box_selector: str + ) -> str | None: + input_value = self.query_one(input_box_selector).value + + return None if input_value == "" else input_value From 8beaa423a3fd7a68afb112660b80f4eedf804a71 Mon Sep 17 00:00:00 2001 From: shrey Date: Tue, 10 Jun 2025 01:35:17 +0530 Subject: [PATCH 31/39] add: docstrings to setup gdrive dialog --- datashuttle/tui/screens/setup_gdrive.py | 61 ++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/datashuttle/tui/screens/setup_gdrive.py b/datashuttle/tui/screens/setup_gdrive.py index e4798673..cc11035f 100644 --- a/datashuttle/tui/screens/setup_gdrive.py +++ b/datashuttle/tui/screens/setup_gdrive.py @@ -21,7 +21,14 @@ class SetupGdriveScreen(ModalScreen): - """ """ + """ + This dialog window handles the TUI equivalent of API's `setup_google_drive_connection`. + If the config contains a "gdrive_client_id", the user is prompted to enter a client secret. + Then, the user is asked if the their machine has access to a browser. If yes, a + google drive authentication page on the user's default browser opens up. Else, + the user is asked to run an rclone command on a machine with access to a browser and + input the config token generated by rclone. + """ def __init__(self, interface: Interface) -> None: super(SetupGdriveScreen, self).__init__() @@ -54,7 +61,28 @@ def compose(self) -> ComposeResult: ) def on_button_pressed(self, event: Button.Pressed) -> None: - """ """ + """ + This dialog window operates using 6 buttons. + + 1) "ok" button : Starts the connection setup process. It has an intermediate + step that asks the user for client secret if "gdrive_client_id" is present + in configs. And then proceeds to ask the user for browser availability. + This is done via a `stage` variable. If asking the user for client secret, + the stage is incremented by 0.5 on two steps. Else, `ask_user_for_browser` + increments the stage directly by 1. + + 2) "yes" button : A "yes" answer to the availability of browser question. On click, + proceeds to a browser authentication. + + 3) "no" button : A "no" answer to the availability of browser question. On click, + prompts the user to enter a config token by running an rclone command. + + 4) "enter" button : To enter the config token generated by rclone. + + 5) "finish" button : To finish the setup. + + 6) "cancel" button : To cancel the setup at any step before completion. + """ if ( event.button.id == "setup_gdrive_cancel_button" or event.button.id == "setup_gdrive_finish_button" @@ -88,6 +116,10 @@ def on_button_pressed(self, event: Button.Pressed) -> None: self.setup_gdrive_connection_using_config_token() def ask_user_for_gdrive_client_secret(self) -> None: + """ + Asks the user for google drive client secret. Only called if + the datashuttle config has a gdrive_client_id. + """ message = ( "Please provide the client secret for Google Drive. " "You can find it in your Google Cloud Console." @@ -102,6 +134,9 @@ def ask_user_for_gdrive_client_secret(self) -> None: self.stage += 0.5 def ask_user_for_browser(self) -> None: + """ + Asks the user if their machine has access to a browser. + """ message = ( "Are you running Datashuttle on a machine " "that can open a web browser?" @@ -126,6 +161,13 @@ def ask_user_for_browser(self) -> None: self.stage += 0.5 if self.stage == 0.5 else 1 def open_browser_and_setup_gdrive_connection(self) -> None: + """ + This removes the "yes" and "no" buttons that were asked during the + browser question and starts an asyncio task that sets up google drive + connection and updates the UI with success/failure. The connection setup + is asynchronous so that the user is able to cancel the setup if anything + goes wrong without quitting datashuttle altogether. + """ message = "Please authenticate through browser." self.update_message_box_message(message) @@ -163,6 +205,10 @@ def prompt_user_for_config_token(self) -> None: self.mount_input_box_before_buttons() def setup_gdrive_connection_using_config_token(self) -> None: + """ + Disables the enter button and starts an asyncio task to setup + google drive connection and update the UI with success/failure. + """ self.input_box.disabled = True @@ -178,6 +224,11 @@ def setup_gdrive_connection_using_config_token(self) -> None: async def setup_gdrive_connection_and_update_ui( self, config_token: Optional[str] = None ) -> None: + """ + This starts the google drive connection setup in a separate thread (required + to have asynchronous processing) and awaits for its completion. After completion, + it displays a success / failure screen. + """ worker = self.setup_gdrive_connection(config_token) self.setup_worker = worker if worker.is_running: @@ -197,6 +248,12 @@ async def setup_gdrive_connection_and_update_ui( def setup_gdrive_connection( self, config_token: Optional[str] = None ) -> Worker[InterfaceOutput]: + """ + This function runs in a worker thread to setup google drive connection. + If the user had access to a browser, the underlying rclone commands called + by this function are responsible for opening google's auth page to authenticate + with google drive. + """ success, output = self.interface.setup_google_drive_connection( self.gdrive_client_secret, config_token ) From e53984ba7777100932d7167c71edda5749db6da0 Mon Sep 17 00:00:00 2001 From: shrey Date: Thu, 19 Jun 2025 23:23:50 +0530 Subject: [PATCH 32/39] update: config dict inplace change for backward compatibility, use existing rclone function, moving type hint imports in the conditional block --- datashuttle/configs/config_class.py | 27 +++++++++++++++------------ datashuttle/tui/screens/setup_aws.py | 2 +- datashuttle/utils/folders.py | 14 ++++---------- datashuttle/utils/rclone.py | 14 ++++++++++---- 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/datashuttle/configs/config_class.py b/datashuttle/configs/config_class.py index 5baad039..43457ac4 100644 --- a/datashuttle/configs/config_class.py +++ b/datashuttle/configs/config_class.py @@ -127,14 +127,13 @@ def load_from_file(self) -> None: load_configs.convert_str_and_pathlib_paths(config_dict, "str_to_path") - config_dict = self.ensure_backwards_compatibilty_for_config( - config_dict - ) + self.update_config_for_backward_compatability_if_required(config_dict) + self.data = config_dict - def ensure_backwards_compatibilty_for_config( + def update_config_for_backward_compatability_if_required( self, config_dict: Dict - ) -> Dict: + ): canonical_config_keys_to_add = [ "gdrive_client_id", "gdrive_root_folder_id", @@ -143,16 +142,20 @@ def ensure_backwards_compatibilty_for_config( ] # All keys shall be missing for a backwards compatibility update - needs_update = all( - key not in config_dict.keys() - for key in canonical_config_keys_to_add - ) - if needs_update: + if not ( + all( + key in config_dict.keys() + for key in canonical_config_keys_to_add + ) + ): + assert not any( + key in config_dict.keys() + for key in canonical_config_keys_to_add + ) + for key in canonical_config_keys_to_add: config_dict[key] = None - return config_dict - # ------------------------------------------------------------------------- # Utils # ------------------------------------------------------------------------- diff --git a/datashuttle/tui/screens/setup_aws.py b/datashuttle/tui/screens/setup_aws.py index 8f59d7ca..8f429635 100644 --- a/datashuttle/tui/screens/setup_aws.py +++ b/datashuttle/tui/screens/setup_aws.py @@ -31,7 +31,7 @@ def compose(self) -> ComposeResult: yield Container( Horizontal( Static( - "Ready to setup AWS connection. " "Press OK to proceed", + "Ready to setup AWS connection. Press OK to proceed", id="setup_aws_messagebox_message", ), id="setup_aws_messagebox_message_container", diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index 95f1aee2..4eaf00ea 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import subprocess from typing import ( TYPE_CHECKING, Any, @@ -22,7 +21,7 @@ from pathlib import Path from datashuttle.configs import canonical_folders, canonical_tags -from datashuttle.utils import ssh, utils, validation +from datashuttle.utils import rclone, ssh, utils, validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError # ----------------------------------------------------------------------------- @@ -564,16 +563,11 @@ def search_gdrive_or_aws_for_folders( The json contains file/folder info about each file/folder like name, type, etc. """ - command = ( - "rclone lsjson " + output = rclone.call_rclone( + "lsjson " f"{cfg.get_rclone_config_name()}:{search_path.as_posix()} " f'--include "{search_prefix}"', - ) - output = subprocess.run( - command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True, + pipe_std=True, ) all_folder_names: List[str] = [] diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index f2e17539..30361094 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -1,14 +1,20 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, List, Literal, Optional + +if TYPE_CHECKING: + from pathlib import Path + + from datashuttle.configs.config_class import Configs + from datashuttle.utils.custom_types import TopLevelFolder + import os import platform import subprocess import tempfile -from pathlib import Path from subprocess import CompletedProcess -from typing import Dict, List, Literal, Optional -from datashuttle.configs.config_class import Configs from datashuttle.utils import utils -from datashuttle.utils.custom_types import TopLevelFolder def call_rclone(command: str, pipe_std: bool = False) -> CompletedProcess: From 772c3c1c64b2df7a21565afa446c1af058195a04 Mon Sep 17 00:00:00 2001 From: shrey Date: Fri, 20 Jun 2025 06:13:14 +0530 Subject: [PATCH 33/39] add: docstrings to setup connection functions; remove: aws region class --- datashuttle/__init__.py | 1 - datashuttle/configs/aws_regions.py | 45 ------------------------ datashuttle/configs/canonical_configs.py | 1 - datashuttle/datashuttle_class.py | 25 ++++++++++++- 4 files changed, 24 insertions(+), 48 deletions(-) diff --git a/datashuttle/__init__.py b/datashuttle/__init__.py index 2b9121e2..501cb203 100644 --- a/datashuttle/__init__.py +++ b/datashuttle/__init__.py @@ -2,7 +2,6 @@ from datashuttle.datashuttle_class import DataShuttle from datashuttle.datashuttle_functions import quick_validate_project -from datashuttle.configs.aws_regions import AWS_REGION try: diff --git a/datashuttle/configs/aws_regions.py b/datashuttle/configs/aws_regions.py index 60e14beb..84576030 100644 --- a/datashuttle/configs/aws_regions.py +++ b/datashuttle/configs/aws_regions.py @@ -41,48 +41,3 @@ def get_aws_regions() -> Dict[str, str]: def get_aws_regions_list() -> List[str]: return list(get_aws_regions().values()) - - -AWS_REGIONS_DICT = get_aws_regions() # runtime constant - - -class AWS_REGION: - """ - A class to represent AWS regions as constants. - It is used to provide intellisense for AWS regions in IDEs. - """ - - US_EAST_1 = AWS_REGIONS_DICT["US_EAST_1"] - US_EAST_2 = AWS_REGIONS_DICT["US_EAST_2"] - US_WEST_1 = AWS_REGIONS_DICT["US_WEST_1"] - US_WEST_2 = AWS_REGIONS_DICT["US_WEST_2"] - CA_CENTRAL_1 = AWS_REGIONS_DICT["CA_CENTRAL_1"] - EU_WEST_1 = AWS_REGIONS_DICT["EU_WEST_1"] - EU_WEST_2 = AWS_REGIONS_DICT["EU_WEST_2"] - EU_WEST_3 = AWS_REGIONS_DICT["EU_WEST_3"] - EU_NORTH_1 = AWS_REGIONS_DICT["EU_NORTH_1"] - EU_SOUTH_1 = AWS_REGIONS_DICT["EU_SOUTH_1"] - EU_CENTRAL_1 = AWS_REGIONS_DICT["EU_CENTRAL_1"] - AP_SOUTHEAST_1 = AWS_REGIONS_DICT["AP_SOUTHEAST_1"] - AP_SOUTHEAST_2 = AWS_REGIONS_DICT["AP_SOUTHEAST_2"] - AP_NORTHEAST_1 = AWS_REGIONS_DICT["AP_NORTHEAST_1"] - AP_NORTHEAST_2 = AWS_REGIONS_DICT["AP_NORTHEAST_2"] - AP_NORTHEAST_3 = AWS_REGIONS_DICT["AP_NORTHEAST_3"] - AP_SOUTH_1 = AWS_REGIONS_DICT["AP_SOUTH_1"] - AP_EAST_1 = AWS_REGIONS_DICT["AP_EAST_1"] - SA_EAST_1 = AWS_REGIONS_DICT["SA_EAST_1"] - IL_CENTRAL_1 = AWS_REGIONS_DICT["IL_CENTRAL_1"] - ME_SOUTH_1 = AWS_REGIONS_DICT["ME_SOUTH_1"] - AF_SOUTH_1 = AWS_REGIONS_DICT["AF_SOUTH_1"] - CN_NORTH_1 = AWS_REGIONS_DICT["CN_NORTH_1"] - CN_NORTHWEST_1 = AWS_REGIONS_DICT["CN_NORTHWEST_1"] - US_GOV_EAST_1 = AWS_REGIONS_DICT["US_GOV_EAST_1"] - US_GOV_WEST_1 = AWS_REGIONS_DICT["US_GOV_WEST_1"] - - @classmethod - def get_all_regions(cls): - return [ - value - for key, value in vars(cls).items() - if not key.startswith("__") and isinstance(value, str) - ] diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 0969c858..2b5b1998 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -49,7 +49,6 @@ def get_canonical_configs() -> dict: "gdrive_root_folder_id": Optional[str], "aws_access_key_id": Optional[str], "aws_region": Optional[Literal[*get_aws_regions_list()]], - # "aws_s3_endpoint_url": Optional[str], } return canonical_configs diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 672dd4a4..76125829 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -901,6 +901,20 @@ def write_public_key(self, filepath: str) -> None: @check_configs_set def setup_google_drive_connection(self) -> None: + """ + Setup a connection to Google Drive using the provided credentials. + Assumes `gdrive_root_folder_id` is set in configs. + + First, the user will be prompted to enter their Google Drive client + secret if `gdrive_client_id` is set in the configs. + + Next, the user will be asked if their machine has access to a browser. + If not, they will be prompted to input a config_token after running an + rclone command displayed to the user on a machine with access to a browser. + + Next, with the provided credentials, the final setup will be done. This + opens up a browser if the user confirmed access to a browser. + """ self._start_log( "setup-google-drive-connection-to-central-server", local_vars=locals(), @@ -912,7 +926,6 @@ def setup_google_drive_connection(self) -> None: gdrive_client_secret = None browser_available = gdrive.ask_user_for_browser(log=True) - config_token = None if not browser_available: config_token = gdrive.prompt_and_get_config_token( @@ -921,6 +934,8 @@ def setup_google_drive_connection(self) -> None: self.cfg.get_rclone_config_name("gdrive"), log=True, ) + else: + config_token = None self._setup_rclone_gdrive_config( gdrive_client_secret, config_token, log=True @@ -937,6 +952,14 @@ def setup_google_drive_connection(self) -> None: @requires_aws_configs @check_configs_set def setup_aws_connection(self) -> None: + """ + Setup a connection to AWS S3 buckets using the provided credentials. + Assumes `aws_access_key_id` and `aws_region` are set in configs. + + First, the user will be prompted to input their AWS secret access key. + + Next, with the provided credentials, the final connection setup will be done. + """ self._start_log( "setup-aws-connection-to-central-server", local_vars=locals(), From 91f2454e1f8f761d28503d4ac1cedb3c10ff452d Mon Sep 17 00:00:00 2001 From: shrey Date: Fri, 20 Jun 2025 06:57:27 +0530 Subject: [PATCH 34/39] add: docstrings to setup widgets function; use backwards compatibility --- datashuttle/tui/shared/configs_content.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index 8e1af2de..bd2f8e4e 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -741,7 +741,7 @@ def fill_inputs_with_project_configs(self) -> None: input = self.query_one("#configs_gdrive_client_id_input") value = ( "" - if cfg_to_load.get("gdrive_client_id", None) is None + if cfg_to_load["gdrive_client_id"] is None else cfg_to_load["gdrive_client_id"] ) input.value = value @@ -750,7 +750,7 @@ def fill_inputs_with_project_configs(self) -> None: input = self.query_one("#configs_gdrive_root_folder_id") value = ( "" - if cfg_to_load.get("gdrive_root_folder_id", None) is None + if cfg_to_load["gdrive_root_folder_id"] is None else cfg_to_load["gdrive_root_folder_id"] ) input.value = value @@ -759,7 +759,7 @@ def fill_inputs_with_project_configs(self) -> None: input = self.query_one("#configs_aws_access_key_id_input") value = ( "" - if cfg_to_load.get("aws_access_key_id", None) is None + if cfg_to_load["aws_access_key_id"] is None else cfg_to_load["aws_access_key_id"] ) input.value = value @@ -768,13 +768,23 @@ def fill_inputs_with_project_configs(self) -> None: select = self.query_one("#configs_aws_region_select") value = ( Select.BLANK - if cfg_to_load.get("aws_region", None) is None + if cfg_to_load["aws_region"] is None else cfg_to_load["aws_region"] ) select.value = value def setup_widgets_to_display(self, connection_method: str | None) -> None: + """ + Sets up widgets to display based on the chosen `connection_method` on the + radiobutton. The widgets pertaining to the chosen connection method will be + be displayed. This is done by dedicated functions for each connection method + which display widgets on receiving a `True` flag. + + Also, this function handles other TUI changes like displaying "setup connection" + button, disabling central path input in a local only project, etc. + Called on mount, on radiobuttons' switch and upon saving project configs. + """ if connection_method: assert connection_method in [ "local_filesystem", @@ -783,6 +793,7 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: "aws", ], "Unexpected Connection Method" + # Connection specific widgets connection_widget_display_functions = { "ssh": self.switch_ssh_widgets_display, "gdrive": self.switch_gdrive_widgets_display, @@ -795,6 +806,7 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: else: widget_func(False) + # Central path input self.query_one("#configs_central_path_input").disabled = ( connection_method is None ) @@ -811,6 +823,7 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: ) # fmt: off + # Setup connection button if ( not connection_method or connection_method == "local_filesystem" From ee88875a2feadd0ce435ad4257f39c674c1e551c Mon Sep 17 00:00:00 2001 From: shrey Date: Fri, 20 Jun 2025 07:32:45 +0530 Subject: [PATCH 35/39] add: docstrings to rclone function, change arugment order --- datashuttle/datashuttle_class.py | 4 +-- datashuttle/utils/rclone.py | 45 ++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 76125829..513e254c 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -1570,8 +1570,8 @@ def _setup_rclone_gdrive_config( ) -> None: rclone.setup_rclone_config_for_gdrive( self.cfg, - gdrive_client_secret, self.cfg.get_rclone_config_name("gdrive"), + gdrive_client_secret, config_token, log=log, ) @@ -1581,8 +1581,8 @@ def _setup_rclone_aws_config( ) -> None: rclone.setup_rclone_config_for_aws( self.cfg, - aws_secret_access_key, self.cfg.get_rclone_config_name("aws"), + aws_secret_access_key, log=log, ) diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 30361094..2e96b8a3 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -158,11 +158,34 @@ def setup_rclone_config_for_ssh( def setup_rclone_config_for_gdrive( cfg: Configs, - gdrive_client_secret: str | None, rclone_config_name: str, + gdrive_client_secret: str | None, config_token: Optional[str] = None, log: bool = True, ): + """ + Sets up rclone config for connections to Google Drive. This must + contain the `gdrive_root_folder_id` and optionally a `gdrive_client_id` + which also mandates for the presence of a Google Drive client secret. + + Parameters + ---------- + + cfg : Configs + datashuttle configs UserDict. + + rclone_config_name : rclone config name + canonical config name, generated by + datashuttle.cfg.get_rclone_config_name() + + gdrive_client_secret : Google Drive client secret, mandatory when + using a Google Drive client. + + config_token : a token to setup rclone config without opening a browser, + needed if the user's machine doesn't have access to a browser. + + log : whether to log, if True logger must already be initialised. + """ client_id_key_value = ( f"client_id {cfg['gdrive_client_id']} " if cfg["gdrive_client_id"] @@ -202,10 +225,28 @@ def setup_rclone_config_for_gdrive( def setup_rclone_config_for_aws( cfg: Configs, - aws_secret_access_key: str, rclone_config_name: str, + aws_secret_access_key: str, log: bool = True, ): + """ + Sets up rclone config for connections to AWS S3 buckets. This must + contain the `aws_access_key_id` and `aws_region`. + + Parameters + ---------- + + cfg : Configs + datashuttle configs UserDict. + + rclone_config_name : rclone config name + canonical config name, generated by + datashuttle.cfg.get_rclone_config_name() + + aws_secret_access_key : the aws secret access key provided by the user. + + log : whether to log, if True logger must already be initialised. + """ output = call_rclone( "config create " f"{rclone_config_name} " From c0a7ecad8088a078febc5ef2fb5a373ae5ab20d8 Mon Sep 17 00:00:00 2001 From: shrey Date: Fri, 20 Jun 2025 08:30:41 +0530 Subject: [PATCH 36/39] minor changes --- datashuttle/tui/shared/configs_content.py | 8 +++++--- datashuttle/utils/gdrive.py | 5 +---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index bd2f8e4e..8ea5ea8a 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -806,16 +806,18 @@ def setup_widgets_to_display(self, connection_method: str | None) -> None: else: widget_func(False) + has_connection_method = connection_method is not None + # Central path input self.query_one("#configs_central_path_input").disabled = ( - connection_method is None + not has_connection_method ) self.query_one("#configs_central_path_select_button").disabled = ( - connection_method is None + not has_connection_method ) # Local only project - if not connection_method: + if not has_connection_method: self.query_one("#configs_central_path_input").value = "" setup_connection_button = self.query_one( diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index b80f2615..fdebe749 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -83,10 +83,7 @@ def ask_user_for_browser(log: bool = True) -> bool: utils.print_message_to_user("Invalid input. Press either 'y' or 'n'.") input_ = utils.get_user_input(message).lower() - if input_ == "y": - answer = True - else: - answer = False + answer = input_ == "y" if log: utils.log(message) From eb3f098f4bc563b33573b66ead431a65b3d4c31f Mon Sep 17 00:00:00 2001 From: shrey Date: Fri, 20 Jun 2025 12:37:45 +0530 Subject: [PATCH 37/39] refactor: centralize the get secret function --- datashuttle/utils/aws.py | 37 ++++++---------------------------- datashuttle/utils/gdrive.py | 37 ++++++---------------------------- datashuttle/utils/utils.py | 40 +++++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 62 deletions(-) diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py index c6e506a4..aebaf225 100644 --- a/datashuttle/utils/aws.py +++ b/datashuttle/utils/aws.py @@ -1,6 +1,4 @@ -import getpass import json -import sys from datashuttle.configs.config_class import Configs from datashuttle.utils import rclone, utils @@ -41,34 +39,11 @@ def raise_if_bucket_absent(cfg: Configs) -> None: def get_aws_secret_access_key(log: bool = True) -> str: - if not sys.stdin.isatty(): - proceed = input( - "\nWARNING!\nThe next step is to enter a AWS secret access key, but it is not possible\n" - "to hide your secret access key while entering it in the current terminal.\n" - "This can occur if running the command in an IDE.\n\n" - "Press 'y' to proceed to secret key entry. " - "The characters will not be hidden!\n" - "Alternatively, run AWS S3 setup after starting Python in your " - "system terminal \nrather than through an IDE: " - ) - if proceed != "y": - utils.print_message_to_user( - "Quitting AWS S3 setup as 'y' not pressed." - ) - utils.log_and_raise_error( - "AWS S3 setup aborted by user.", ConnectionAbortedError - ) - - aws_secret_access_key = input( - "Please enter your AWS secret access key. Characters will not be hidden: " - ) - - else: - aws_secret_access_key = getpass.getpass( - "Please enter your AWS secret access key: " - ) - - if log: - utils.log("AWS secret access key entered by user.") + aws_secret_access_key = utils.get_connection_secret_from_user( + connection_method_name="AWS", + key_name_full="AWS secret access key", + key_name_short="secret key", + log_status=log, + ) return aws_secret_access_key.strip() diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index fdebe749..440c7c08 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -1,6 +1,4 @@ -import getpass import json -import sys from datashuttle.configs.config_class import Configs from datashuttle.utils import rclone, utils @@ -112,34 +110,11 @@ def prompt_and_get_config_token( def get_client_secret(log: bool = True) -> str: - if not sys.stdin.isatty(): - proceed = input( - "\nWARNING!\nThe next step is to enter a google drive client secret, but it is not possible\n" - "to hide your client secret while entering it in the current terminal.\n" - "This can occur if running the command in an IDE.\n\n" - "Press 'y' to proceed to client secret entry. " - "The characters will not be hidden!\n" - "Alternatively, run google drive setup after starting Python in your " - "system terminal \nrather than through an IDE: " - ) - if proceed != "y": - utils.print_message_to_user( - "Quitting google drive setup as 'y' not pressed." - ) - utils.log_and_raise_error( - "Google Drive setup aborted by user.", ConnectionAbortedError - ) - - gdrive_client_secret = input( - "Please enter your google drive client secret. Characters will not be hidden: " - ) - - else: - gdrive_client_secret = getpass.getpass( - "Please enter your google drive client secret: " - ) - - if log: - utils.log("Google Drive client secret entered by user.") + gdrive_client_secret = utils.get_connection_secret_from_user( + connection_method_name="Google Drive", + key_name_full="Google Drive client secret", + key_name_short="secret key", + log_status=log, + ) return gdrive_client_secret.strip() diff --git a/datashuttle/utils/utils.py b/datashuttle/utils/utils.py index 87a39e5c..862e7bb4 100644 --- a/datashuttle/utils/utils.py +++ b/datashuttle/utils/utils.py @@ -1,6 +1,8 @@ from __future__ import annotations +import getpass import re +import sys import traceback import warnings from typing import TYPE_CHECKING, Any, List, Literal, Union, overload @@ -87,6 +89,44 @@ def get_user_input(message: str) -> str: return input_ +def get_connection_secret_from_user( + connection_method_name: str, + key_name_full: str, + key_name_short: str, + log_status: bool, +) -> str: + if not sys.stdin.isatty(): + proceed = input( + f"\nWARNING!\nThe next step is to enter a {key_name_full}, but it is not possible\n" + f"to hide your {key_name_short} while entering it in the current terminal.\n" + f"This can occur if running the command in an IDE.\n\n" + f"Press 'y' to proceed to {key_name_short} entry. " + f"The characters will not be hidden!\n" + f"Alternatively, run {connection_method_name} setup after starting Python in your " + f"system terminal \nrather than through an IDE: " + ) + if proceed != "y": + print_message_to_user( + f"Quitting {connection_method_name} setup as 'y' not pressed." + ) + log_and_raise_error( + f"{connection_method_name} setup aborted by user.", + ConnectionAbortedError, + ) + + input_ = input( + f"Please enter your {key_name_full}. Characters will not be hidden: " + ) + + else: + input_ = getpass.getpass(f"Please enter your {key_name_full}: ") + + if log_status: + log(f"{key_name_full} entered by user.") + + return input_ + + # ----------------------------------------------------------------------------- # Paths # ----------------------------------------------------------------------------- From 77d9a716363b7d05bbb0fc2bd2681da7961907fa Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 20 Jun 2025 13:06:01 +0100 Subject: [PATCH 38/39] Edit GDrive headless method and add test. --- .github/workflows/code_test_and_deploy.yml | 12 +++- datashuttle/configs/canonical_configs.py | 3 +- datashuttle/datashuttle_class.py | 23 ++++--- datashuttle/utils/folders.py | 2 +- datashuttle/utils/gdrive.py | 10 ++- datashuttle/utils/rclone.py | 19 ++++-- .../test_gdrive_connection.py | 64 +++++++++++++++++++ 7 files changed, 111 insertions(+), 22 deletions(-) create mode 100644 tests/tests_integration/test_gdrive_connection.py diff --git a/.github/workflows/code_test_and_deploy.yml b/.github/workflows/code_test_and_deploy.yml index 49054c6b..dc658c6e 100644 --- a/.github/workflows/code_test_and_deploy.yml +++ b/.github/workflows/code_test_and_deploy.yml @@ -55,8 +55,16 @@ jobs: conda uninstall datashuttle --force python -m pip install --upgrade pip pip install .[dev] - - name: Test - run: pytest +# - name: Test + # run: pytest + - name: Set up Google Drive secrets + run: | + printf '%s' '${{ secrets.GDRIVE_SERVICE_ACCOUNT_JSON }}' > "$HOME/gdrive.json" + echo "GDRIVE_SERVICE_ACCOUNT_FILE=$HOME/gdrive.json" >> $GITHUB_ENV + echo "GDRIVE_ROOT_FOLDER_ID=${{ secrets.GDRIVE_ROOT_FOLDER_ID }}" >> $GITHUB_ENV + + - name: Run Google Drive tests + run: pytest -q -k test_gdrive_connection build_sdist_wheels: name: Build source distribution diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 2b5b1998..b0d6d8ef 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -27,7 +27,6 @@ import typeguard -from datashuttle.configs.aws_regions import get_aws_regions_list from datashuttle.utils import folders, utils from datashuttle.utils.custom_exceptions import ConfigError @@ -48,7 +47,7 @@ def get_canonical_configs() -> dict: "gdrive_client_id": Optional[str], "gdrive_root_folder_id": Optional[str], "aws_access_key_id": Optional[str], - "aws_region": Optional[Literal[*get_aws_regions_list()]], + "aws_region": Optional[str], } return canonical_configs diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 513e254c..6e5fb742 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -928,17 +928,20 @@ def setup_google_drive_connection(self) -> None: browser_available = gdrive.ask_user_for_browser(log=True) if not browser_available: - config_token = gdrive.prompt_and_get_config_token( - self.cfg, - gdrive_client_secret, - self.cfg.get_rclone_config_name("gdrive"), - log=True, - ) + service_account_filepath = utils.get_user_input( + "Please input the path to your credentials json" + ) # TODO: add more explanation + # config_token = gdrive.prompt_and_get_config_token( + # self.cfg, + # gdrive_client_secret, + # self.cfg.get_rclone_config_name("gdrive"), + # log=True, + # ) else: - config_token = None + service_account_filepath = None self._setup_rclone_gdrive_config( - gdrive_client_secret, config_token, log=True + gdrive_client_secret, service_account_filepath, log=True ) rclone.check_successful_connection_and_raise_error_on_fail(self.cfg) @@ -1565,14 +1568,14 @@ def _setup_rclone_central_local_filesystem_config(self) -> None: def _setup_rclone_gdrive_config( self, gdrive_client_secret: str | None, - config_token: str | None, + service_account_filepath: str | None, log: bool, ) -> None: rclone.setup_rclone_config_for_gdrive( self.cfg, self.cfg.get_rclone_config_name("gdrive"), gdrive_client_secret, - config_token, + service_account_filepath, log=log, ) diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index 4eaf00ea..b42478b0 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -575,7 +575,7 @@ def search_gdrive_or_aws_for_folders( if output.returncode != 0: utils.log_and_message( - f"Error searching files at {search_path.as_posix()} \n {output.stderr.decode('utf-8') if output.stderr else ""}" + f"Error searching files at {search_path.as_posix()} \n {output.stderr.decode('utf-8') if output.stderr else ''}" ) return all_folder_names, all_filenames diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index 440c7c08..bd8df990 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -1,6 +1,14 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from datashuttle.configs.config_class import Configs + import json -from datashuttle.configs.config_class import Configs from datashuttle.utils import rclone, utils # ----------------------------------------------------------------------------- diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 2e96b8a3..35123599 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -160,7 +160,7 @@ def setup_rclone_config_for_gdrive( cfg: Configs, rclone_config_name: str, gdrive_client_secret: str | None, - config_token: Optional[str] = None, + service_account_filepath: Optional[str] = None, log: bool = True, ): """ @@ -197,10 +197,16 @@ def setup_rclone_config_for_gdrive( else "" ) - extra_args = ( - f"config_is_local=false config_token={config_token}" - if config_token - else "" + # extra_args = ( + # f"config_is_local=false config_token={config_token}" + # if config_token + # else "" + # ) + + service_account_filepath_arg = ( + "" + if service_account_filepath is None + else f"service_account_file {service_account_filepath} " ) output = call_rclone( f"config create " @@ -210,7 +216,8 @@ def setup_rclone_config_for_gdrive( f"{client_secret_key_value}" f"scope drive " f"root_folder_id {cfg['gdrive_root_folder_id']} " - f"{extra_args}", + f"{service_account_filepath_arg}", + # f"{extra_args}", pipe_std=True, ) diff --git a/tests/tests_integration/test_gdrive_connection.py b/tests/tests_integration/test_gdrive_connection.py new file mode 100644 index 00000000..2c89bfd3 --- /dev/null +++ b/tests/tests_integration/test_gdrive_connection.py @@ -0,0 +1,64 @@ +import builtins +import os +import random +import string + +import pytest +import test_utils +from base import BaseTest + +from datashuttle.configs.canonical_configs import get_broad_datatypes +from datashuttle.utils import rclone + + +@pytest.mark.skipif(os.getenv("CI") is None, reason="Only runs in CI") +class TestGoogleDriveGithubCI(BaseTest): + + def test_google_drive_connection(self, no_cfg_project, tmp_path): + + central_path = ( + f"test-id-{''.join(random.choices(string.digits, k=15))}" + ) + + root_id = os.environ["GDRIVE_ROOT_FOLDER_ID"] + sa_path = os.environ["GDRIVE_SERVICE_ACCOUNT_FILE"] + + no_cfg_project.make_config_file( + local_path=str(tmp_path), # any temp location TODO UPDATE + connection_method="gdrive", + central_path=central_path, + gdrive_root_folder_id=root_id, + gdrive_client_id=None, # keep None + ) + + state = {"first": True} + + def mock_input(_: str) -> str: + if state["first"]: + state["first"] = False + return "n" + return sa_path + + original_input = builtins.input + builtins.input = mock_input + + no_cfg_project.setup_google_drive_connection() + + builtins.input = original_input + + subs, sessions = test_utils.get_default_sub_sessions_to_test() + + test_utils.make_and_check_local_project_folders( + no_cfg_project, "rawdata", subs, sessions, get_broad_datatypes() + ) + + no_cfg_project.upload_entire_project() + + # get a list of files on gdrive and check they are as expected + # assert the test id if its failed + + # only tidy up if as expected, otherwise we can leave the folder there to have a look + # and delete manually later + rclone.call_rclone( + f"purge central_{no_cfg_project.project_name}_gdrive:{central_path}" + ) From 34d158e1efad87dd7cc33741b46f24ace4943784 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 20 Jun 2025 17:43:21 +0100 Subject: [PATCH 39/39] Add prototype AWS test. --- .github/workflows/code_test_and_deploy.yml | 21 ++++-- .../tests_integration/test_aws_connection.py | 64 +++++++++++++++++++ 2 files changed, 79 insertions(+), 6 deletions(-) create mode 100644 tests/tests_integration/test_aws_connection.py diff --git a/.github/workflows/code_test_and_deploy.yml b/.github/workflows/code_test_and_deploy.yml index dc658c6e..fe1ffdd4 100644 --- a/.github/workflows/code_test_and_deploy.yml +++ b/.github/workflows/code_test_and_deploy.yml @@ -57,14 +57,23 @@ jobs: pip install .[dev] # - name: Test # run: pytest - - name: Set up Google Drive secrets + # - name: Set up Google Drive secrets + # run: | + # printf '%s' '${{ secrets.GDRIVE_SERVICE_ACCOUNT_JSON }}' > "$HOME/gdrive.json" + # echo "GDRIVE_SERVICE_ACCOUNT_FILE=$HOME/gdrive.json" >> $GITHUB_ENV + # echo "GDRIVE_ROOT_FOLDER_ID=${{ secrets.GDRIVE_ROOT_FOLDER_ID }}" >> $GITHUB_ENV + + # - name: Run Google Drive tests + # run: pytest -q -k test_gdrive_connection + + - name: Set up AWS secrets run: | - printf '%s' '${{ secrets.GDRIVE_SERVICE_ACCOUNT_JSON }}' > "$HOME/gdrive.json" - echo "GDRIVE_SERVICE_ACCOUNT_FILE=$HOME/gdrive.json" >> $GITHUB_ENV - echo "GDRIVE_ROOT_FOLDER_ID=${{ secrets.GDRIVE_ROOT_FOLDER_ID }}" >> $GITHUB_ENV + echo "AWS_ACCESS_KEY_ID=${{ secrets.AWS_ACCESS_KEY_ID }}" >> $GITHUB_ENV + echo "AWS_ACCESS_KEY_ID_SECRET=${{ secrets.AWS_ACCESS_KEY_ID_SECRET }}" >> $GITHUB_ENV + echo "AWS_REGION=${{ secrets.AWS_REGION }}" >> $GITHUB_ENV - - name: Run Google Drive tests - run: pytest -q -k test_gdrive_connection + - name: Run AWS tests + run: pytest -q -k test_aws_connection build_sdist_wheels: name: Build source distribution diff --git a/tests/tests_integration/test_aws_connection.py b/tests/tests_integration/test_aws_connection.py new file mode 100644 index 00000000..e50ce9c8 --- /dev/null +++ b/tests/tests_integration/test_aws_connection.py @@ -0,0 +1,64 @@ +import builtins +import os +import random +import string + +import pytest +import test_utils +from base import BaseTest + +from datashuttle.configs.canonical_configs import get_broad_datatypes +from datashuttle.utils import rclone + + +@pytest.mark.skipif(os.getenv("CI") is None, reason="Only runs in CI") +class TestGoogleDriveGithubCI(BaseTest): + + def test_google_drive_connection(self, no_cfg_project, tmp_path): + + central_path = f"test-datashuttle/test-id-{''.join(random.choices(string.digits, k=15))}" + + aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"] + aws_access_key_id_secret = os.environ["AWS_ACCESS_KEY_ID_SECRET"] + aws_region = os.environ["AWS_REGION"] + + no_cfg_project.make_config_file( + local_path=str(tmp_path), # any temp location TODO UPDATE + connection_method="aws", + central_path=central_path, + aws_access_key_id=aws_access_key_id, + aws_region=aws_region, + ) + + state = {"first": True} + + def mock_input(_: str) -> str: + if state["first"]: + state["first"] = False + return "y" + return aws_access_key_id_secret + + original_input = builtins.input + builtins.input = mock_input + + no_cfg_project.setup_aws_connection() # TODO: check that the connection method is correct for these funcs + + builtins.input = original_input + + subs, sessions = test_utils.get_default_sub_sessions_to_test() + + test_utils.make_and_check_local_project_folders( + no_cfg_project, "rawdata", subs, sessions, get_broad_datatypes() + ) + + no_cfg_project.upload_entire_project() + + # get a list of files on gdrive and check they are as expected + # assert the test id if its failed + + # only tidy up if as expected, otherwise we can leave the folder there to have a look + # and delete manually later + + rclone.call_rclone( + f"purge central_{no_cfg_project.project_name}_aws:{central_path}" + )