Skip to content

Commit

Permalink
Merge pull request #61 from apoorvkh/misc-refactoring
Browse files Browse the repository at this point in the history
Misc refactoring
  • Loading branch information
apoorvkh committed Sep 12, 2024
2 parents 81f5e91 + 9c97d09 commit 060850b
Show file tree
Hide file tree
Showing 13 changed files with 399 additions and 304 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ jobs:
cache: false
environments: default
activate-environment: default
- run: pytest tests/test_CI.py
- run: pytest tests/test_ci.py
4 changes: 2 additions & 2 deletions pixi.lock
Original file line number Diff line number Diff line change
Expand Up @@ -2601,9 +2601,9 @@ packages:
requires_python: '>=3.8.0'
- kind: pypi
name: torchrunx
version: 0.1.3
version: 0.2.0
path: .
sha256: 7352054b1212a4ce0d60c055288dd4f51cea2093a84d0a1a48ea97bdaa703fad
sha256: 1753f43bee54bc0da38cdd524dc501c0c2be9fbaaa7036bced9c9d03a7a8e810
requires_dist:
- cloudpickle>=3.0.0
- fabric>=3.0.0
Expand Down
21 changes: 19 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "torchrunx"
version = "0.1.3"
version = "0.2.0"
authors = [
{name = "Apoorv Khandelwal", email = "mail@apoorvkh.com"},
{name = "Peter Curtin", email = "peter_curtin@brown.edu"},
Expand Down Expand Up @@ -41,7 +41,24 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
line-length = 100
src = ["src", "tests"]
[tool.ruff.lint]
select = ["E", "F", "B", "UP", "I"]
select = ["ALL"]
ignore = [
"D", # documentation
"ANN101", "ANN102", "ANN401", # self / cls / Any annotations
"BLE001", # blind exceptions
"TD", # todo syntax
"FIX002", # existing todos
"PLR0913", # too many arguments
"DTZ005", # datetime timezone
"S301", # bandit: pickle
"S603", "S607", # bandit: subprocess
"COM812", "ISC001", # conflict with formatter
]
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = [
"S101", # allow asserts
"T201" # allow prints
]

[tool.pyright]
include = ["src", "tests"]
Expand Down
4 changes: 4 additions & 0 deletions src/torchrunx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .launcher import Launcher, launch
from .logging_utils import add_filter_to_handler, file_handler, stream_handler

__all__ = [
"Launcher",
"launch",
"add_filter_to_handler",
"file_handler",
"stream_handler",
]
66 changes: 28 additions & 38 deletions src/torchrunx/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import socket
import sys
import tempfile
import traceback
from dataclasses import dataclass
from typing import Any, Callable, Literal

import cloudpickle
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing import start_processes
from typing_extensions import Self
import torch.distributed.elastic.multiprocessing as dist_mp

from .logging_utils import log_records_to_socket, redirect_stdio_to_logger
from .utils import (
Expand All @@ -40,16 +40,20 @@ class WorkerArgs:
hostname: str
timeout: int

def to_bytes(self) -> bytes:
return cloudpickle.dumps(self)
def serialize(self) -> SerializedWorkerArgs:
return SerializedWorkerArgs(worker_args=self)

@classmethod
def from_bytes(cls, serialized: bytes) -> Self:
return cloudpickle.loads(serialized)

class SerializedWorkerArgs:
def __init__(self, worker_args: WorkerArgs) -> None:
self.bytes = cloudpickle.dumps(worker_args)

def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
worker_args = WorkerArgs.from_bytes(serialized_worker_args)
def deserialize(self) -> WorkerArgs:
return cloudpickle.loads(self.bytes)


def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException:
worker_args: WorkerArgs = serialized_worker_args.deserialize()

logger = logging.getLogger()

Expand All @@ -63,18 +67,14 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:

redirect_stdio_to_logger(logger)

store = dist.TCPStore( # pyright: ignore[reportPrivateImportUsage]
store = dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
host_name=worker_args.main_agent_hostname,
port=worker_args.main_agent_port,
world_size=worker_args.world_size,
is_master=(worker_args.rank == 0),
)

backend = worker_args.backend
if backend is None:
backend = "nccl" if torch.cuda.is_available() else "gloo"

logger.debug(f"using backend: {backend}")
backend = worker_args.backend or ("nccl" if torch.cuda.is_available() else "gloo")

dist.init_process_group(
backend=backend,
Expand All @@ -91,19 +91,17 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)

logger.debug(f"executing function: {worker_args.function}")

try:
return worker_args.function()
except Exception as e:
logger.error(e)
traceback.print_exc()
return WorkerException(exception=e)
finally:
sys.stdout.flush()
sys.stderr.flush()


def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int):
def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None:
agent_rank = launcher_agent_group.rank - 1

payload = AgentPayload(
Expand Down Expand Up @@ -132,16 +130,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_

redirect_stdio_to_logger(logger)

if torch.__version__ >= "2.3":
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs

log_kwargs = {"logs_specs": DefaultLogsSpecs(log_dir=tempfile.mkdtemp())}
else:
log_kwargs = {"log_dir": tempfile.mkdtemp()}

# spawn workers

ctx = start_processes(
ctx = dist_mp.start_processes(
name=f"{hostname}_",
entrypoint=entrypoint,
args={
Expand All @@ -159,31 +150,30 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
world_size=worker_world_size,
hostname=launcher_payload.hostnames[agent_rank],
timeout=launcher_payload.timeout,
).to_bytes(),
).serialize(),
)
for i in range(num_workers)
},
envs={i: {} for i in range(num_workers)},
**log_kwargs, # pyright: ignore [reportArgumentType]
**(
{"logs_specs": dist_mp.DefaultLogsSpecs(log_dir=tempfile.mkdtemp())}
if torch.__version__ >= "2.3"
else {"log_dir": tempfile.mkdtemp()}
), # pyright: ignore [reportArgumentType]
)
logger.info("starting processes")

