Skip to content

Commit

Permalink
Merge pull request #60 from apoorvkh/exception-propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvkh committed Sep 11, 2024
2 parents b9110d6 + aef0aa7 commit 81f5e91
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 90 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
line-length = 100
src = ["src", "tests"]
[tool.ruff.lint]
extend-select = ["I"]
select = ["E", "F", "B", "UP", "I"]

[tool.pyright]
include = ["src", "tests"]
Expand Down
37 changes: 18 additions & 19 deletions src/torchrunx/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import tempfile
from dataclasses import dataclass
from typing import Callable, Literal
from typing import Any, Callable, Literal

import cloudpickle
import torch
Expand All @@ -20,7 +20,7 @@
AgentPayload,
AgentStatus,
LauncherAgentGroup,
LauncherPayload,
WorkerException,
get_open_port,
)

Expand Down Expand Up @@ -48,7 +48,7 @@ def from_bytes(cls, serialized: bytes) -> Self:
return cloudpickle.loads(serialized)


def entrypoint(serialized_worker_args: bytes):
def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
worker_args = WorkerArgs.from_bytes(serialized_worker_args)

logger = logging.getLogger()
Expand Down Expand Up @@ -93,13 +93,14 @@ def entrypoint(serialized_worker_args: bytes):

logger.debug(f"executing function: {worker_args.function}")

r = worker_args.function()

# flush streams
sys.stdout.flush()
sys.stderr.flush()

return r
try:
return worker_args.function()
except Exception as e:
logger.error(e)
return WorkerException(exception=e)
finally:
sys.stdout.flush()
sys.stderr.flush()


def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int):
Expand All @@ -111,9 +112,8 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
process_id=os.getpid(),
)

all_payloads = launcher_agent_group.sync_payloads(payload=payload)
launcher_payload: LauncherPayload = all_payloads[0] # pyright: ignore[reportAssignmentType]
main_agent_payload: AgentPayload = all_payloads[1] # pyright: ignore[reportAssignmentType]
launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload)
main_agent_payload = agent_payloads[0]

hostname = launcher_payload.hostnames[agent_rank]
worker_world_size = launcher_payload.worker_world_size
Expand Down Expand Up @@ -169,20 +169,19 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
logger.info("starting processes")

try:
status = AgentStatus()
status = None
while True:
if status.is_running():
if status is None or status.state == "running":
status = AgentStatus.from_result(
result=ctx.wait(5), worker_global_ranks=worker_global_ranks
)

agent_statuses = launcher_agent_group.sync_agent_statuses(status=status)

if all(s.is_done() for s in agent_statuses):
if all(s.state == "done" for s in agent_statuses):
break
elif any(s.state == "failed" for s in agent_statuses):
break

if any(s.is_failed() for s in agent_statuses):
raise RuntimeError()
except:
raise
finally:
Expand Down
77 changes: 36 additions & 41 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,21 @@
import subprocess
import sys
from collections import ChainMap
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import partial
from logging import Handler
from multiprocessing import Process
from typing import Any, Callable, Literal
from typing import Any, Callable, Literal, Sequence

import fabric
import torch.distributed as dist

from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
from .logging_utils import LogRecordSocketReceiver, default_handlers
from .utils import (
AgentPayload,
AgentStatus,
LauncherAgentGroup,
LauncherPayload,
WorkerException,
get_open_port,
)

