Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Place input files in /opt/ml/input/data/{pk}-input and link to them from /input #29

Merged
merged 2 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _get_group_id(id_or_name: str) -> int | None:

def clean_path(path: Path) -> None:
for f in path.glob("*"):
if f.is_file():
if f.is_symlink() or f.is_file():
f.chmod(0o700)
f.unlink()
elif f.is_dir():
Expand Down Expand Up @@ -493,6 +493,19 @@ def input_path(self) -> Path:
logger.debug(f"{input_path=}")
return input_path

@property
def linked_input_path(self) -> Path:
"""Local path where the input files will be placed and linked to"""
linked_input_parent = Path(
os.environ.get(
"GRAND_CHALLENGE_COMPONENT_LINKED_INPUT_PARENT",
"/opt/ml/input/data/",
)
)
linked_input_path = linked_input_parent / f"{self.pk}-input"
logger.debug(f"{linked_input_path=}")
return linked_input_path

@property
def output_path(self) -> Path:
"""Local path where the subprocess is expected to write its files"""
Expand Down Expand Up @@ -590,7 +603,7 @@ async def _invoke(self) -> InferenceResult:
logger.info(f"Invoking {self.pk=}")

try:
self.clean_io()
self.reset_io()

try:
self.download_input()
Expand All @@ -609,12 +622,43 @@ async def _invoke(self) -> InferenceResult:
pk=self.pk, return_code=return_code, outputs=outputs
)
finally:
self.clean_io()
self.reset_io()

def clean_io(self) -> None:
"""Clean all contents of input and output folders"""
def reset_io(self) -> None:
"""Resets the input and output directories"""
clean_path(path=self.input_path)
clean_path(path=self.output_path)
self.reset_linked_input()

def reset_linked_input(self) -> None:
"""Resets the symlink from the input to the linked directory"""
if (
os.environ.get(
"GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "True"
).lower()
== "true"
):
logger.info(
f"Setting up linked input from {self.input_path} "
f"to {self.linked_input_path}"
)

if self.input_path.exists():
if self.input_path.is_symlink():
self.input_path.unlink()
elif self.input_path.is_dir():
self.input_path.rmdir()

if self.linked_input_path.exists():
self.linked_input_path.rmdir()

self.linked_input_path.mkdir(parents=True)
self.linked_input_path.chmod(0o755)

self.input_path.symlink_to(
self.linked_input_path, target_is_directory=True
)
self.input_path.chmod(0o755)

def download_input(self) -> None:
"""Download all the inputs to the input path"""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_inference_from_task_list(
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")

runner = CliRunner()
runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down Expand Up @@ -144,6 +145,7 @@ def test_inference_from_s3_uri(minio, monkeypatch, cmd, expected_return_code):
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")

definition_key = f"{uuid4()}/invocations.json"

Expand Down Expand Up @@ -186,6 +188,7 @@ def test_logging_setup(minio, monkeypatch):
encode_b64j(val=["echo", "hello"]),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")

runner = CliRunner()
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down Expand Up @@ -216,6 +219,7 @@ def test_logging_stderr_setup(minio, monkeypatch):
),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")

runner = CliRunner()
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down
1 change: 1 addition & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ async def test_inference_result_upload(
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")

direct_invocation = await task.invoke()

Expand Down
47 changes: 47 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,50 @@ def test_ensure_directories_are_writable(tmp_path, monkeypatch):
assert model.stat().st_mode == 0o40777
assert checkpoints.stat().st_mode == 0o40777
assert tmp.stat().st_mode == 0o40777


def test_linked_input_path_default():
t = InferenceTask(
pk="test", inputs=[], output_bucket_name="test", output_prefix="test"
)

assert t.linked_input_path == Path("/opt/ml/input/data/test-input")


def test_linked_input_path_setting(monkeypatch):
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_LINKED_INPUT_PARENT", "/foo")

t = InferenceTask(
pk="test", inputs=[], output_bucket_name="test", output_prefix="test"
)

assert t.linked_input_path == Path("/foo/test-input")


def test_reset_linked_input(tmp_path, monkeypatch):
input_path = tmp_path / "input"
linked_input_parent = tmp_path / "linked-input"

monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_INPUT_PATH", input_path.absolute()
)
monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_LINKED_INPUT_PARENT", linked_input_parent
)

t = InferenceTask(
pk="test", inputs=[], output_bucket_name="test", output_prefix="test"
)
t.reset_io()

expected_input_directory = linked_input_parent / "test-input"

assert input_path.exists()
assert input_path.is_symlink()
assert expected_input_directory.exists()
assert expected_input_directory.is_dir()
assert input_path.resolve(strict=True) == expected_input_directory

# Ensure 0o755 permissions
assert os.stat(input_path).st_mode & 0o777 == 0o755
assert os.stat(expected_input_directory).st_mode & 0o777 == 0o755