Skip to content

Commit

Permalink
Set $HOME (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsmkn authored Dec 16, 2023
1 parent efb3564 commit 4fdc07c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sagemaker-shim"
version = "0.2.1a0"
version = "0.2.1a1"
description = "Adapts algorithms that implement the Grand Challenge inference API for running in SageMaker"
authors = ["James Meakin <12661555+jmsmkn@users.noreply.github.com>"]
license = "Apache-2.0"
Expand Down
20 changes: 12 additions & 8 deletions sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ class InferenceResult(BaseModel):


class UserGroup(NamedTuple):
user: int | None
group: int | None
uid: int | None
gid: int | None


class InferenceTask(BaseModel):
Expand Down Expand Up @@ -284,6 +284,10 @@ def proc_env(self) -> dict[str, str]:
else:
env.pop(lp_key, None)

if self.proc_user.uid is not None:
pw_record = pwd.getpwuid(self.proc_user.uid)
env["HOME"] = pw_record.pw_dir

return env

@staticmethod
Expand Down Expand Up @@ -318,19 +322,19 @@ def _get_user_or_group_id(*, match: re.Match[str], key: str) -> int | None:

return out

@property
@cached_property
def proc_user(self) -> UserGroup:
match = re.fullmatch(
r"^(?P<user>[0-9a-zA-Z]*):?(?P<group>[0-9a-zA-Z]*)$", self.user
)

if match:
return UserGroup(
user=self._get_user_or_group_id(match=match, key="user"),
group=self._get_user_or_group_id(match=match, key="group"),
uid=self._get_user_or_group_id(match=match, key="user"),
gid=self._get_user_or_group_id(match=match, key="group"),
)
else:
return UserGroup(user=None, group=None)
return UserGroup(uid=None, gid=None)

async def invoke(self) -> InferenceResult:
"""Run the inference on a single case"""
Expand Down Expand Up @@ -458,8 +462,8 @@ async def execute(self) -> int:

process = await asyncio.create_subprocess_exec(
*self.proc_args,
user=self.proc_user.user,
group=self.proc_user.group,
user=self.proc_user.uid,
group=self.proc_user.gid,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=self.proc_env,
Expand Down
19 changes: 15 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import grp
import os
import pwd

import pytest

Expand Down Expand Up @@ -84,8 +85,8 @@ def test_proc_user(monkeypatch, user, expected_user, expected_group):
)

assert t.user == user
assert t.proc_user.user == expected_user
assert t.proc_user.group == expected_group
assert t.proc_user.uid == expected_user
assert t.proc_user.gid == expected_group


def test_proc_user_unset():
Expand All @@ -94,5 +95,15 @@ def test_proc_user_unset():
)

assert t.user == ""
assert t.proc_user.user is None
assert t.proc_user.group is None
assert t.proc_user.uid is None
assert t.proc_user.gid is None


def test_home_is_set(monkeypatch):
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USER", "root")

t = InferenceTask(
pk="test", inputs=[], output_bucket_name="test", output_prefix="test"
)

assert t.proc_env["HOME"] == pwd.getpwnam("root").pw_dir

0 comments on commit 4fdc07c

Please sign in to comment.