Expand Down Expand Up @@ -59,31 +58,29 @@ def execute_command(

@dataclass
class Launcher:
hostnames: list[str] | Literal["auto", "slurm"] = field(default_factory=lambda: ["localhost"])
workers_per_host: int | list[int] | Literal["auto", "slurm"] = 1
hostnames: list[str] | Literal["auto", "slurm"] = "auto"
workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto"
ssh_config_file: str | os.PathLike | None = None
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None
log_handlers: list[Handler] | Literal["auto"] | None = "auto"
env_vars: list[str] = field(
default_factory=lambda: [
"PATH",
"LD_LIBRARY",
"LIBRARY_PATH",
"PYTHON*",
"CUDA*",
"TORCH*",
"PYTORCH*",
"NCCL*",
]
env_vars: Sequence[str] = (
"PATH",
"LD_LIBRARY",
"LIBRARY_PATH",
"PYTHON*",
"CUDA*",
"TORCH*",
"PYTORCH*",
"NCCL*",
)
env_file: str | os.PathLike | None = None
timeout: int = 600

def run(
self,
func: Callable,
func_args: tuple[Any] = tuple(),
func_kwargs: dict[str, Any] = {},
func_args: tuple[Any] | None = None,
func_kwargs: dict[str, Any] | None = None,
) -> dict[int, Any]:
"""
Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch`
Expand Down Expand Up @@ -205,6 +202,11 @@ def run(
host_ranks = range(_cumulative_workers[n], _cumulative_workers[n + 1])
worker_global_ranks.append(list(host_ranks))

if func_args is None:
func_args = tuple()
if func_kwargs is None:
func_kwargs = dict()

payload = LauncherPayload(
fn=partial(func, *func_args, **func_kwargs),
hostnames=self.hostnames,
Expand All @@ -214,30 +216,23 @@ def run(
timeout=self.timeout,
)

agent_payloads: list[AgentPayload] = launcher_agent_group.sync_payloads(payload=payload)[1:] # pyright: ignore[reportAssignmentType]
launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload)
agent_pids = [p.process_id for p in agent_payloads]

# loop to monitor agent statuses (until failed or done)
try:
while True:
agent_statuses = launcher_agent_group.sync_agent_statuses(status=AgentStatus())
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)

for s in agent_statuses:
if s.state == "failed":
for value in s.return_values.values():
if isinstance(value, WorkerException):
raise value.exception

if all(s.is_done() for s in agent_statuses):
if all(s.state == "done" for s in agent_statuses):
break

if any(s.is_failed() for s in agent_statuses):
# TODO: cleaner way to print these?
e = ""
for i, s in enumerate(agent_statuses):
if s is not None and s.is_failed():
for k, v in s.failures.items():
e += f"Node {i}, local worker {k} exited with error: "
if isinstance(v.message, str):
e += f"{v.message}\n"
else:
e += f"{v.message['message']}\n"
e += f"{v.message['extraInfo']['py_callstack']}\n\n"
raise RuntimeError(e)
except:
# cleanup: SIGTERM all agents
for agent_pid, agent_hostname in zip(agent_pids, self.hostnames):
Expand All @@ -259,14 +254,14 @@ def run(

def launch(
func: Callable,
func_args: tuple[Any] = tuple(),
func_kwargs: dict[str, Any] = {},
hostnames: list[str] | Literal["auto", "slurm"] = ["localhost"],
workers_per_host: int | list[int] | Literal["auto", "slurm"] = 1,
func_args: tuple[Any] | None = None,
func_kwargs: dict[str, Any] | None = None,
hostnames: list[str] | Literal["auto", "slurm"] = "auto",
workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto",
ssh_config_file: str | os.PathLike | None = None,
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None,
log_handlers: list[Handler] | Literal["auto"] = "auto",
env_vars: list[str] = [
env_vars: Sequence[str] = (
"PATH",
"LD_LIBRARY",
"LIBRARY_PATH",
Expand All @@ -275,7 +270,7 @@ def launch(
"TORCH*",
"PYTORCH*",
"NCCL*",
],
),
env_file: str | os.PathLike | None = None,
timeout: int = 600,
) -> dict[int, Any]:
Expand Down
53 changes: 26 additions & 27 deletions src/torchrunx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import cloudpickle
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.api import RunProcsResult
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
from typing_extensions import Self


Expand All @@ -20,6 +19,11 @@ def get_open_port() -> int:
return port


@dataclass
class WorkerException:
exception: Exception


@dataclass
class LauncherPayload:
fn: Callable
Expand All @@ -39,33 +43,25 @@ class AgentPayload:

@dataclass
class AgentStatus:
running: bool = True
failed: bool = False
return_values: dict[int, Any] = field(default_factory=dict)
failures: dict[int, ProcessFailure] = field(default_factory=dict)
stdouts: dict[int, str] = field(default_factory=dict)
stderrs: dict[int, str] = field(default_factory=dict)
state: Literal["running", "failed", "done"]
return_values: dict[int, Any | WorkerException] = field(default_factory=dict)

@classmethod
def from_result(cls, result: RunProcsResult | None, worker_global_ranks: list[int]) -> Self:
if result is None:
return cls()
return cls(state="running")

return cls(
running=False,
failed=result.is_failed(),
return_values={worker_global_ranks[k]: v for k, v in result.return_values.items()},
failures={worker_global_ranks[k]: v for k, v in result.failures.items()},
)
return_values = result.return_values

def is_running(self) -> bool:
return self.running
if any(isinstance(v, WorkerException) for v in return_values.values()):
state = "failed"
else:
state = "done"

def is_failed(self) -> bool:
return self.failed

def is_done(self) -> bool:
return not self.running and not self.failed
return cls(
state=state,
return_values={worker_global_ranks[k]: v for k, v in return_values.items()},
)


@dataclass
Expand Down Expand Up @@ -98,15 +94,18 @@ def _deserialize(self, serialized: bytes) -> Any:
def _all_gather(self, object: Any) -> list:
"""gather object from every rank to list on every rank"""
object_bytes = self._serialize(object)
object_list = [bytes()] * self.world_size
object_list = [b""] * self.world_size
dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group)
object_list = [self._deserialize(o) for o in object_list]
return object_list

def sync_payloads(
self, payload: LauncherPayload | AgentPayload
) -> list[LauncherPayload | AgentPayload]:
return self._all_gather(object=payload)

def sync_agent_statuses(self, status: AgentStatus) -> list[AgentStatus]:
return self._all_gather(object=status)[1:]
) -> tuple[LauncherPayload, list[AgentPayload]]:
payloads = self._all_gather(object=payload)
launcher_payload = payloads[0]
agent_payloads = payloads[1:]
return launcher_payload, agent_payloads

def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]:
return self._all_gather(object=status)[1:] # [0] is launcher (status=None)
4 changes: 2 additions & 2 deletions tests/test_CI.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def dist_func():
assert len(log_files) == 3

for file in log_files:
with open(f"{tmp}/{file}", "r") as f:
with open(f"{tmp}/{file}") as f:
contents = f.read()
print(contents)
if file.endswith("[0].log"):
Expand All @@ -79,7 +79,7 @@ def error_func():
tmp = tempfile.mkdtemp()
os.environ["TORCHRUNX_DIR"] = tmp

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(ValueError) as excinfo:
trx.launch(
func=error_func,
func_kwargs={},
Expand Down

0 comments on commit 81f5e91

Please sign in to comment.