try:
status = None
while True:
if status is None or status.state == "running":
status = AgentStatus.from_result(
result=ctx.wait(5), worker_global_ranks=worker_global_ranks
)
status = AgentStatus.from_result(ctx.wait(5))

agent_statuses = launcher_agent_group.sync_agent_statuses(status=status)

if all(s.state == "done" for s in agent_statuses):
break
elif any(s.state == "failed" for s in agent_statuses):
all_done = all(s.state == "done" for s in agent_statuses)
any_failed = any(s.state == "failed" for s in agent_statuses)
if all_done or any_failed:
break
except:
raise
finally:
ctx.close()
sys.stdout.flush()
Expand Down
17 changes: 11 additions & 6 deletions src/torchrunx/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def slurm_hosts() -> list[str]:
:rtype: list[str]
"""
# TODO: sanity check SLURM variables, commands
assert in_slurm_job()
if not in_slurm_job():
msg = "Not in a SLURM job"
raise RuntimeError(msg)
return (
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
.decode()
Expand All @@ -35,15 +37,18 @@ def slurm_workers() -> int:
:rtype: int
"""
# TODO: sanity check SLURM variables, commands
assert in_slurm_job()
if not in_slurm_job():
msg = "Not in a SLURM job"
raise RuntimeError(msg)

if "SLURM_JOB_GPUS" in os.environ:
# TODO: is it possible to allocate uneven GPUs across nodes?
return len(os.environ["SLURM_JOB_GPUS"].split(","))
elif "SLURM_GPUS_PER_NODE" in os.environ:
if "SLURM_GPUS_PER_NODE" in os.environ:
return int(os.environ["SLURM_GPUS_PER_NODE"])
else:
# TODO: should we assume that we plan to do one worker per CPU?
return int(os.environ["SLURM_CPUS_ON_NODE"])

# TODO: should we assume that we plan to do one worker per CPU?
return int(os.environ["SLURM_CPUS_ON_NODE"])


def auto_hosts() -> list[str]:
Expand Down
Loading

0 comments on commit 060850b

Please sign in to comment.