Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set the users primary group unless unset #16

Merged
merged 1 commit into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sagemaker-shim"
version = "0.2.1a1"
version = "0.2.1a2"
description = "Adapts algorithms that implement the Grand Challenge inference API for running in SageMaker"
authors = ["James Meakin <12661555+jmsmkn@users.noreply.github.com>"]
license = "Apache-2.0"
Expand Down
73 changes: 41 additions & 32 deletions sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import re
import subprocess
from base64 import b64decode
from collections.abc import Callable
from functools import cached_property
from importlib.metadata import version
from pathlib import Path
Expand Down Expand Up @@ -159,6 +158,7 @@ class InferenceResult(BaseModel):
class UserGroup(NamedTuple):
uid: int | None
gid: int | None
home: str | None


class InferenceTask(BaseModel):
Expand Down Expand Up @@ -284,43 +284,36 @@ def proc_env(self) -> dict[str, str]:
else:
env.pop(lp_key, None)

if self.proc_user.uid is not None:
pw_record = pwd.getpwuid(self.proc_user.uid)
env["HOME"] = pw_record.pw_dir
if self.proc_user.home is not None:
env["HOME"] = self.proc_user.home

return env

@staticmethod
def _get_user_or_group_id(*, match: re.Match[str], key: str) -> int | None:
value = match.group(key)

if value == "":
def _get_user_info(id_or_name: str) -> pwd.struct_passwd | None:
if id_or_name == "":
return None

if key == "user":
name_lookup: Callable[
[str], pwd.struct_passwd | grp.struct_group
] = pwd.getpwnam
id_lookup: Callable[
[int], pwd.struct_passwd | grp.struct_group
] = pwd.getpwuid
attr = "pw_uid"
elif key == "group":
name_lookup = grp.getgrnam
id_lookup = grp.getgrgid
attr = "gr_gid"
else:
raise RuntimeError("Unknown key")

try:
out: int = getattr(name_lookup(value), attr)
return pwd.getpwnam(id_or_name)
except (KeyError, AttributeError):
try:
out = getattr(id_lookup(int(value)), attr)
return pwd.getpwuid(int(id_or_name))
except (KeyError, ValueError, AttributeError) as error:
raise RuntimeError(f"{key} {value} not found") from error
raise RuntimeError(f"User {id_or_name} not found") from error

return out
@staticmethod
def _get_group_info(id_or_name: str) -> grp.struct_group | None:
if id_or_name == "":
return None

try:
return grp.getgrnam(id_or_name)
except (KeyError, AttributeError):
try:
return grp.getgrgid(int(id_or_name))
except (KeyError, ValueError, AttributeError) as error:
raise RuntimeError(f"Group {id_or_name} not found") from error

@cached_property
def proc_user(self) -> UserGroup:
Expand All @@ -329,12 +322,28 @@ def proc_user(self) -> UserGroup:
)

if match:
return UserGroup(
uid=self._get_user_or_group_id(match=match, key="user"),
gid=self._get_user_or_group_id(match=match, key="group"),
)
user = self._get_user_info(id_or_name=match.group("user"))
group = self._get_group_info(id_or_name=match.group("group"))

if user is None:
uid = None
home = None
else:
uid = user.pw_uid
home = user.pw_dir

if group is None:
if user is None:
gid = None
else:
# Switch to the users primary group
gid = user.pw_gid
else:
gid = group.gr_gid

return UserGroup(uid=uid, gid=gid, home=home)
else:
return UserGroup(uid=None, gid=None)
return UserGroup(uid=None, gid=None, home=None)

async def invoke(self) -> InferenceResult:
"""Run the inference on a single case"""
Expand Down
6 changes: 4 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,19 @@ def test_removing_ld_library_path(monkeypatch):
@pytest.mark.parametrize(
"user,expected_user,expected_group",
(
("0", 0, None),
("0", 0, 0),
("0:0", 0, 0),
(":0", None, 0),
("", None, None),
("root", 0, None),
("root", 0, 0),
(f"root:{grp.getgrgid(0).gr_name}", 0, 0),
(f":{grp.getgrgid(0).gr_name}", None, 0),
("", None, None),
("🙈:🙉", None, None),
("root:0", 0, 0),
(f"0:{grp.getgrgid(0).gr_name}", 0, 0),
(f":{os.getgid()}", None, os.getgid()),
(f"root:{os.getgid()}", 0, os.getgid()),
),
)
def test_proc_user(monkeypatch, user, expected_user, expected_group):
Expand Down