diff --git a/.github/workflows/code_test_and_deploy.yml b/.github/workflows/code_test_and_deploy.yml index 52975585..2fee3d26 100644 --- a/.github/workflows/code_test_and_deploy.yml +++ b/.github/workflows/code_test_and_deploy.yml @@ -32,7 +32,7 @@ jobs: # macos-14 is M1, macos-13 is intel. Run on earliest and # latest python versions. All python versions are tested in # the weekly cron job. - os: [windows-latest, ubuntu-latest, macos-14, macos-13] + os: [ ubuntu-latest, windows-latest, macos-14, macos-13] # Test all Python versions for cron job, and only first/last for other triggers python-version: ${{ fromJson(github.event_name == 'schedule' && '["3.9", "3.10", "3.11", "3.12"]' || '["3.9", "3.12"]') }} @@ -57,8 +57,17 @@ jobs: run: | python -m pip install --upgrade pip pip install .[dev] - - name: Test - run: pytest + # run SSH tests only on Linux because Windows and macOS + # are already run within a virtual container and so cannot + # run Linux containers because nested containerisation is disabled. + - name: Test SSH (Linux only) + if: runner.os == 'Linux' + run: | + sudo service mysql stop # free up port 3306 for ssh tests + pytest tests/tests_transfers/ssh + - name: All Other Tests + run: | + pytest --ignore tests/tests_transfers/ssh build_sdist_wheels: name: Build source distribution diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index ac41d6e4..9d041743 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -11,6 +11,7 @@ from __future__ import annotations +import os from typing import ( TYPE_CHECKING, Dict, @@ -58,6 +59,16 @@ def keys_str_on_file_but_path_in_class() -> list[str]: ] +def get_default_ssh_port() -> int: + """ + Get the default port used for SSH connections. + """ + if "DS_SSH_PORT" in os.environ: + return int(os.environ["DS_SSH_PORT"]) + else: + return 22 + + # ----------------------------------------------------------------------------- # Check Configs # ----------------------------------------------------------------------------- diff --git a/datashuttle/configs/config_class.py b/datashuttle/configs/config_class.py index 562c3310..64e8364e 100644 --- a/datashuttle/configs/config_class.py +++ b/datashuttle/configs/config_class.py @@ -200,6 +200,9 @@ def get_rclone_config_name( return f"central_{self.project_name}_{connection_method}" + def get_rclone_config_name_local(self): + return f"local_{self.project_name}_local_filesystem" + def make_rclone_transfer_options( self, overwrite_existing_files: OverwriteExistingFiles, dry_run: bool ) -> Dict: diff --git a/datashuttle/utils/data_transfer.py b/datashuttle/utils/data_transfer.py index 21121b8e..c617f4c1 100644 --- a/datashuttle/utils/data_transfer.py +++ b/datashuttle/utils/data_transfer.py @@ -157,7 +157,6 @@ def build_a_list_of_all_files_and_folders_to_transfer(self) -> List[str]: self.update_list_with_non_ses_sub_level_folders( extra_folder_names, extra_filenames, sub ) - continue # Datatype (sub and ses level) -------------------------------- diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index 56852640..f28f519a 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -515,25 +515,66 @@ def search_for_folders( verbose : If `True`, when a search folder cannot be found, a message will be printed with the missing path. """ - if local_or_central == "central" and cfg["connection_method"] == "ssh": - all_folder_names, all_filenames = ssh.search_ssh_central_for_folders( - search_path, - search_prefix, - cfg, - verbose, - return_full_path, + if local_or_central == "local": + all_folder_names, all_filenames = search_gdrive_or_aws_for_folders( + search_path, search_prefix, None, return_full_path ) - else: - if not search_path.exists(): - if verbose: - utils.log_and_message( - f"No file found at {search_path.as_posix()}" - ) - return [], [] - all_folder_names, all_filenames = search_filesystem_path_for_folders( + all_folder_names_, all_filenames_ = search_filesystem_path_for_folders( search_path / search_prefix, return_full_path ) + + assert all_folder_names == all_folder_names_ + assert all_filenames == all_filenames_ + + else: + + if cfg["connection_method"] == "ssh": + all_folder_names, all_filenames = ( + ssh.search_ssh_central_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) + ) + + all_folder_names_, all_filenames_ = ( + search_gdrive_or_aws_for_folders( + search_path, + search_prefix, + cfg.get_rclone_config_name("ssh"), + return_full_path, + ) + ) + assert sorted(all_folder_names) == all_folder_names_ + assert all_filenames == all_filenames_ + + else: + if not search_path.exists(): + if verbose: + utils.log_and_message( + f"No file found at {search_path.as_posix()}" + ) + return [], [] + + all_folder_names, all_filenames = search_gdrive_or_aws_for_folders( + search_path, + search_prefix, + cfg.get_rclone_config_name("local_filesystem"), + return_full_path, + ) + + all_folder_names_, all_filenames_ = ( + search_filesystem_path_for_folders( + search_path / search_prefix, return_full_path + ) + ) + + assert all_folder_names == all_folder_names_ + assert all_filenames == all_filenames_ + return all_folder_names, all_filenames @@ -565,3 +606,65 @@ def search_filesystem_path_for_folders( ) return all_folder_names, all_filenames + + +def search_gdrive_or_aws_for_folders( + search_path: Path, + search_prefix: str, + rclone_config_name: str | None, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """ + Searches for files and folders in central path using `rclone lsjson` command. + This command lists all the files and folders in the central path in a json format. + The json contains file/folder info about each file/folder like name, type, etc. + """ + import fnmatch + import json + + from datashuttle.utils import rclone + + if rclone_config_name: + config_prefix = f"{rclone_config_name}:" + else: + config_prefix = "" + + output = rclone.call_rclone( + f'lsjson {config_prefix}"{search_path.as_posix()}"', + pipe_std=True, + ) + + all_folder_names: List[str] = [] + all_filenames: List[str] = [] + + if output.returncode != 0: + utils.log_and_message( + f"Error searching files at {search_path.as_posix()} \n {output.stderr.decode('utf-8') if output.stderr else ''}" + ) + return all_folder_names, all_filenames + + files_and_folders = json.loads(output.stdout) + + # try: + for file_or_folder in files_and_folders: + + name = file_or_folder["Name"] + + if not fnmatch.fnmatch(name, search_prefix): + continue + + is_dir = file_or_folder.get("IsDir", False) + + to_append = search_path / name if return_full_path else name + + if is_dir: + all_folder_names.append(to_append) + else: + all_filenames.append(to_append) + + # except Exception: + # utils.log_and_message( + # f"Error searching files at {search_path.as_posix()}" + # ) + + return all_folder_names, all_filenames diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 49d7da82..890d0b9f 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -6,6 +6,7 @@ from subprocess import CompletedProcess from typing import Dict, List, Literal +from datashuttle.configs import canonical_configs from datashuttle.configs.config_class import Configs from datashuttle.utils import utils from datashuttle.utils.custom_types import TopLevelFolder @@ -141,7 +142,7 @@ def setup_rclone_config_for_ssh( f"sftp " f"host {cfg['central_host_id']} " f"user {cfg['central_host_username']} " - f"port 22 " + f"port {canonical_configs.get_default_ssh_port()} " f"key_file {ssh_key_path.as_posix()}", pipe_std=True, ) diff --git a/datashuttle/utils/ssh.py b/datashuttle/utils/ssh.py index 8f6de678..57024085 100644 --- a/datashuttle/utils/ssh.py +++ b/datashuttle/utils/ssh.py @@ -14,6 +14,7 @@ import paramiko +from datashuttle.configs import canonical_configs from datashuttle.utils import utils # ----------------------------------------------------------------------------- @@ -42,6 +43,7 @@ def connect_client_core( else None ), look_for_keys=True, + port=canonical_configs.get_default_ssh_port(), ) @@ -83,7 +85,9 @@ def get_remote_server_key(central_host_id: str): connection. """ transport: paramiko.Transport - with paramiko.Transport(central_host_id) as transport: + with paramiko.Transport( + (central_host_id, canonical_configs.get_default_ssh_port()) + ) as transport: transport.connect() key = transport.get_remote_server_key() return key @@ -91,7 +95,11 @@ def get_remote_server_key(central_host_id: str): def save_hostkey_locally(key, central_host_id, hostkeys_path) -> None: client = paramiko.SSHClient() - client.get_host_keys().add(central_host_id, key.get_name(), key) + client.get_host_keys().add( + f"[{central_host_id}]:{canonical_configs.get_default_ssh_port()}", + key.get_name(), + key, + ) client.get_host_keys().save(hostkeys_path.as_posix()) @@ -183,7 +191,7 @@ def connect_client_with_logging( f"Connection to { cfg['central_host_id']} made successfully." ) - except Exception: + except Exception as e: utils.log_and_raise_error( f"Could not connect to server. Ensure that \n" f"1) You have run setup_ssh_connection() \n" @@ -191,7 +199,8 @@ def connect_client_with_logging( f"3) The central_host_id: {cfg['central_host_id']} is" f" correct.\n" f"4) The central username:" - f" {cfg['central_host_username']}, and password are correct.", + f" {cfg['central_host_username']}, and password are correct." + f"Original error: {e}", ConnectionError, ) diff --git a/pyproject.toml b/pyproject.toml index 9ca0f69e..58714f17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ select = ["I", "E", "F", "TCH", "TID252"] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] +"tests/**/*" = ["TID252"] [tool.ruff.lint.mccabe] max-complexity = 18 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_integration/base.py b/tests/base.py similarity index 91% rename from tests/tests_integration/base.py rename to tests/base.py index bee36b01..7ce3f3c7 100644 --- a/tests/tests_integration/base.py +++ b/tests/base.py @@ -1,10 +1,11 @@ import warnings import pytest -import test_utils from datashuttle import DataShuttle +from . import test_utils + TEST_PROJECT_NAME = "test_project" @@ -13,8 +14,8 @@ class BaseTest: @pytest.fixture(scope="function") def no_cfg_project(test): """ - Fixture that creates an empty project. Ignore the warning - that no configs are setup yet. + Fixture that creates an empty project. Ignore the + warning that no configs are set up yet. """ test_utils.delete_project_if_it_exists(TEST_PROJECT_NAME) @@ -64,8 +65,8 @@ def project(self, tmp_path, request): def clean_project_name(self): """ Create an empty project, but ensure no - configs already exists, and delete created configs - after test. + configs already exists, and delete created + configs after test. """ project_name = TEST_PROJECT_NAME test_utils.delete_project_if_it_exists(project_name) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 2203b102..00000000 --- a/tests/conftest.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Test configs, used for setting up SSH tests. - -Before running these tests, it is necessary to setup -an SSH key. This can be done through datashuttle -ssh.setup_ssh_key(project.cfg, log=False). - -Store this path somewhere outside of the test environment, -and it will be copied to the project test folder before testing. - -FILESYSTEM_PATH and SERVER_PATH these must point -to the same folder on the HPC, filesystem, -as a mounted drive and server as the linux path to -connect through SSH -""" - -import platform -from types import SimpleNamespace - -import pytest -import test_utils - -test_ssh = False -username = "jziminski" -central_host_id = "ssh.swc.ucl.ac.uk" -server_path = r"/ceph/neuroinformatics/neuroinformatics/scratch/datashuttle_tests/fake_data" - - -if platform.system() == "Windows": - ssh_key_path = r"C:\Users\Joe\.datashuttle\test_file_conflicts_ssh_key" - filesystem_path = "X:/neuroinformatics/scratch/datashuttle_tests/fake_data" - -else: - ssh_key_path = "/home/joe/test_file_conflicts_ssh_key" - filesystem_path = "/home/joe/ceph_mount/neuroinformatics/scratch/datashuttle_tests/fake_data" - - -def pytest_configure(config): - pytest.ssh_config = SimpleNamespace( - TEST_SSH=test_ssh, - SSH_KEY_PATH=ssh_key_path, - USERNAME=username, - CENTRAL_HOST_ID=central_host_id, - FILESYSTEM_PATH=filesystem_path, # FILESYSTEM_PATH and SERVER_PATH these must point to the same folder on the HPC, filesystem - SERVER_PATH=server_path, # as a mounted drive and server as the linux path to connect through SSH - ) - test_utils.set_datashuttle_loggers(disable=True) diff --git a/tests/quick_make_project.py b/tests/quick_make_project.py deleted file mode 100644 index 250e6650..00000000 --- a/tests/quick_make_project.py +++ /dev/null @@ -1,5 +0,0 @@ -base_path = r"C:/Users/Joe/work/git-repos/forks/yxtuix/joe" - -from test_utils import quick_create_project - -quick_create_project(base_path) diff --git a/tests/ssh_test_utils.py b/tests/ssh_test_utils.py deleted file mode 100644 index 0838669f..00000000 --- a/tests/ssh_test_utils.py +++ /dev/null @@ -1,55 +0,0 @@ -import builtins -import copy - -from datashuttle.utils import rclone, ssh - - -def setup_project_for_ssh( - project, central_path, central_host_id, central_host_username -): - """ - Set up the project configs to use SSH connection - to central - """ - project.update_config_file( - central_path=central_path, - ) - project.update_config_file(central_host_id=central_host_id) - project.update_config_file(central_host_username=central_host_username) - project.update_config_file(connection_method="ssh") - - rclone.setup_rclone_config_for_ssh( - project.cfg, - project.cfg.get_rclone_config_name("ssh"), - project.cfg.ssh_key_path, - ) - - -def setup_mock_input(input_): - """ - This is very similar to pytest monkeypatch but - using that was giving me very strange output, - monkeypatch.setattr('builtins.input', lambda _: "n") - i.e. pdb went deep into some unrelated code stack - """ - orig_builtin = copy.deepcopy(builtins.input) - builtins.input = lambda _: input_ # type: ignore - return orig_builtin - - -def restore_mock_input(orig_builtin): - """ - orig_builtin: the copied, original builtins.input - """ - builtins.input = orig_builtin - - -def setup_hostkeys(project): - """ - Convenience function to verify the server hostkey. - """ - orig_builtin = setup_mock_input(input_="y") - ssh.verify_ssh_central_host( - project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True - ) - restore_mock_input(orig_builtin) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0e7b4569..1b3f9a9f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,6 @@ from pathlib import Path import yaml -from file_conflicts_pathtable import get_pathtable from datashuttle import DataShuttle from datashuttle.configs import canonical_configs, canonical_folders @@ -159,19 +158,6 @@ def make_test_path(base_path, local_or_central, test_project_name): return Path(base_path) / local_or_central / test_project_name -def create_all_pathtable_files(pathtable): - """ """ - for i in range(pathtable.shape[0]): - filepath = pathtable["base_folder"][i] / pathtable["path"][i] - filepath.parents[0].mkdir(parents=True, exist_ok=True) - write_file(filepath, contents="test_entry") - - -def quick_create_project(base_path): - pathtable = get_pathtable(base_path) - create_all_pathtable_files(pathtable) - - # ----------------------------------------------------------------------------- # Test Configs # ----------------------------------------------------------------------------- diff --git a/tests/tests_integration/__init__.py b/tests/tests_integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_integration/test_configs.py b/tests/tests_integration/test_configs.py index feff2900..24e15245 100644 --- a/tests/tests_integration/test_configs.py +++ b/tests/tests_integration/test_configs.py @@ -1,13 +1,14 @@ import os import pytest -import test_utils -from base import BaseTest from datashuttle import DataShuttle from datashuttle.utils import getters from datashuttle.utils.custom_exceptions import ConfigError +from .. import test_utils +from ..base import BaseTest + class TestConfigs(BaseTest): # Test Errors diff --git a/tests/tests_integration/test_create_folders.py b/tests/tests_integration/test_create_folders.py index e5062377..0ca15ab7 100644 --- a/tests/tests_integration/test_create_folders.py +++ b/tests/tests_integration/test_create_folders.py @@ -5,12 +5,13 @@ from os.path import join import pytest -import test_utils -from base import BaseTest from datashuttle.configs import canonical_configs, canonical_folders from datashuttle.configs.canonical_tags import tags +from .. import test_utils +from ..base import BaseTest + class TestCreateFolders(BaseTest): diff --git a/tests/tests_integration/test_datatypes.py b/tests/tests_integration/test_datatypes.py index 840fc26e..cf5c6582 100644 --- a/tests/tests_integration/test_datatypes.py +++ b/tests/tests_integration/test_datatypes.py @@ -1,11 +1,12 @@ import os import pytest -import test_utils -from base import BaseTest from datashuttle.configs import canonical_configs +from .. import test_utils +from ..base import BaseTest + class TestDatatypes(BaseTest): """ diff --git a/tests/tests_integration/test_formatting.py b/tests/tests_integration/test_formatting.py index bf62c6a5..466c4eb9 100644 --- a/tests/tests_integration/test_formatting.py +++ b/tests/tests_integration/test_formatting.py @@ -1,9 +1,10 @@ import pytest -from base import BaseTest from datashuttle.utils import formatting from datashuttle.utils.custom_exceptions import NeuroBlueprintError +from ..base import BaseTest + class TestFormatting(BaseTest): @pytest.mark.parametrize("prefix", ["sub", "ses"]) diff --git a/tests/tests_integration/test_local_only_mode.py b/tests/tests_integration/test_local_only_mode.py index 6df59233..30bd2736 100644 --- a/tests/tests_integration/test_local_only_mode.py +++ b/tests/tests_integration/test_local_only_mode.py @@ -1,14 +1,15 @@ import shutil import pytest -import test_utils -from base import BaseTest from datashuttle import DataShuttle from datashuttle.utils.custom_exceptions import ( ConfigError, ) +from .. import test_utils +from ..base import BaseTest + TEST_PROJECT_NAME = "test_project" diff --git a/tests/tests_integration/test_logging.py b/tests/tests_integration/test_logging.py index 3cf2d733..82e510df 100644 --- a/tests/tests_integration/test_logging.py +++ b/tests/tests_integration/test_logging.py @@ -5,7 +5,6 @@ from pathlib import Path import pytest -import test_utils from datashuttle import DataShuttle from datashuttle.configs import canonical_configs @@ -16,6 +15,8 @@ NeuroBlueprintError, ) +from .. import test_utils + class TestLogging: diff --git a/tests/tests_integration/test_settings.py b/tests/tests_integration/test_settings.py index 1f19bede..0a655bc9 100644 --- a/tests/tests_integration/test_settings.py +++ b/tests/tests_integration/test_settings.py @@ -2,13 +2,14 @@ import shutil import pytest -from base import BaseTest from datashuttle import DataShuttle from datashuttle.configs import canonical_configs from datashuttle.utils import validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError +from ..base import BaseTest + class TestPersistentSettings(BaseTest): diff --git a/tests/tests_integration/test_ssh_file_transfer.py b/tests/tests_integration/test_ssh_file_transfer.py deleted file mode 100644 index 393de076..00000000 --- a/tests/tests_integration/test_ssh_file_transfer.py +++ /dev/null @@ -1,311 +0,0 @@ -""" """ - -import copy -import glob -import shutil -import time -from pathlib import Path - -import pandas as pd -import pytest -import ssh_test_utils -import test_utils -from file_conflicts_pathtable import get_pathtable -from pytest import ssh_config - - -class TestFileTransfer: - @pytest.fixture( - scope="class", - params=[ # Set running SSH or local filesystem (see docstring). - # False, - pytest.param( - True, - marks=pytest.mark.skipif( - ssh_config.TEST_SSH is False, - reason="TEST_SSH is set to False.", - ), - ), - ], - ) - def pathtable_and_project(self, request, tmpdir_factory): - """ - Create a project for SSH testing. Setup - the project as normal, and switch configs - to use SSH connection. - - Although SSH is used for transfer, for SSH tests, - checking the created filepaths is always - done through the local filesystem for speed - and convenience. As such, the drive that is - SSH to must also be mounted and the path - supplied to the location SSH'd to. - - For speed, create the project once, - and all files to transfer. Then in the - test function, the folder are transferred. - Partial cleanup is done in the test function - i.e. deleting the central_path to which the - items have been transferred. This is achieved - by using "class" scope. - - NOTES - ----- - - Pytest params - The `params` key sets the - `params` attribute on the pytest `request` fixture. - This attribute is used to set the `testing_ssh` variable - to `True` or `False`. In the first run, this is set to - `False`, meaning local filesystem tests are run. In the - second run, this is set with a pytest parameter that is - `True` (i.e. SSH tests are run) but is skipped if `TEST_SSH` - in `ssh_config` (set in conftest.py` is `False`. - - - For convenience, files are transferred - with SSH and then checked through the local filesystem - mount. This is significantly easier than checking - everything through SFTP. However, on Windows the - mounted filesystem is quite slow to update, taking - a few seconds after SSH transfer. This makes the - tests run very slowly. We can get rid - of this limitation on linux. - """ - testing_ssh = request.param - tmp_path = tmpdir_factory.mktemp("test") - - if testing_ssh: - base_path = ssh_config.FILESYSTEM_PATH - central_path = ssh_config.SERVER_PATH - else: - base_path = tmp_path / "test with space" - central_path = base_path - test_project_name = "test_file_conflicts" - - project = test_utils.setup_project_fixture( - base_path, test_project_name - ) - - if testing_ssh: - ssh_test_utils.setup_project_for_ssh( - project, - test_utils.make_test_path( - central_path, "central", test_project_name - ), - ssh_config.CENTRAL_HOST_ID, - ssh_config.USERNAME, - ) - - # Initialise the SSH connection - ssh_test_utils.setup_hostkeys(project) - shutil.copy(ssh_config.SSH_KEY_PATH, project.cfg.file_path.parent) - - pathtable = get_pathtable(project.cfg["local_path"]) - test_utils.create_all_pathtable_files(pathtable) - project.testing_ssh = testing_ssh - - yield [pathtable, project] - - test_utils.teardown_project(project) - - if testing_ssh: - for result in glob.glob(ssh_config.FILESYSTEM_PATH): - shutil.rmtree(result) - - # ------------------------------------------------------------------------- - # Utils - # ------------------------------------------------------------------------- - - def central_from_local(self, path_): - return Path(str(copy.copy(path_)).replace("local", "central")) - - # ------------------------------------------------------------------------- - # Test File Transfer - All Options - # ------------------------------------------------------------------------- - - @pytest.mark.parametrize( - "sub_names", - [ - ["all"], - ["all_sub"], - ["all_non_sub"], - ["sub-001"], - ["sub-003_date-20231901"], - ["sub-002", "all_non_sub"], - ], - ) - @pytest.mark.parametrize( - "ses_names", - [ - ["all"], - ["all_non_ses"], - ["all_ses"], - ["ses-001"], - ["ses-002_random-key"], - ["all_non_ses", "ses-001"], - ], - ) - @pytest.mark.parametrize( - "datatype", - [ - ["all"], - ["all_non_datatype"], - ["all_datatype"], - ["behav"], - ["ephys"], - ["anat"], - ["funcimg"], - ["anat", "behav", "all_non_datatype"], - ], - ) - @pytest.mark.parametrize("upload_or_download", ["upload", "download"]) - def test_all_data_transfer_options( - self, - pathtable_and_project, - sub_names, - ses_names, - datatype, - upload_or_download, - ): - """ - Parse the arguments to filter the pathtable, getting - the files expected to be transferred passed on the arguments - Note files in sub/ses/datatype folders must be handled - separately to those in non-sub, non-ses, non-datatype folders - - see test_utils.swap_local_and_central_paths() for the logic - on setting up and swapping local / central paths for - upload / download tests. - """ - pathtable, project = pathtable_and_project - - transfer_function = test_utils.handle_upload_or_download( - project, - upload_or_download, - transfer_method="custom", - swap_last_folder_only=project.testing_ssh, - )[0] - - transfer_function( - "rawdata", sub_names, ses_names, datatype, init_log=False - ) - - if upload_or_download == "download": - test_utils.swap_local_and_central_paths( - project, swap_last_folder_only=project.testing_ssh - ) - - sub_names = self.parse_arguments(pathtable, sub_names, "sub") - ses_names = self.parse_arguments(pathtable, ses_names, "ses") - datatype = self.parse_arguments(pathtable, datatype, "datatype") - - # Filter pathtable to get files that were expected - # to be transferred - ( - sub_ses_dtype_arguments, - extra_arguments, - ) = self.make_pathtable_search_filter(sub_names, ses_names, datatype) - - datatype_folders = self.query_table(pathtable, sub_ses_dtype_arguments) - extra_folders = self.query_table(pathtable, extra_arguments) - - expected_paths = pd.concat([datatype_folders, extra_folders]) - expected_paths = expected_paths.drop_duplicates(subset="path") - - central_base_paths = expected_paths.base_folder.map( - lambda x: str(x).replace("local", "central") - ) - expected_transferred_paths = central_base_paths / expected_paths.path - - # When transferring with SSH, there is a delay before - # filesystem catches up - if project.testing_ssh: - time.sleep(0.5) - - # Check what paths were actually moved - # (through the local filesystem), and test - path_to_search = ( - self.central_from_local(project.cfg["local_path"]) / "rawdata" - ) - all_transferred = path_to_search.glob("**/*") - paths_to_transferred_files = list( - filter(Path.is_file, all_transferred) - ) - - assert sorted(paths_to_transferred_files) == sorted( - expected_transferred_paths - ) - - # Teardown here, because we have session scope. - try: - shutil.rmtree(self.central_from_local(project.cfg["local_path"])) - except FileNotFoundError: - pass - - # --------------------------------------------------------------------------------------------------------------- - # Utils - # --------------------------------------------------------------------------------------------------------------- - - def query_table(self, pathtable, arguments): - """ - Search the table for arguments, return empty - if arguments empty - """ - if any(arguments): - folders = pathtable.query(" | ".join(arguments)) - else: - folders = pd.DataFrame() - return folders - - def parse_arguments(self, pathtable, list_of_names, field): - """ - Replicate datashuttle name formatting by parsing - "all" arguments and turning them into a list of all names, - (subject or session), taken from the pathtable. - """ - if list_of_names in [["all"], [f"all_{field}"]]: - entries = pathtable.query(f"parent_{field} != False")[ - f"parent_{field}" - ] - entries = list(set(entries)) - if list_of_names == ["all"]: - entries += ( - [f"all_non_{field}"] - if field != "datatype" - else ["all_non_datatype"] - ) - list_of_names = entries - return list_of_names - - def make_pathtable_search_filter(self, sub_names, ses_names, datatype): - """ - Create a string of arguments to pass to pd.query() that will - create the table of only transferred sub, ses and datatype. - - Two arguments must be created, one of all sub / ses / datatypes - and the other of all non sub/ non ses / non datatype - folders. These must be handled separately as they are - mutually exclusive. - """ - sub_ses_dtype_arguments = [] - extra_arguments = [] - - for sub in sub_names: - if sub == "all_non_sub": - extra_arguments += ["is_non_sub == True"] - else: - for ses in ses_names: - if ses == "all_non_ses": - extra_arguments += [ - f"(parent_sub == '{sub}' & is_non_ses == True)" - ] - else: - for dtype in datatype: - if dtype == "all_non_datatype": - extra_arguments += [ - f"(parent_sub == '{sub}' & parent_ses == '{ses}' & is_ses_level_non_datatype == True)" - ] - else: - sub_ses_dtype_arguments += [ - f"(parent_sub == '{sub}' & parent_ses == '{ses}' & (parent_datatype == '{dtype}' | parent_datatype == '{dtype}'))" - ] - - return sub_ses_dtype_arguments, extra_arguments diff --git a/tests/tests_integration/test_validation.py b/tests/tests_integration/test_validation.py index 4ef49688..6b294d5c 100644 --- a/tests/tests_integration/test_validation.py +++ b/tests/tests_integration/test_validation.py @@ -2,12 +2,13 @@ import shutil import pytest -from base import BaseTest from datashuttle import quick_validate_project from datashuttle.utils import formatting, validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError +from ..base import BaseTest + # ----------------------------------------------------------------------------- # Inconsistent sub or ses value lengths # ----------------------------------------------------------------------------- diff --git a/tests/tests_regression/__init__.py b/tests/tests_regression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_regression/test_backwards_compatibility.py b/tests/tests_regression/test_backwards_compatibility.py index 7e243992..9a351414 100644 --- a/tests/tests_regression/test_backwards_compatibility.py +++ b/tests/tests_regression/test_backwards_compatibility.py @@ -3,10 +3,11 @@ from pathlib import Path import pytest -import test_utils from datashuttle import DataShuttle +from .. import test_utils + TEST_PROJECT_NAME = "test_project" diff --git a/tests/tests_transfers/__init__.py b/tests/tests_transfers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_transfers/base_transfer.py b/tests/tests_transfers/base_transfer.py new file mode 100644 index 00000000..32dac6de --- /dev/null +++ b/tests/tests_transfers/base_transfer.py @@ -0,0 +1,166 @@ +""" """ + +import copy +from pathlib import Path + +import pandas as pd +import pytest + +from .. import test_utils +from ..base import BaseTest +from .file_conflicts_pathtable import get_pathtable + + +class BaseTransfer(BaseTest): + """ + Class holding fixtures and methods for testing the + custom transfers with keys (e.g. all_non_sub). + """ + + @pytest.fixture( + scope="class", + ) + def pathtable_and_project(self, tmpdir_factory): + """ + Create a new test project with a test project folder + and file structure (see `get_pathtable()` for definition). + """ + tmp_path = tmpdir_factory.mktemp("test") + + base_path = tmp_path / "test with space" + test_project_name = "test_file_conflicts" + + project = test_utils.setup_project_fixture( + base_path, test_project_name + ) + + pathtable = get_pathtable(project.cfg["local_path"]) + + self.create_all_pathtable_files(pathtable) + + yield [pathtable, project] + + test_utils.teardown_project(project) + + def get_expected_transferred_paths( + self, pathtable, sub_names, ses_names, datatype + ): + """ + Process the expected files that are transferred using the logic in + `make_pathtable_search_filter()` to + """ + parsed_sub_names = self.parse_arguments(pathtable, sub_names, "sub") + parsed_ses_names = self.parse_arguments(pathtable, ses_names, "ses") + parsed_datatype = self.parse_arguments(pathtable, datatype, "datatype") + + # Filter pathtable to get files that were expected to be transferred + ( + sub_ses_dtype_arguments, + extra_arguments, + ) = self.make_pathtable_search_filter( + parsed_sub_names, parsed_ses_names, parsed_datatype + ) + + datatype_folders = self.query_table(pathtable, sub_ses_dtype_arguments) + extra_folders = self.query_table(pathtable, extra_arguments) + + expected_paths = pd.concat([datatype_folders, extra_folders]) + expected_paths = expected_paths.drop_duplicates(subset="path") + + expected_paths = self.remove_path_before_rawdata(expected_paths.path) + + return expected_paths + + def make_pathtable_search_filter(self, sub_names, ses_names, datatype): + """ + Create a string of arguments to pass to pd.query() that will + create the table of only transferred sub, ses and datatype. + + Two arguments must be created, one of all sub / ses / datatypes + and the other of all non sub/ non ses / non datatype + folders. These must be handled separately as they are + mutually exclusive. + """ + sub_ses_dtype_arguments = [] + extra_arguments = [] + + for sub in sub_names: + if sub == "all_non_sub": + extra_arguments += ["is_non_sub == True"] + else: + for ses in ses_names: + if ses == "all_non_ses": + extra_arguments += [ + f"(parent_sub == '{sub}' & is_non_ses == True)" + ] + else: + for dtype in datatype: + if dtype == "all_non_datatype": + extra_arguments += [ + f"(parent_sub == '{sub}' & parent_ses == '{ses}' " + f"& is_ses_level_non_datatype == True)" + ] + else: + sub_ses_dtype_arguments += [ + f"(parent_sub == '{sub}' & parent_ses == '{ses}' " + f"& (parent_datatype == '{dtype}' " + f"| parent_datatype == '{dtype}'))" + ] + + return sub_ses_dtype_arguments, extra_arguments + + def remove_path_before_rawdata(self, list_of_paths): + """ + Remove the path to project files before the "rawdata" so + they can be compared no matter where the project was stored + (e.g. on a central server vs. local filesystem). + """ + cut_paths = [] + for path_ in list_of_paths: + parts = Path(path_).parts + cut_paths.append(Path(*parts[parts.index("rawdata") :])) + return cut_paths + + def query_table(self, pathtable, arguments): + """ + Search the table for arguments, return empty + if arguments empty + """ + if any(arguments): + folders = pathtable.query(" | ".join(arguments)) + else: + folders = pd.DataFrame() + return folders + + def parse_arguments(self, pathtable, list_of_names, field): + """ + Replicate datashuttle name formatting by parsing + "all" arguments and turning them into a list of all names, + (subject or session), taken from the pathtable. + """ + if list_of_names in [["all"], [f"all_{field}"]]: + entries = pathtable.query(f"parent_{field} != False")[ + f"parent_{field}" + ] + entries = list(set(entries)) + if list_of_names == ["all"]: + entries += ( + [f"all_non_{field}"] + if field != "datatype" + else ["all_non_datatype"] + ) + list_of_names = entries + return list_of_names + + def create_all_pathtable_files(self, pathtable): + """ + Create the entire test project in the defined + location (usually project's `local_path`). + """ + for i in range(pathtable.shape[0]): + filepath = pathtable["base_folder"][i] / pathtable["path"][i] + filepath.parents[0].mkdir(parents=True, exist_ok=True) + test_utils.write_file(filepath, contents="test_entry") + + def central_from_local(self, path_): + return Path(str(copy.copy(path_)).replace("local", "central")) diff --git a/tests/file_conflicts_pathtable.py b/tests/tests_transfers/file_conflicts_pathtable.py similarity index 100% rename from tests/file_conflicts_pathtable.py rename to tests/tests_transfers/file_conflicts_pathtable.py diff --git a/tests/tests_transfers/local_filesystem/__init__.py b/tests/tests_transfers/local_filesystem/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_integration/test_filesystem_transfer.py b/tests/tests_transfers/local_filesystem/test_transfer.py similarity index 99% rename from tests/tests_integration/test_filesystem_transfer.py rename to tests/tests_transfers/local_filesystem/test_transfer.py index 8cbff488..069ce346 100644 --- a/tests/tests_integration/test_filesystem_transfer.py +++ b/tests/tests_transfers/local_filesystem/test_transfer.py @@ -4,13 +4,14 @@ from pathlib import Path import pytest -import test_utils -from base import BaseTest from datashuttle.configs import canonical_folders from datashuttle.configs.canonical_configs import get_broad_datatypes from datashuttle.configs.canonical_tags import tags +from ... import test_utils +from ...base import BaseTest + class TestFileTransfer(BaseTest): diff --git a/tests/tests_integration/test_transfer_checks.py b/tests/tests_transfers/local_filesystem/test_transfer_checks.py similarity index 98% rename from tests/tests_integration/test_transfer_checks.py rename to tests/tests_transfers/local_filesystem/test_transfer_checks.py index 342ed640..2f5c909d 100644 --- a/tests/tests_integration/test_transfer_checks.py +++ b/tests/tests_transfers/local_filesystem/test_transfer_checks.py @@ -3,11 +3,12 @@ from pathlib import Path import pytest -import test_utils -from base import BaseTest from datashuttle.utils.rclone import get_local_and_central_file_differences +from ... import test_utils +from ...base import BaseTest + class TestTransferChecks(BaseTest): @pytest.mark.parametrize( diff --git a/tests/tests_transfers/local_filesystem/test_transfer_special_arguments.py b/tests/tests_transfers/local_filesystem/test_transfer_special_arguments.py new file mode 100644 index 00000000..03e9295f --- /dev/null +++ b/tests/tests_transfers/local_filesystem/test_transfer_special_arguments.py @@ -0,0 +1,112 @@ +""" """ + +import shutil +from pathlib import Path + +import pytest + +from ... import test_utils +from ..base_transfer import BaseTransfer + +PARAM_SUBS = [ + ["all"], + ["all_sub"], + ["all_non_sub"], + ["sub-001"], + ["sub-003_date-20231201"], + ["sub-002", "all_non_sub"], +] +PARAM_SES = [ + ["all"], + ["all_non_ses"], + ["all_ses"], + ["ses-001"], + ["ses-002_random-key"], + ["all_non_ses", "ses-001"], +] +PARAM_DATATYPE = [ + ["all"], + ["all_non_datatype"], + ["all_datatype"], + ["behav"], + ["ephys"], + ["anat"], + ["funcimg"], + ["anat", "behav", "all_non_datatype"], +] + + +class TestFileTransfer(BaseTransfer): + + # ---------------------------------------------------------------------------------- + # Test File Transfer - All Options + # ---------------------------------------------------------------------------------- + + @pytest.mark.parametrize("sub_names", PARAM_SUBS) + @pytest.mark.parametrize("ses_names", PARAM_SES) + @pytest.mark.parametrize("datatype", PARAM_DATATYPE) + @pytest.mark.parametrize("upload_or_download", ["upload", "download"]) + def test_combinations_filesystem_transfer( + self, + pathtable_and_project, + sub_names, + ses_names, + datatype, + upload_or_download, + ): + """ + Test many combinations of possible file transfer commands. The + entire test project is created in the original `local_path` + and subset of it is uploaded and tested against. To test + upload vs. download, the `local_path` and `central_path` + locations are swapped. + """ + pathtable, project = pathtable_and_project + + # Transfer the data, swapping the paths to move a subset of + # files from the already set up directory to a new directory + # using upload or download. + transfer_function = test_utils.handle_upload_or_download( + project, + upload_or_download, + transfer_method="custom", + swap_last_folder_only=False, + )[0] + + transfer_function( + "rawdata", sub_names, ses_names, datatype, init_log=False + ) + + if upload_or_download == "download": + test_utils.swap_local_and_central_paths( + project, swap_last_folder_only=False + ) + + expected_transferred_paths = self.get_expected_transferred_paths( + pathtable, sub_names, ses_names, datatype + ) + + # Check what paths were actually moved + # (through the local filesystem), and test + path_to_search = ( + self.central_from_local(project.cfg["local_path"]) / "rawdata" + ) + all_transferred = path_to_search.glob("**/*") + + paths_to_transferred_files = list( + filter(Path.is_file, all_transferred) + ) + + paths_to_transferred_files = self.remove_path_before_rawdata( + paths_to_transferred_files + ) + + assert sorted(paths_to_transferred_files) == sorted( + expected_transferred_paths + ) + + # Teardown here, because we have session scope. + try: + shutil.rmtree(self.central_from_local(project.cfg["local_path"])) + except FileNotFoundError: + pass diff --git a/tests/tests_transfers/ssh/__init__.py b/tests/tests_transfers/ssh/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_transfers/ssh/base_ssh.py b/tests/tests_transfers/ssh/base_ssh.py new file mode 100644 index 00000000..0a8c52b9 --- /dev/null +++ b/tests/tests_transfers/ssh/base_ssh.py @@ -0,0 +1,76 @@ +""" """ + +import os +import platform +import subprocess +from pathlib import Path + +import pytest + +from ..base_transfer import BaseTransfer +from . import ssh_test_utils + +# Choose port 3306 for running on GH actions +# suggested in https://github.com/orgs/community/discussions/25550 +PORT = 3306 +os.environ["DS_SSH_PORT"] = str(PORT) + + +class BaseSSHTransfer(BaseTransfer): + """ + Class holding fixtures and methods for testing the + custom transfers with keys (e.g. all_non_sub). + """ + + @pytest.fixture( + scope="class", + ) + def setup_ssh_container(self): + """ + Set up the Dockerfile container for SSH tests and + delete it on teardown. + """ + container_name = "datashuttle_ssh_tests" + + assert ssh_test_utils.docker_is_running(), ( + "docker is not running, " + "this should be checked at the top of test script" + ) + + image_path = Path(__file__).parent / "ssh_test_images" + os.chdir(image_path) + + if platform.system() != "Windows": + build_command = "sudo docker build -t ssh_server ." + run_command = ( + f"sudo docker run -d -p {PORT}:22 " + f"--name {container_name} ssh_server" + ) + else: + build_command = "docker build -t ssh_server ." + run_command = f"docker run -d -p {PORT}:22 --name {container_name} ssh_server" + + build_output = subprocess.run( + build_command, + shell=True, + capture_output=True, + ) + assert build_output.returncode == 0, ( + f"docker build failed with: STDOUT-{build_output.stdout} " + f"STDERR-{build_output.stderr}" + ) + + run_output = subprocess.run( + run_command, + shell=True, + capture_output=True, + ) + + assert run_output.returncode == 0, ( + f"docker run failed with: STDOUT-{run_output.stdout} " + f"STDERR-{run_output.stderr}" + ) + + yield + + subprocess.run(f"docker rm -f {container_name}", shell=True) diff --git a/tests/tests_transfers/ssh/ssh_test_images/Dockerfile b/tests/tests_transfers/ssh/ssh_test_images/Dockerfile new file mode 100644 index 00000000..474c8ecb --- /dev/null +++ b/tests/tests_transfers/ssh/ssh_test_images/Dockerfile @@ -0,0 +1,25 @@ +# Use a base image with the desired OS (e.g., Ubuntu, Debian, etc.) +FROM ubuntu:latest + +# Install SSH server +RUN apt-get update && \ + apt-get upgrade -y +RUN apt-get install openssh-server -y supervisor +RUN apt-get install nano + +# Create an SSH user +RUN useradd -rm -d /home/sshuser -s /bin/bash -g root -G sudo sshuser + +# Set the SSH user's password (replace "password" with your desired password) +RUN echo "sshuser:password" | chpasswd + +# Allow SSH access +RUN mkdir /var/run/sshd + +RUN /usr/bin/ssh-keygen -A + +# Expose the SSH port +EXPOSE 22 + +# Start SSH server on container startup +CMD ["/usr/sbin/sshd", "-D"] diff --git a/tests/tests_transfers/ssh/ssh_test_utils.py b/tests/tests_transfers/ssh/ssh_test_utils.py new file mode 100644 index 00000000..aa605019 --- /dev/null +++ b/tests/tests_transfers/ssh/ssh_test_utils.py @@ -0,0 +1,151 @@ +import builtins +import copy +import stat +import subprocess +import sys +import warnings + +import paramiko + +from datashuttle.utils import rclone, ssh + + +def setup_project_for_ssh( + project, +): + """ + Set up the project configs to use + SSH connection to central. The settings + set up a connection to the Dockerfile image + found in /ssh_test_images. + """ + project.update_config_file( + connection_method="ssh", + central_path=f"/home/sshuser/datashuttle/{project.project_name}", + central_host_id="localhost", + central_host_username="sshuser", + ) + rclone.setup_rclone_config_for_ssh( + project.cfg, + project.cfg.get_rclone_config_name("ssh"), + project.cfg.ssh_key_path, + ) + + +def setup_ssh_connection(project, setup_ssh_key_pair=True): + """ + Convenience function to verify the server hostkey and ssh + key pairs to the Dockerfile image for ssh tests. + + This requires monkeypatching a number of functions involved + in the SSH setup process. `input()` is patched to always + return the required hostkey confirmation "y". `getpass()` is + patched to always return the password for the container in which + SSH tests are run. `isatty()` is patched because when running this + for some reason it appears to be in a TTY - this might be a + container thing. + """ + # Monkeypatch + orig_builtin = copy.deepcopy(builtins.input) + builtins.input = lambda _: "y" # type: ignore + + orig_getpass = copy.deepcopy(ssh.getpass.getpass) + ssh.getpass.getpass = lambda _: "password" # type: ignore + + orig_isatty = copy.deepcopy(sys.stdin.isatty) + sys.stdin.isatty = lambda: True + + # Run setup + verified = ssh.verify_ssh_central_host( + project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True + ) + + if setup_ssh_key_pair: + ssh.setup_ssh_key(project.cfg, log=False) + + # Restore functions + builtins.input = orig_builtin + ssh.getpass.getpass = orig_getpass + sys.stdin.isatty = orig_isatty + + return verified + + +def recursive_search_central(project): + """ + A convenience function to recursively search a + project for files through SSH, used during testing + across an SSH connection to collected names of + files that were transferred. + """ + with paramiko.SSHClient() as client: + ssh.connect_client_core(client, project.cfg) + + sftp = client.open_sftp() + + all_filenames = [] + + sftp_recursive_file_search( + sftp, + (project.cfg["central_path"] / "rawdata").as_posix(), + all_filenames, + ) + return all_filenames + + +def sftp_recursive_file_search(sftp, path_, all_filenames): + """ + Append all filenames found within a folder, + when searching over a sftp connection. + """ + try: + sftp.stat(path_) + except FileNotFoundError: + return + + for file_or_folder in sftp.listdir_attr(path_): + if stat.S_ISDIR(file_or_folder.st_mode): + sftp_recursive_file_search( + sftp, + path_ + "/" + file_or_folder.filename, + all_filenames, + ) + else: + all_filenames.append(path_ + "/" + file_or_folder.filename) + + +def get_test_ssh(): + """ + Return bool indicating whether Docker is installed and running, + which is required for ssh tests. + """ + docker_installed = docker_is_running() + if not docker_installed: + warnings.warn( + "SSH tests are not run as docker either not installed or running." + ) + return docker_installed + + +def docker_is_running(): + if not is_docker_installed(): + return False + + is_running = check_sys_command_returns_0("docker stats --no-stream") + return is_running + + +def is_docker_installed(): + return check_sys_command_returns_0("docker -v") + + +def check_sys_command_returns_0(command): + return ( + subprocess.run( + command, + shell=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ).returncode + == 0 + ) diff --git a/tests/tests_integration/test_ssh_setup.py b/tests/tests_transfers/ssh/test_ssh_setup.py similarity index 56% rename from tests/tests_integration/test_ssh_setup.py rename to tests/tests_transfers/ssh/test_ssh_setup.py index 752985a6..3d823969 100644 --- a/tests/tests_integration/test_ssh_setup.py +++ b/tests/tests_transfers/ssh/test_ssh_setup.py @@ -1,35 +1,38 @@ -""" -SSH configs are set in conftest.py . The password -should be stored in a file called test_ssh_password.txt located -in the same folder as test_ssh.py -""" +import builtins +import copy +import platform import pytest -import ssh_test_utils -import test_utils -from pytest import ssh_config from datashuttle.utils import ssh +from ... import test_utils +from . import ssh_test_utils +from .base_ssh import BaseSSHTransfer + +TEST_SSH = ssh_test_utils.get_test_ssh() + + +@pytest.mark.skipif( + platform.system == "Darwin", reason="Docker set up is not robust on macOS." +) +@pytest.mark.skipif(not TEST_SSH, reason="TEST_SSH is false") +class TestSSH(BaseSSHTransfer): -@pytest.mark.skipif(ssh_config.TEST_SSH is False, reason="TEST_SSH is false") -class TestSSH: @pytest.fixture(scope="function") - def project(test, tmp_path): + def project(test, tmp_path, setup_ssh_container): """ - Make a project as per usual, but now add - in test ssh configurations + Set up a project with configs for SSH into + the test Dockerfile image. """ tmp_path = tmp_path / "test with space" test_project_name = "test_ssh" + project = test_utils.setup_project_fixture(tmp_path, test_project_name) ssh_test_utils.setup_project_for_ssh( project, - ssh_config.FILESYSTEM_PATH, - ssh_config.CENTRAL_HOST_ID, - ssh_config.USERNAME, ) yield project @@ -44,18 +47,14 @@ def test_verify_ssh_central_host_do_not_accept( self, capsys, project, input_ ): """ - Use the main function to test this. Test the sub-function - when accepting, because this main function will also - call setup ssh key pairs which we don't want to do yet - - This should only accept for "y" so try some random strings - including "n" and check they all do not make the connection. + Test that host not accepted if input is not "y". """ - orig_builtin = ssh_test_utils.setup_mock_input(input_) + orig_builtin = copy.deepcopy(builtins.input) + builtins.input = lambda _: input_ # type: ignore project.setup_ssh_connection() - ssh_test_utils.restore_mock_input(orig_builtin) + builtins.input = orig_builtin captured = capsys.readouterr() @@ -67,27 +66,27 @@ def test_verify_ssh_central_host_accept(self, capsys, project): and check hostkey is successfully accepted and written to configs. """ test_utils.clear_capsys(capsys) - orig_builtin = ssh_test_utils.setup_mock_input(input_="y") - verified = ssh.verify_ssh_central_host( - project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True + verified = ssh_test_utils.setup_ssh_connection( + project, setup_ssh_key_pair=False ) - ssh_test_utils.restore_mock_input(orig_builtin) - assert verified captured = capsys.readouterr() + assert captured.out == "Host accepted.\n" with open(project.cfg.hostkeys_path, "r") as file: hostkey = file.readlines()[0] - assert f"{project.cfg['central_host_id']} ssh-ed25519 " in hostkey + assert ( + f"[{project.cfg['central_host_id']}]:3306 ssh-ed25519 " in hostkey + ) def test_generate_and_write_ssh_key(self, project): """ - Check ssh key for passwordless connection is written - to file + Check ssh key for passwordless connection + is written to file. """ path_to_save = project.cfg["local_path"] / "test" ssh.generate_and_write_ssh_key(path_to_save) diff --git a/tests/tests_transfers/ssh/test_ssh_transfer.py b/tests/tests_transfers/ssh/test_ssh_transfer.py new file mode 100644 index 00000000..60ce29c2 --- /dev/null +++ b/tests/tests_transfers/ssh/test_ssh_transfer.py @@ -0,0 +1,149 @@ +import platform +import shutil + +import paramiko +import pytest + +from datashuttle.utils import ssh + +from . import ssh_test_utils +from .base_ssh import BaseSSHTransfer + +TEST_SSH = ssh_test_utils.get_test_ssh() + + +@pytest.mark.skipif( + platform.system == "Darwin", reason="Docker set up is not robust on macOS." +) +@pytest.mark.skipif(not TEST_SSH, reason="TEST_SSH is false") +class TestSSHTransfer(BaseSSHTransfer): + + @pytest.fixture( + scope="class", + ) + def ssh_setup(self, pathtable_and_project, setup_ssh_container): + """ + After initial project setup (in `pathtable_and_project`) + setup a container and the project's SSH connection to the container. + Then upload the test project to the `central_path`. + """ + pathtable, project = pathtable_and_project + + ssh_test_utils.setup_project_for_ssh( + project, + ) + ssh_test_utils.setup_ssh_connection(project) + + project.upload_rawdata() + + return [pathtable, project] + + # ----------------------------------------------------------------- + # Test Setup SSH Connection + # ----------------------------------------------------------------- + + @pytest.mark.parametrize( + "sub_names", [["all"], ["all_non_sub", "sub-002"]] + ) + @pytest.mark.parametrize( + "ses_names", [["all"], ["ses-002_random-key"], ["all_non_ses"]] + ) + @pytest.mark.parametrize( + "datatype", [["all"], ["anat", "all_non_datatype"]] + ) + def test_combinations_ssh_transfer( + self, + ssh_setup, + sub_names, + ses_names, + datatype, + ): + """ + Test a subset of argument combinations while testing over SSH connection + to a container. This is very slow, due to the rclone ssh transfer (which + is performed twice in this test, once for upload, once for download), around + 8 seconds per parameterization. + + In test setup, the entire project is created in the `local_path` and + is uploaded to `central_path`. So we only need to set up once per test, + upload and download is to temporary folders and these temporary folders + are cleaned at the end of each parameterization. + """ + pathtable, project = ssh_setup + + # Upload data from the setup local project to a temporary + # central directory. + true_central_path = project.cfg["central_path"] + tmp_central_path = ( + project.cfg["central_path"] / "tmp" / project.project_name + ) + self.remake_logging_path(project) + + project.update_config_file(central_path=tmp_central_path) + + project.upload_custom( + "rawdata", sub_names, ses_names, datatype, init_log=False + ) + + expected_transferred_paths = self.get_expected_transferred_paths( + pathtable, sub_names, ses_names, datatype + ) + + # Search the paths that were transferred and tidy them up, + # then check against the paths that were expected to be transferred. + transferred_files = ssh_test_utils.recursive_search_central(project) + paths_to_transferred_files = self.remove_path_before_rawdata( + transferred_files + ) + + assert sorted(paths_to_transferred_files) == sorted( + expected_transferred_paths + ) + + # Now, move data from the central path where the project is + # setup, to a temp local folder to test download. + true_local_path = project.cfg["local_path"] + tmp_local_path = ( + project.cfg["local_path"] / "tmp" / project.project_name + ) + tmp_local_path.mkdir(exist_ok=True, parents=True) + + project.update_config_file(local_path=tmp_local_path) + project.update_config_file(central_path=true_central_path) + + project.download_custom( + "rawdata", sub_names, ses_names, datatype, init_log=False + ) + + # Find the transferred paths, tidy them up + # and check expected paths were transferred. + all_transferred = list((tmp_local_path / "rawdata").glob("**/*")) + all_transferred = [ + path_ for path_ in all_transferred if path_.is_file() + ] + + paths_to_transferred_files = self.remove_path_before_rawdata( + all_transferred + ) + + assert sorted(paths_to_transferred_files) == sorted( + expected_transferred_paths + ) + + # Clean up, removing the temp directories and + # resetting the project paths. + with paramiko.SSHClient() as client: + ssh.connect_client_core(client, project.cfg) + client.exec_command(f"rm -rf {(tmp_central_path).as_posix()}") + + shutil.rmtree(tmp_local_path) + + self.remake_logging_path(project) + project.update_config_file(local_path=true_local_path) + + def remake_logging_path(self, project): + """ + Need to do this to compensate for switching + local_path location in the test environment. + """ + project.get_logging_path().mkdir(parents=True, exist_ok=True) diff --git a/tests/tests_tui/__init__.py b/tests/tests_tui/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_tui/test_local_only_project.py b/tests/tests_tui/test_local_only_project.py index 0599ea14..c3e87385 100644 --- a/tests/tests_tui/test_local_only_project.py +++ b/tests/tests_tui/test_local_only_project.py @@ -1,8 +1,9 @@ import pytest -from tui_base import TuiBase from datashuttle.tui.app import TuiApp +from .tui_base import TuiBase + class TestTuiLocalOnlyProject(TuiBase): diff --git a/tests/tests_tui/test_tui_configs.py b/tests/tests_tui/test_tui_configs.py index e8f5d076..fc332078 100644 --- a/tests/tests_tui/test_tui_configs.py +++ b/tests/tests_tui/test_tui_configs.py @@ -3,8 +3,6 @@ from time import monotonic import pytest -import test_utils -from tui_base import TuiBase from datashuttle.configs import load_configs from datashuttle.tui.app import TuiApp @@ -13,6 +11,9 @@ ) from datashuttle.tui.screens.project_manager import ProjectManagerScreen +from .. import test_utils +from .tui_base import TuiBase + class TestTuiConfigs(TuiBase): diff --git a/tests/tests_tui/test_tui_create_folders.py b/tests/tests_tui/test_tui_create_folders.py index 85a06cd5..9e4df32c 100644 --- a/tests/tests_tui/test_tui_create_folders.py +++ b/tests/tests_tui/test_tui_create_folders.py @@ -1,8 +1,6 @@ import re import pytest -import test_utils -from tui_base import TuiBase from datashuttle.configs import canonical_configs from datashuttle.tui.app import TuiApp @@ -11,6 +9,9 @@ ) from datashuttle.tui.screens.project_manager import ProjectManagerScreen +from .. import test_utils +from .tui_base import TuiBase + class TestTuiCreateFolders(TuiBase): diff --git a/tests/tests_tui/test_tui_datatypes.py b/tests/tests_tui/test_tui_datatypes.py index 0e9ec8a6..bc72bea9 100644 --- a/tests/tests_tui/test_tui_datatypes.py +++ b/tests/tests_tui/test_tui_datatypes.py @@ -1,10 +1,11 @@ import pytest -import test_utils -from tui_base import TuiBase from datashuttle.configs import canonical_configs from datashuttle.tui.app import TuiApp +from .. import test_utils +from .tui_base import TuiBase + class TestDatatypesTUI(TuiBase): """ diff --git a/tests/tests_tui/test_tui_directorytree.py b/tests/tests_tui/test_tui_directorytree.py index 1a6e03d8..81493810 100644 --- a/tests/tests_tui/test_tui_directorytree.py +++ b/tests/tests_tui/test_tui_directorytree.py @@ -2,10 +2,11 @@ import pyperclip import pytest -from tui_base import TuiBase from datashuttle.tui.app import TuiApp +from .tui_base import TuiBase + try: pyperclip.paste() HAS_GUI = True diff --git a/tests/tests_tui/test_tui_get_help.py b/tests/tests_tui/test_tui_get_help.py index 55d876f3..18ae5354 100644 --- a/tests/tests_tui/test_tui_get_help.py +++ b/tests/tests_tui/test_tui_get_help.py @@ -1,8 +1,9 @@ import pytest -from tui_base import TuiBase from datashuttle.tui.app import TuiApp +from .tui_base import TuiBase + class TestTuiSettings(TuiBase): """ diff --git a/tests/tests_tui/test_tui_logging.py b/tests/tests_tui/test_tui_logging.py index cb4b8417..fe2cdad1 100644 --- a/tests/tests_tui/test_tui_logging.py +++ b/tests/tests_tui/test_tui_logging.py @@ -1,10 +1,11 @@ import pytest -from tui_base import TuiBase from datashuttle import DataShuttle from datashuttle.tui.app import TuiApp from datashuttle.tui.tabs.logging import RichLogScreen +from .tui_base import TuiBase + class TestTuiLogging(TuiBase): diff --git a/tests/tests_tui/test_tui_settings.py b/tests/tests_tui/test_tui_settings.py index 77a65b09..98d95c02 100644 --- a/tests/tests_tui/test_tui_settings.py +++ b/tests/tests_tui/test_tui_settings.py @@ -1,8 +1,9 @@ import pytest -from tui_base import TuiBase from datashuttle.tui.app import TuiApp +from .tui_base import TuiBase + class TestTuiSettings(TuiBase): """ diff --git a/tests/tests_tui/test_tui_transfer.py b/tests/tests_tui/test_tui_transfer.py index 3e8c389f..925f457b 100644 --- a/tests/tests_tui/test_tui_transfer.py +++ b/tests/tests_tui/test_tui_transfer.py @@ -1,10 +1,11 @@ import pytest -import test_utils -from tui_base import TuiBase from datashuttle.configs import canonical_configs from datashuttle.tui.app import TuiApp +from .. import test_utils +from .tui_base import TuiBase + class TestTuiTransfer(TuiBase): """ diff --git a/tests/tests_tui/test_tui_validate.py b/tests/tests_tui/test_tui_validate.py index 0972c975..fa07c1c3 100644 --- a/tests/tests_tui/test_tui_validate.py +++ b/tests/tests_tui/test_tui_validate.py @@ -1,10 +1,11 @@ import pytest import textual -from tui_base import TuiBase import datashuttle from datashuttle.tui.app import TuiApp +from .tui_base import TuiBase + class TestTuiValidate(TuiBase): diff --git a/tests/tests_tui/test_tui_widgets_and_defaults.py b/tests/tests_tui/test_tui_widgets_and_defaults.py index 73f06a22..5651cef8 100644 --- a/tests/tests_tui/test_tui_widgets_and_defaults.py +++ b/tests/tests_tui/test_tui_widgets_and_defaults.py @@ -2,7 +2,6 @@ from typing import Union import pytest -from tui_base import TuiBase from datashuttle import DataShuttle from datashuttle.configs import canonical_configs @@ -12,6 +11,8 @@ ) from datashuttle.tui.screens.new_project import NewProjectScreen +from .tui_base import TuiBase + class TestTuiWidgets(TuiBase): """ diff --git a/tests/tests_tui/tui_base.py b/tests/tests_tui/tui_base.py index 510b943b..6b5df9ba 100644 --- a/tests/tests_tui/tui_base.py +++ b/tests/tests_tui/tui_base.py @@ -1,11 +1,12 @@ import pytest_asyncio -import test_utils from textual.widgets._tabbed_content import ContentTab from datashuttle.configs import canonical_configs from datashuttle.tui.screens.project_manager import ProjectManagerScreen from datashuttle.tui.screens.project_selector import ProjectSelectorScreen +from .. import test_utils + class TuiBase: """ diff --git a/tests/tests_unit/__init__.py b/tests/tests_unit/__init__.py new file mode 100644 index 00000000..e69de29b