diff --git a/sagemaker_shim/models.py b/sagemaker_shim/models.py index 1d85ceb..17deea5 100644 --- a/sagemaker_shim/models.py +++ b/sagemaker_shim/models.py @@ -144,7 +144,7 @@ def _get_group_id(id_or_name: str) -> int | None: def clean_path(path: Path) -> None: for f in path.glob("*"): - if f.is_file(): + if f.is_symlink() or f.is_file(): f.chmod(0o700) f.unlink() elif f.is_dir(): @@ -493,6 +493,19 @@ def input_path(self) -> Path: logger.debug(f"{input_path=}") return input_path + @property + def linked_input_path(self) -> Path: + """Local path where the input files will be placed and linked to""" + linked_input_parent = Path( + os.environ.get( + "GRAND_CHALLENGE_COMPONENT_LINKED_INPUT_PARENT", + "/opt/ml/input/data/", + ) + ) + linked_input_path = linked_input_parent / f"{self.pk}-input" + logger.debug(f"{linked_input_path=}") + return linked_input_path + @property def output_path(self) -> Path: """Local path where the subprocess is expected to write its files""" @@ -590,7 +603,7 @@ async def _invoke(self) -> InferenceResult: logger.info(f"Invoking {self.pk=}") try: - self.clean_io() + self.reset_io() try: self.download_input() @@ -609,12 +622,43 @@ async def _invoke(self) -> InferenceResult: pk=self.pk, return_code=return_code, outputs=outputs ) finally: - self.clean_io() + self.reset_io() - def clean_io(self) -> None: - """Clean all contents of input and output folders""" + def reset_io(self) -> None: + """Resets the input and output directories""" clean_path(path=self.input_path) clean_path(path=self.output_path) + self.reset_linked_input() + + def reset_linked_input(self) -> None: + """Resets the symlink from the input to the linked directory""" + if ( + os.environ.get( + "GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "True" + ).lower() + == "true" + ): + logger.info( + f"Setting up linked input from {self.input_path} " + f"to {self.linked_input_path}" + ) + + if self.input_path.exists(): + if self.input_path.is_symlink(): + self.input_path.unlink() + elif self.input_path.is_dir(): + self.input_path.rmdir() + + if self.linked_input_path.exists(): + self.linked_input_path.rmdir() + + self.linked_input_path.mkdir(parents=True) + self.linked_input_path.chmod(0o755) + + self.input_path.symlink_to( + self.linked_input_path, target_is_directory=True + ) + self.input_path.chmod(0o755) def download_input(self) -> None: """Download all the inputs to the input path""" diff --git a/tests/test_cli.py b/tests/test_cli.py index 09a069f..e11697f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -101,6 +101,7 @@ def test_inference_from_task_list( encode_b64j(val=cmd), ) monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False") + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False") runner = CliRunner() runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)]) @@ -144,6 +145,7 @@ def test_inference_from_s3_uri(minio, monkeypatch, cmd, expected_return_code): encode_b64j(val=cmd), ) monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False") + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False") definition_key = f"{uuid4()}/invocations.json" @@ -186,6 +188,7 @@ def test_logging_setup(minio, monkeypatch): encode_b64j(val=["echo", "hello"]), ) monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False") + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False") runner = CliRunner() result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)]) @@ -216,6 +219,7 @@ def test_logging_stderr_setup(minio, monkeypatch): ), ) monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False") + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False") runner = CliRunner() result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)]) diff --git a/tests/test_io.py b/tests/test_io.py index 4a6871b..79c0227 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -314,6 +314,7 @@ async def test_inference_result_upload( encode_b64j(val=cmd), ) monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False") + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False") direct_invocation = await task.invoke() diff --git a/tests/test_models.py b/tests/test_models.py index cdafbc4..bf4cb84 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -407,3 +407,50 @@ def test_ensure_directories_are_writable(tmp_path, monkeypatch): assert model.stat().st_mode == 0o40777 assert checkpoints.stat().st_mode == 0o40777 assert tmp.stat().st_mode == 0o40777 + + +def test_linked_input_path_default(): + t = InferenceTask( + pk="test", inputs=[], output_bucket_name="test", output_prefix="test" + ) + + assert t.linked_input_path == Path("/opt/ml/input/data/test-input") + + +def test_linked_input_path_setting(monkeypatch): + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_LINKED_INPUT_PARENT", "/foo") + + t = InferenceTask( + pk="test", inputs=[], output_bucket_name="test", output_prefix="test" + ) + + assert t.linked_input_path == Path("/foo/test-input") + + +def test_reset_linked_input(tmp_path, monkeypatch): + input_path = tmp_path / "input" + linked_input_parent = tmp_path / "linked-input" + + monkeypatch.setenv( + "GRAND_CHALLENGE_COMPONENT_INPUT_PATH", input_path.absolute() + ) + monkeypatch.setenv( + "GRAND_CHALLENGE_COMPONENT_LINKED_INPUT_PARENT", linked_input_parent + ) + + t = InferenceTask( + pk="test", inputs=[], output_bucket_name="test", output_prefix="test" + ) + t.reset_io() + + expected_input_directory = linked_input_parent / "test-input" + + assert input_path.exists() + assert input_path.is_symlink() + assert expected_input_directory.exists() + assert expected_input_directory.is_dir() + assert input_path.resolve(strict=True) == expected_input_directory + + # Ensure 0o755 permissions + assert os.stat(input_path).st_mode & 0o777 == 0o755 + assert os.stat(expected_input_directory).st_mode & 0o777 == 0o755