From da9bdb466b0498e5611ad86bdc4a5a1008f97007 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sat, 7 Sep 2024 05:10:51 -0400 Subject: [PATCH 1/6] more refactoring --- pixi.lock | 2 +- src/torchrunx/launcher.py | 221 ++++++++++++++++++++------------- src/torchrunx/logging_utils.py | 3 + src/torchrunx/utils.py | 4 +- tests/test_CI.py | 5 +- tests/test_func.py | 6 +- 6 files changed, 145 insertions(+), 96 deletions(-) diff --git a/pixi.lock b/pixi.lock index 8e8c95d..377385d 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2603,7 +2603,7 @@ packages: name: torchrunx version: 0.1.3 path: . - sha256: 7352054b1212a4ce0d60c055288dd4f51cea2093a84d0a1a48ea97bdaa703fad + sha256: 0a30b1182ca7c101ff1d147eba62de2ba883f822fdedd13fa49207c5484f6cd8 requires_dist: - cloudpickle>=3.0.0 - fabric>=3.0.0 diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 89cde69..58011ae 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -8,11 +8,11 @@ import socket import subprocess import sys -from collections import ChainMap from dataclasses import dataclass from functools import partial from logging import Handler from multiprocessing import Process +from pathlib import Path from typing import Any, Callable, Literal, Sequence import fabric @@ -28,6 +28,57 @@ ) +def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: + if hostnames == "auto": + return auto_hosts() + elif hostnames == "slurm": + return slurm_hosts() + return hostnames + + +def resolve_workers_per_host( + workers_per_host: int | list[int] | Literal["auto", "slurm"], num_hosts: int +) -> list[int]: + if workers_per_host == "auto": + workers_per_host = auto_workers() + elif workers_per_host == "slurm": + workers_per_host = slurm_workers() + + if isinstance(workers_per_host, int): + workers_per_host = [workers_per_host] * num_hosts + else: + assert len(workers_per_host) == num_hosts + + return workers_per_host + + +def build_logging_server( + log_handlers: list[Handler] | Literal["auto"] | None, + launcher_hostname: str, + hostnames: list[str], + workers_per_host: list[int], + log_dir: str | os.PathLike, + log_level: int, +) -> LogRecordSocketReceiver: + if log_handlers is None: + log_handlers = [] + elif log_handlers == "auto": + log_handlers = default_handlers( + hostnames=hostnames, + workers_per_host=workers_per_host, + log_dir=log_dir, + log_level=log_level, + ) + + log_receiver = LogRecordSocketReceiver( + host=launcher_hostname, + port=get_open_port(), + handlers=log_handlers, + ) + + return log_receiver + + def is_localhost(hostname_or_ip: str) -> bool: # check if host is "loopback" address (i.e. designated to send to self) try: @@ -56,6 +107,43 @@ def execute_command( conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True) +def build_command( + launcher_hostname: str, + launcher_port: int, + logger_port: int, + world_size: int, + rank: int, + env_vars: Sequence[str], + env_file: str | os.PathLike | None, +) -> str: + current_dir = os.getcwd() + + env_exports = [] + for k, v in os.environ.items(): + if any(fnmatch.fnmatch(k, e) for e in env_vars): + env_exports.append(f"{k}={v}") + + env_export_string = "" + if len(env_exports) > 0: + env_export_string = f"export {' '.join(env_exports)} && " + + env_file_string = "" + if env_file is not None: + env_file_string = f"source {env_file} && " + + return ( + f"cd {current_dir} && " + f"{env_export_string}" + f"{env_file_string}" + f"{sys.executable} -u -m torchrunx " + f"--launcher-hostname {launcher_hostname} " + f"--launcher-port {launcher_port} " + f"--logger-port {logger_port} " + f"--world-size {world_size} " + f"--rank {rank}" + ) + + @dataclass class Launcher: hostnames: list[str] | Literal["auto", "slurm"] = "auto" @@ -81,7 +169,7 @@ def run( func: Callable, func_args: tuple[Any] | None = None, func_kwargs: dict[str, Any] | None = None, - ) -> dict[int, Any]: + ) -> dict[str, dict[int, Any]]: """ Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch` @@ -98,91 +186,50 @@ def run( if not dist.is_available(): raise RuntimeError("The torch.distributed package is not available.") - if self.hostnames == "auto": - self.hostnames = auto_hosts() - elif self.hostnames == "slurm": - self.hostnames = slurm_hosts() - - num_hosts = len(self.hostnames) - - if self.workers_per_host == "auto": - self.workers_per_host = auto_workers() - elif self.workers_per_host == "slurm": - self.workers_per_host = slurm_workers() - - if isinstance(self.workers_per_host, int): - self.workers_per_host = [self.workers_per_host] * num_hosts - - assert num_hosts == len(self.workers_per_host) - - # + hostnames = resolve_hostnames(self.hostnames) + workers_per_host = resolve_workers_per_host(self.workers_per_host, len(hostnames)) launcher_hostname = socket.getfqdn() + launcher_port = get_open_port() + world_size = len(hostnames) + 1 - # setup logging - - if self.log_handlers is None: - self.log_handlers = [] - elif self.log_handlers == "auto": - self.log_handlers = default_handlers( - hostnames=self.hostnames, - workers_per_host=self.workers_per_host, - log_dir=os.environ.get("TORCHRUNX_DIR", "./torchrunx_logs"), - log_level=logging._nameToLevel.get( - os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO"), logging.NOTSET - ), - ) + # start logging server - logger_port = get_open_port() - log_receiver = LogRecordSocketReceiver( - host=launcher_hostname, port=logger_port, handlers=self.log_handlers + log_receiver = build_logging_server( + log_handlers=self.log_handlers, + launcher_hostname=launcher_hostname, + hostnames=hostnames, + workers_per_host=workers_per_host, + log_dir=Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")), + log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], ) + log_process = Process( target=log_receiver.serve_forever, daemon=True, ) - log_process.start() - - # launch command - - current_dir = os.getcwd() - - env_exports = [] - for k, v in os.environ.items(): - if any(fnmatch.fnmatch(k, e) for e in self.env_vars): - env_exports.append(f"{k}={v}") - env_export_string = "" - if len(env_exports) > 0: - env_export_string = f"export {' '.join(env_exports)} && " - - env_file_string = "" - if self.env_file is not None: - env_file_string = f"source {self.env_file} && " - - launcher_port = get_open_port() - world_size = num_hosts + 1 # launcher + agents + log_process.start() # start agents on each node - for i, hostname in enumerate(self.hostnames): + + for i, hostname in enumerate(hostnames): execute_command( - command=( - f"cd {current_dir} && " - f"{env_export_string}" - f"{env_file_string}" - f"{sys.executable} -u -m torchrunx " - f"--launcher-hostname {launcher_hostname} " - f"--launcher-port {launcher_port} " - f"--logger-port {logger_port} " - f"--world-size {world_size} " - f"--rank {i+1}" + command=build_command( + launcher_hostname=launcher_hostname, + launcher_port=launcher_port, + logger_port=log_receiver.port, + world_size=world_size, + rank=i + 1, + env_vars=self.env_vars, + env_file=self.env_file, ), hostname=hostname, ssh_config_file=self.ssh_config_file, ) # initialize launcher–agent process group - # ranks = (launcher, agent_0, ..., agent_{num_hosts-1}) + # ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1]) launcher_agent_group = LauncherAgentGroup( launcher_hostname=launcher_hostname, @@ -193,36 +240,30 @@ def run( # build and sync payloads between launcher and agents - _cumulative_workers = [0] + list(itertools.accumulate(self.workers_per_host)) - - worker_world_size = _cumulative_workers[-1] + _cumulative_workers = [0] + list(itertools.accumulate(workers_per_host)) - worker_global_ranks = [] # list of worker ranks per host - for n in range(num_hosts): - host_ranks = range(_cumulative_workers[n], _cumulative_workers[n + 1]) - worker_global_ranks.append(list(host_ranks)) - - if func_args is None: - func_args = tuple() - if func_kwargs is None: - func_kwargs = dict() + worker_global_ranks = [ + list(range(_cumulative_workers[n], _cumulative_workers[n + 1])) + for n in range(len(hostnames)) + ] payload = LauncherPayload( - fn=partial(func, *func_args, **func_kwargs), - hostnames=self.hostnames, - worker_world_size=worker_world_size, + fn=partial(func, *(func_args or ()), **(func_kwargs or {})), + hostnames=hostnames, worker_global_ranks=worker_global_ranks, + worker_world_size=sum(workers_per_host), backend=self.backend, timeout=self.timeout, ) launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload) - agent_pids = [p.process_id for p in agent_payloads] # loop to monitor agent statuses (until failed or done) + try: while True: agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) + # raises exception if communication timeout due to death of any agent for s in agent_statuses: if s.state == "failed": @@ -235,9 +276,9 @@ def run( except: # cleanup: SIGTERM all agents - for agent_pid, agent_hostname in zip(agent_pids, self.hostnames): + for agent_payload, agent_hostname in zip(agent_payloads, hostnames): execute_command( - command=f"kill {agent_pid}", + command=f"kill {agent_payload.process_id}", hostname=agent_hostname, ssh_config_file=self.ssh_config_file, ) @@ -248,8 +289,10 @@ def run( log_process.kill() dist.destroy_process_group() - return_values: dict[int, Any] = dict(ChainMap(*[s.return_values for s in agent_statuses])) - return return_values + return { + hostname: agent_status.return_values + for hostname, agent_status in zip(hostnames, agent_statuses) + } def launch( @@ -273,7 +316,7 @@ def launch( ), env_file: str | os.PathLike | None = None, timeout: int = 600, -) -> dict[int, Any]: +) -> dict[str, dict[int, Any]]: """ Launch a distributed PyTorch function on the specified nodes. diff --git a/src/torchrunx/logging_utils.py b/src/torchrunx/logging_utils.py index 469c845..e8a5b14 100644 --- a/src/torchrunx/logging_utils.py +++ b/src/torchrunx/logging_utils.py @@ -142,6 +142,9 @@ def flush(self): class LogRecordSocketReceiver(ThreadingTCPServer): def __init__(self, host: str, port: int, handlers: list[Handler]): + self.host = host + self.port = port + class _LogRecordStreamHandler(StreamRequestHandler): def handle(self): while True: diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 3a14d34..82274e6 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -28,8 +28,8 @@ class WorkerException: class LauncherPayload: fn: Callable hostnames: list[str] - worker_world_size: int worker_global_ranks: list[list[int]] + worker_world_size: int backend: Literal["mpi", "gloo", "nccl", "ucc", None] timeout: int @@ -60,7 +60,7 @@ def from_result(cls, result: RunProcsResult | None, worker_global_ranks: list[in return cls( state=state, - return_values={worker_global_ranks[k]: v for k, v in return_values.items()}, + return_values=return_values, ) diff --git a/tests/test_CI.py b/tests/test_CI.py index b86cad6..472d0c5 100644 --- a/tests/test_CI.py +++ b/tests/test_CI.py @@ -38,7 +38,8 @@ def dist_func(): backend="gloo", # log_dir="./test_logs" ) - assert torch.all(r[0] == r[1]) + results = next(iter(r.values())) + assert torch.all(results[0] == results[1]) def test_logging(): @@ -47,7 +48,7 @@ def dist_func(): print(f"worker rank: {rank}") tmp = tempfile.mkdtemp() - os.environ["TORCHRUNX_DIR"] = tmp + os.environ["TORCHRUNX_LOG_DIR"] = tmp trx.launch( func=dist_func, diff --git a/tests/test_func.py b/tests/test_func.py index 9db6454..444e783 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -13,9 +13,11 @@ def test_launch(): workers_per_host="slurm", ) + result_values = [v for host_results in result.values() for v in host_results.values()] + t = True - for i in range(len(result)): - t = t and torch.all(result[i] == result[0]) + for i in range(len(result_values)): + t = t and torch.all(result_values[i] == result_values[0]) assert t, "Not all tensors equal" From f41025fc01610d86983b60cf3576d11289f0d5f3 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Wed, 11 Sep 2024 16:03:35 -0400 Subject: [PATCH 2/6] ruff ANN rules --- pyproject.toml | 3 ++- src/torchrunx/agent.py | 2 +- src/torchrunx/logging_utils.py | 16 ++++++++-------- tests/test_CI.py | 13 +++++++------ tests/test_func.py | 4 ++-- tests/test_submitit.py | 14 ++++++++------ tests/test_train.py | 20 ++++++++++---------- 7 files changed, 38 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6f029d8..693b5f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,8 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"] line-length = 100 src = ["src", "tests"] [tool.ruff.lint] -select = ["E", "F", "B", "UP", "I"] +select = ["E", "F", "W", "ANN", "B", "UP", "I"] +ignore = ["ANN101", "ANN102", "ANN401"] [tool.pyright] include = ["src", "tests"] diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index f4dfab3..96aa383 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -103,7 +103,7 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException: sys.stderr.flush() -def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int): +def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None: agent_rank = launcher_agent_group.rank - 1 payload = AgentPayload( diff --git a/src/torchrunx/logging_utils.py b/src/torchrunx/logging_utils.py index e8a5b14..20a0051 100644 --- a/src/torchrunx/logging_utils.py +++ b/src/torchrunx/logging_utils.py @@ -101,12 +101,12 @@ def log_records_to_socket( worker_rank: int | None, logger_hostname: str, logger_port: int, -): +) -> None: logger.setLevel(logging.NOTSET) old_factory = logging.getLogRecordFactory() - def record_factory(*args, **kwargs): + def record_factory(*args, **kwargs) -> logging.LogRecord: # noqa: ANN002, ANN003 record = old_factory(*args, **kwargs) record.hostname = hostname record.worker_rank = worker_rank @@ -117,14 +117,14 @@ def record_factory(*args, **kwargs): logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port)) -def redirect_stdio_to_logger(logger: Logger): +def redirect_stdio_to_logger(logger: Logger) -> None: class _LoggingStream(StringIO): - def __init__(self, logger: Logger, level: int = logging.NOTSET): + def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None: super().__init__() self.logger = logger self.level = level - def flush(self): + def flush(self) -> None: super().flush() value = self.getvalue() if value != "": @@ -141,12 +141,12 @@ def flush(self): class LogRecordSocketReceiver(ThreadingTCPServer): - def __init__(self, host: str, port: int, handlers: list[Handler]): + def __init__(self, host: str, port: int, handlers: list[Handler]) -> None: self.host = host self.port = port class _LogRecordStreamHandler(StreamRequestHandler): - def handle(self): + def handle(self) -> None: while True: chunk = self.connection.recv(4) if len(chunk) < 4: @@ -168,7 +168,7 @@ def handle(self): ) self.daemon_threads = True - def shutdown(self): + def shutdown(self) -> None: """override BaseServer.shutdown() with added timeout""" self._BaseServer__shutdown_request = True self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue] diff --git a/tests/test_CI.py b/tests/test_CI.py index 472d0c5..bcd70bf 100644 --- a/tests/test_CI.py +++ b/tests/test_CI.py @@ -1,5 +1,6 @@ import os import tempfile +from typing import NoReturn import pytest import torch @@ -8,8 +9,8 @@ import torchrunx as trx -def test_simple_localhost(): - def dist_func(): +def test_simple_localhost() -> None: + def dist_func() -> torch.Tensor: rank = int(os.environ["RANK"]) if rank == 0: @@ -42,8 +43,8 @@ def dist_func(): assert torch.all(results[0] == results[1]) -def test_logging(): - def dist_func(): +def test_logging() -> None: + def dist_func() -> None: rank = int(os.environ["RANK"]) print(f"worker rank: {rank}") @@ -73,8 +74,8 @@ def dist_func(): assert "starting processes" in contents -def test_error(): - def error_func(): +def test_error() -> None: + def error_func() -> NoReturn: raise ValueError("abcdefg") tmp = tempfile.mkdtemp() diff --git a/tests/test_func.py b/tests/test_func.py index 444e783..8fb264b 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -6,7 +6,7 @@ import torchrunx as trx -def test_launch(): +def test_launch() -> None: result = trx.launch( func=simple_matmul, hostnames="slurm", @@ -22,7 +22,7 @@ def test_launch(): assert t, "Not all tensors equal" -def simple_matmul(): +def simple_matmul() -> torch.Tensor: rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(local_rank) if torch.cuda.is_available() else torch.device("cpu") diff --git a/tests/test_submitit.py b/tests/test_submitit.py index 290f7aa..225268d 100644 --- a/tests/test_submitit.py +++ b/tests/test_submitit.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import submitit @@ -9,22 +11,22 @@ class DummyDataset(Dataset): - def __init__(self, max_text_length=16, num_samples=20000) -> None: + def __init__(self, max_text_length: int = 16, num_samples: int = 20000) -> None: super().__init__() self.input_ids = torch.randint(0, 30522, (num_samples, max_text_length)) self.labels = copy.deepcopy(self.input_ids) - def __len__(self): + def __len__(self) -> int: return len(self.input_ids) - def __getitem__(self, index): + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: return { "input_ids": self.input_ids[index], "labels": self.labels[index], } -def main(): +def main() -> None: model = BertForMaskedLM.from_pretrained("bert-base-uncased") train_dataset = DummyDataset() @@ -46,11 +48,11 @@ def main(): trainer.train() -def launch(): +def launch() -> None: trx.launch(func=main, func_kwargs={}, hostnames="slurm", workers_per_host="slurm") -def test_submitit(): +def test_submitit() -> None: executor = submitit.SlurmExecutor(folder="logs") executor.update_parameters( diff --git a/tests/test_train.py b/tests/test_train.py index d28f5ef..9f63728 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -3,23 +3,23 @@ import torchrunx as trx -def worker(): +def worker() -> None: import torch - class TwoLinLayerNet(torch.nn.Module): - def __init__(self): + class TwoLayerNN(torch.nn.Module): + def __init__(self) -> None: super().__init__() self.a = torch.nn.Linear(10, 10, bias=False) self.b = torch.nn.Linear(10, 1, bias=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: a = self.a(x) - b = self.b(x) - return (a, b) + b = self.b(a) + return b local_rank = int(os.environ["LOCAL_RANK"]) print("init model") - model = TwoLinLayerNet().to(local_rank) + model = TwoLayerNN().to(local_rank) print("init ddp") ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) @@ -28,11 +28,11 @@ def forward(self, x): for _ in range(20): output = ddp_model(inp) - loss = output[0] + output[1] - loss.sum().backward() + loss = output.sum() + loss.backward() -def test_distributed_train(): +def test_distributed_train() -> None: trx.launch( worker, hostnames="slurm", From 52220fbb73e7a56ee6f1fe254de2c273d12edb1c Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Wed, 11 Sep 2024 16:40:17 -0400 Subject: [PATCH 3/6] print traceback in agent; import as dist_mp --- src/torchrunx/agent.py | 31 ++++++++++--------------------- tests/test_CI.py | 5 +++-- 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 96aa383..4777c8f 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -6,13 +6,14 @@ import socket import sys import tempfile +import traceback from dataclasses import dataclass from typing import Any, Callable, Literal import cloudpickle import torch import torch.distributed as dist -from torch.distributed.elastic.multiprocessing import start_processes +import torch.distributed.elastic.multiprocessing as dist_mp from typing_extensions import Self from .logging_utils import log_records_to_socket, redirect_stdio_to_logger @@ -70,11 +71,7 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException: is_master=(worker_args.rank == 0), ) - backend = worker_args.backend - if backend is None: - backend = "nccl" if torch.cuda.is_available() else "gloo" - - logger.debug(f"using backend: {backend}") + backend = worker_args.backend or ("nccl" if torch.cuda.is_available() else "gloo") dist.init_process_group( backend=backend, @@ -91,12 +88,10 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException: os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) - logger.debug(f"executing function: {worker_args.function}") - try: return worker_args.function() except Exception as e: - logger.error(e) + traceback.print_exc() return WorkerException(exception=e) finally: sys.stdout.flush() @@ -132,16 +127,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ redirect_stdio_to_logger(logger) - if torch.__version__ >= "2.3": - from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs - - log_kwargs = {"logs_specs": DefaultLogsSpecs(log_dir=tempfile.mkdtemp())} - else: - log_kwargs = {"log_dir": tempfile.mkdtemp()} - # spawn workers - ctx = start_processes( + ctx = dist_mp.start_processes( name=f"{hostname}_", entrypoint=entrypoint, args={ @@ -164,9 +152,12 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ for i in range(num_workers) }, envs={i: {} for i in range(num_workers)}, - **log_kwargs, # pyright: ignore [reportArgumentType] + **( + {"logs_specs": dist_mp.DefaultLogsSpecs(log_dir=tempfile.mkdtemp())} + if torch.__version__ >= "2.3" + else {"log_dir": tempfile.mkdtemp()} + ), # pyright: ignore [reportArgumentType] ) - logger.info("starting processes") try: status = None @@ -182,8 +173,6 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ break elif any(s.state == "failed" for s in agent_statuses): break - except: - raise finally: ctx.close() sys.stdout.flush() diff --git a/tests/test_CI.py b/tests/test_CI.py index bcd70bf..bc3e268 100644 --- a/tests/test_CI.py +++ b/tests/test_CI.py @@ -70,8 +70,9 @@ def dist_func() -> None: assert "worker rank: 0\n" in contents elif file.endswith("[1].log"): assert "worker rank: 1\n" in contents - else: - assert "starting processes" in contents + # TODO ? + # else: + # assert "starting processes" in contents def test_error() -> None: From bd52e7a6cef213bd860d56768dc1432e894b3940 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Thu, 12 Sep 2024 17:44:22 -0400 Subject: [PATCH 4/6] added more ruff lint rules --- .github/workflows/main.yml | 2 +- pixi.lock | 4 +- pyproject.toml | 22 +++++- src/torchrunx/agent.py | 10 ++- src/torchrunx/environment.py | 17 +++-- src/torchrunx/launcher.py | 118 +++++++++++++++++-------------- src/torchrunx/logging_utils.py | 24 ++++--- src/torchrunx/utils.py | 29 ++++---- tests/{test_CI.py => test_ci.py} | 22 +++--- tests/test_submitit.py | 2 +- tests/test_train.py | 8 +-- 11 files changed, 143 insertions(+), 115 deletions(-) rename tests/{test_CI.py => test_ci.py} (81%) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c172693..45a3b48 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -86,4 +86,4 @@ jobs: cache: false environments: default activate-environment: default - - run: pytest tests/test_CI.py + - run: pytest tests/test_ci.py diff --git a/pixi.lock b/pixi.lock index 377385d..e4fae5c 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2601,9 +2601,9 @@ packages: requires_python: '>=3.8.0' - kind: pypi name: torchrunx - version: 0.1.3 + version: 0.2.0 path: . - sha256: 0a30b1182ca7c101ff1d147eba62de2ba883f822fdedd13fa49207c5484f6cd8 + sha256: 1753f43bee54bc0da38cdd524dc501c0c2be9fbaaa7036bced9c9d03a7a8e810 requires_dist: - cloudpickle>=3.0.0 - fabric>=3.0.0 diff --git a/pyproject.toml b/pyproject.toml index 693b5f1..33acfae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "torchrunx" -version = "0.1.3" +version = "0.2.0" authors = [ {name = "Apoorv Khandelwal", email = "mail@apoorvkh.com"}, {name = "Peter Curtin", email = "peter_curtin@brown.edu"}, @@ -41,8 +41,24 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"] line-length = 100 src = ["src", "tests"] [tool.ruff.lint] -select = ["E", "F", "W", "ANN", "B", "UP", "I"] -ignore = ["ANN101", "ANN102", "ANN401"] +select = ["ALL"] +ignore = [ + "D", # documentation + "ANN101", "ANN102", "ANN401", # self / cls / Any annotations + "BLE001", # blind exceptions + "TD", # todo syntax + "FIX002", # existing todos + "PLR0913", # too many arguments + "DTZ005", # datetime timezone + "S301", # bandit: pickle + "S603", "S607", # bandit: subprocess + "COM812", "ISC001", # conflict with formatter +] +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = [ + "S101", # allow asserts + "T201" # allow prints +] [tool.pyright] include = ["src", "tests"] diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 4777c8f..37b3cb4 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -163,15 +163,13 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ status = None while True: if status is None or status.state == "running": - status = AgentStatus.from_result( - result=ctx.wait(5), worker_global_ranks=worker_global_ranks - ) + status = AgentStatus.from_result(ctx.wait(5)) agent_statuses = launcher_agent_group.sync_agent_statuses(status=status) - if all(s.state == "done" for s in agent_statuses): - break - elif any(s.state == "failed" for s in agent_statuses): + all_done = all(s.state == "done" for s in agent_statuses) + any_failed = any(s.state == "failed" for s in agent_statuses) + if all_done or any_failed: break finally: ctx.close() diff --git a/src/torchrunx/environment.py b/src/torchrunx/environment.py index edf1431..179cfb8 100644 --- a/src/torchrunx/environment.py +++ b/src/torchrunx/environment.py @@ -17,7 +17,9 @@ def slurm_hosts() -> list[str]: :rtype: list[str] """ # TODO: sanity check SLURM variables, commands - assert in_slurm_job() + if not in_slurm_job(): + msg = "Not in a SLURM job" + raise RuntimeError(msg) return ( subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]) .decode() @@ -35,15 +37,18 @@ def slurm_workers() -> int: :rtype: int """ # TODO: sanity check SLURM variables, commands - assert in_slurm_job() + if not in_slurm_job(): + msg = "Not in a SLURM job" + raise RuntimeError(msg) + if "SLURM_JOB_GPUS" in os.environ: # TODO: is it possible to allocate uneven GPUs across nodes? return len(os.environ["SLURM_JOB_GPUS"].split(",")) - elif "SLURM_GPUS_PER_NODE" in os.environ: + if "SLURM_GPUS_PER_NODE" in os.environ: return int(os.environ["SLURM_GPUS_PER_NODE"]) - else: - # TODO: should we assume that we plan to do one worker per CPU? - return int(os.environ["SLURM_CPUS_ON_NODE"]) + + # TODO: should we assume that we plan to do one worker per CPU? + return int(os.environ["SLURM_CPUS_ON_NODE"]) def auto_hosts() -> list[str]: diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 58011ae..4c826a6 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -5,6 +5,7 @@ import itertools import logging import os +import shlex import socket import subprocess import sys @@ -31,13 +32,14 @@ def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: if hostnames == "auto": return auto_hosts() - elif hostnames == "slurm": + if hostnames == "slurm": return slurm_hosts() return hostnames def resolve_workers_per_host( - workers_per_host: int | list[int] | Literal["auto", "slurm"], num_hosts: int + workers_per_host: int | list[int] | Literal["auto", "slurm"], + num_hosts: int, ) -> list[int]: if workers_per_host == "auto": workers_per_host = auto_workers() @@ -46,8 +48,9 @@ def resolve_workers_per_host( if isinstance(workers_per_host, int): workers_per_host = [workers_per_host] * num_hosts - else: - assert len(workers_per_host) == num_hosts + elif len(workers_per_host) != num_hosts: + msg = "len(workers_per_host) != len(hostnames)" + raise ValueError(msg) return workers_per_host @@ -70,13 +73,53 @@ def build_logging_server( log_level=log_level, ) - log_receiver = LogRecordSocketReceiver( + return LogRecordSocketReceiver( host=launcher_hostname, port=get_open_port(), handlers=log_handlers, ) - return log_receiver + +def build_command( + launcher_hostname: str, + launcher_port: int, + logger_port: int, + world_size: int, + rank: int, + env_vars: Sequence[str], + env_file: str | os.PathLike | None, +) -> str: + # shlex.quote prevents shell injection here (resolves S602 in execute_command) + + commands = [] + + current_dir = shlex.quote(str(Path.cwd())) + commands.append("cd " + current_dir) + + env_exports = [] + for k, v in os.environ.items(): + if any(fnmatch.fnmatch(k, e) for e in env_vars): + env_exports.append(shlex.quote(f"{k}={v}")) + + if len(env_exports) > 0: + commands.append("export " + " ".join(env_exports)) + + if env_file is not None: + commands.append("source " + shlex.quote(str(env_file))) + + python = shlex.quote(sys.executable) + launcher_hostname = shlex.quote(launcher_hostname) + + commands.append( + f"{python} -u -m torchrunx " + f"--launcher-hostname {launcher_hostname} " + f"--launcher-port {launcher_port} " + f"--logger-port {logger_port} " + f"--world-size {world_size} " + f"--rank {rank}", + ) + + return " && ".join(commands) def is_localhost(hostname_or_ip: str) -> bool: @@ -99,51 +142,17 @@ def execute_command( ssh_config_file: str | os.PathLike | None = None, ) -> None: if is_localhost(hostname): - subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + # S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.8/library/subprocess.html#security-considerations) + # Made sure to shlex.quote arguments in build_command to prevent shell injection + subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602 else: with fabric.Connection( - host=hostname, config=fabric.Config(runtime_ssh_path=ssh_config_file) + host=hostname, + config=fabric.Config(runtime_ssh_path=ssh_config_file), ) as conn: conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True) -def build_command( - launcher_hostname: str, - launcher_port: int, - logger_port: int, - world_size: int, - rank: int, - env_vars: Sequence[str], - env_file: str | os.PathLike | None, -) -> str: - current_dir = os.getcwd() - - env_exports = [] - for k, v in os.environ.items(): - if any(fnmatch.fnmatch(k, e) for e in env_vars): - env_exports.append(f"{k}={v}") - - env_export_string = "" - if len(env_exports) > 0: - env_export_string = f"export {' '.join(env_exports)} && " - - env_file_string = "" - if env_file is not None: - env_file_string = f"source {env_file} && " - - return ( - f"cd {current_dir} && " - f"{env_export_string}" - f"{env_file_string}" - f"{sys.executable} -u -m torchrunx " - f"--launcher-hostname {launcher_hostname} " - f"--launcher-port {launcher_port} " - f"--logger-port {logger_port} " - f"--world-size {world_size} " - f"--rank {rank}" - ) - - @dataclass class Launcher: hostnames: list[str] | Literal["auto", "slurm"] = "auto" @@ -184,7 +193,8 @@ def run( :rtype: dict[int, Any] """ if not dist.is_available(): - raise RuntimeError("The torch.distributed package is not available.") + msg = "The torch.distributed package is not available." + raise RuntimeError(msg) hostnames = resolve_hostnames(self.hostnames) workers_per_host = resolve_workers_per_host(self.workers_per_host, len(hostnames)) @@ -201,7 +211,7 @@ def run( hostnames=hostnames, workers_per_host=workers_per_host, log_dir=Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")), - log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], + log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], # noqa: SLF001 ) log_process = Process( @@ -228,7 +238,7 @@ def run( ssh_config_file=self.ssh_config_file, ) - # initialize launcher–agent process group + # initialize launcher-agent process group # ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1]) launcher_agent_group = LauncherAgentGroup( @@ -240,7 +250,7 @@ def run( # build and sync payloads between launcher and agents - _cumulative_workers = [0] + list(itertools.accumulate(workers_per_host)) + _cumulative_workers = [0, *itertools.accumulate(workers_per_host)] worker_global_ranks = [ list(range(_cumulative_workers[n], _cumulative_workers[n + 1])) @@ -262,14 +272,14 @@ def run( try: while True: - agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) # raises exception if communication timeout due to death of any agent + agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) + # raises exception if any agent failed for s in agent_statuses: - if s.state == "failed": - for value in s.return_values.values(): - if isinstance(value, WorkerException): - raise value.exception + for value in s.return_values.values(): + if isinstance(value, WorkerException): + raise value.exception if all(s.state == "done" for s in agent_statuses): break diff --git a/src/torchrunx/logging_utils.py b/src/torchrunx/logging_utils.py index 20a0051..36ec67b 100644 --- a/src/torchrunx/logging_utils.py +++ b/src/torchrunx/logging_utils.py @@ -2,7 +2,7 @@ import datetime import logging -import os +import os # noqa: TCH003 import pickle import struct from contextlib import redirect_stderr, redirect_stdout @@ -52,11 +52,11 @@ def file_handlers( ) -> list[Handler]: handlers = [] - os.makedirs(log_dir, exist_ok=True) + Path(log_dir).mkdir(parents=True, exist_ok=True) timestamp = datetime.datetime.now().isoformat(timespec="seconds") for hostname, num_workers in zip(hostnames, workers_per_host): - for rank in [None] + list(range(num_workers)): + for rank in [None, *range(num_workers)]: file_path = ( f"{log_dir}/{timestamp}-{hostname}" + (f"[{rank}]" if rank is not None else "") @@ -74,8 +74,8 @@ def stream_handler(hostname: str, rank: int | None, log_level: int = logging.NOT logging.Formatter( "%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s" if rank is not None - else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s" - ) + else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s", + ), ) return handler @@ -89,7 +89,8 @@ def default_handlers( return [ stream_handler(hostname=hostnames[0], rank=None, log_level=log_level), stream_handler(hostname=hostnames[0], rank=0, log_level=log_level), - ] + file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level) + *file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level), + ] ## Agent/worker utilities @@ -128,11 +129,11 @@ def flush(self) -> None: super().flush() value = self.getvalue() if value != "": - self.logger.log(self.level, f"\n{value}") + self.logger.log(self.level, value) self.truncate(0) self.seek(0) - logging.captureWarnings(True) + logging.captureWarnings(capture=True) redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__() redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__() @@ -148,8 +149,9 @@ def __init__(self, host: str, port: int, handlers: list[Handler]) -> None: class _LogRecordStreamHandler(StreamRequestHandler): def handle(self) -> None: while True: - chunk = self.connection.recv(4) - if len(chunk) < 4: + chunk_size = 4 + chunk = self.connection.recv(chunk_size) + if len(chunk) < chunk_size: break slen = struct.unpack(">L", chunk)[0] chunk = self.connection.recv(slen) @@ -157,7 +159,7 @@ def handle(self) -> None: chunk = chunk + self.connection.recv(slen - len(chunk)) obj = pickle.loads(chunk) record = logging.makeLogRecord(obj) - # + for handler in handlers: handler.handle(record) diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 82274e6..c2559fe 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -4,19 +4,20 @@ import socket from contextlib import closing from dataclasses import dataclass, field -from typing import Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal import cloudpickle import torch.distributed as dist -from torch.distributed.elastic.multiprocessing.api import RunProcsResult from typing_extensions import Self +if TYPE_CHECKING: + from torch.distributed.elastic.multiprocessing.api import RunProcsResult + def get_open_port() -> int: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) - port = s.getsockname()[1] - return port + return s.getsockname()[1] @dataclass @@ -47,7 +48,7 @@ class AgentStatus: return_values: dict[int, Any | WorkerException] = field(default_factory=dict) @classmethod - def from_result(cls, result: RunProcsResult | None, worker_global_ranks: list[int]) -> Self: + def from_result(cls, result: RunProcsResult | None) -> Self: if result is None: return cls(state="running") @@ -85,27 +86,27 @@ def __post_init__(self) -> None: timeout=datetime.timedelta(seconds=30), ) - def _serialize(self, object: Any) -> bytes: - return cloudpickle.dumps(object) + def _serialize(self, obj: Any) -> bytes: + return cloudpickle.dumps(obj) def _deserialize(self, serialized: bytes) -> Any: return cloudpickle.loads(serialized) - def _all_gather(self, object: Any) -> list: + def _all_gather(self, obj: Any) -> list: """gather object from every rank to list on every rank""" - object_bytes = self._serialize(object) + object_bytes = self._serialize(obj) object_list = [b""] * self.world_size dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group) - object_list = [self._deserialize(o) for o in object_list] - return object_list + return [self._deserialize(o) for o in object_list] def sync_payloads( - self, payload: LauncherPayload | AgentPayload + self, + payload: LauncherPayload | AgentPayload, ) -> tuple[LauncherPayload, list[AgentPayload]]: - payloads = self._all_gather(object=payload) + payloads = self._all_gather(payload) launcher_payload = payloads[0] agent_payloads = payloads[1:] return launcher_payload, agent_payloads def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]: - return self._all_gather(object=status)[1:] # [0] is launcher (status=None) + return self._all_gather(status)[1:] # [0] is launcher (status=None) diff --git a/tests/test_CI.py b/tests/test_ci.py similarity index 81% rename from tests/test_CI.py rename to tests/test_ci.py index bc3e268..f72f3ef 100644 --- a/tests/test_CI.py +++ b/tests/test_ci.py @@ -1,5 +1,6 @@ import os import tempfile +from pathlib import Path from typing import NoReturn import pytest @@ -13,10 +14,7 @@ def test_simple_localhost() -> None: def dist_func() -> torch.Tensor: rank = int(os.environ["RANK"]) - if rank == 0: - w = torch.rand((100, 100)) # in_dim, out_dim - else: - w = torch.zeros((100, 100)) + w = torch.rand((100, 100)) if rank == 0 else torch.zeros((100, 100)) dist.broadcast(w, 0) @@ -51,38 +49,38 @@ def dist_func() -> None: tmp = tempfile.mkdtemp() os.environ["TORCHRUNX_LOG_DIR"] = tmp + num_workers = 2 + trx.launch( func=dist_func, func_kwargs={}, - workers_per_host=2, + workers_per_host=num_workers, backend="gloo", ) log_files = next(os.walk(tmp), (None, None, []))[2] - assert len(log_files) == 3 + assert len(log_files) == num_workers + 1 for file in log_files: - with open(f"{tmp}/{file}") as f: + with Path(f"{tmp}/{file}").open() as f: contents = f.read() print(contents) if file.endswith("[0].log"): assert "worker rank: 0\n" in contents elif file.endswith("[1].log"): assert "worker rank: 1\n" in contents - # TODO ? - # else: - # assert "starting processes" in contents def test_error() -> None: def error_func() -> NoReturn: - raise ValueError("abcdefg") + msg = "abcdefg" + raise ValueError(msg) tmp = tempfile.mkdtemp() os.environ["TORCHRUNX_DIR"] = tmp - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError) as excinfo: # noqa: PT011 trx.launch( func=error_func, func_kwargs={}, diff --git a/tests/test_submitit.py b/tests/test_submitit.py index 225268d..433e338 100644 --- a/tests/test_submitit.py +++ b/tests/test_submitit.py @@ -40,7 +40,7 @@ def main() -> None: ) trainer = Trainer( - model=model, # type: ignore + model=model, args=training_arguments, train_dataset=train_dataset, ) diff --git a/tests/test_train.py b/tests/test_train.py index 9f63728..b654a8b 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -6,20 +6,18 @@ def worker() -> None: import torch - class TwoLayerNN(torch.nn.Module): + class MLP(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = torch.nn.Linear(10, 10, bias=False) self.b = torch.nn.Linear(10, 1, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - a = self.a(x) - b = self.b(a) - return b + return self.b(self.a(x)) local_rank = int(os.environ["LOCAL_RANK"]) print("init model") - model = TwoLayerNN().to(local_rank) + model = MLP().to(local_rank) print("init ddp") ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) From e4ae2200865cb1fd33fb7f4c45a876051d6db28b Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Thu, 12 Sep 2024 18:04:39 -0400 Subject: [PATCH 5/6] refactoring worker args serialization --- src/torchrunx/agent.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 37b3cb4..030316a 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -14,7 +14,6 @@ import torch import torch.distributed as dist import torch.distributed.elastic.multiprocessing as dist_mp -from typing_extensions import Self from .logging_utils import log_records_to_socket, redirect_stdio_to_logger from .utils import ( @@ -41,16 +40,20 @@ class WorkerArgs: hostname: str timeout: int - def to_bytes(self) -> bytes: - return cloudpickle.dumps(self) + def serialize(self) -> SerializedWorkerArgs: + return SerializedWorkerArgs(worker_args=self) - @classmethod - def from_bytes(cls, serialized: bytes) -> Self: - return cloudpickle.loads(serialized) +class SerializedWorkerArgs: + def __init__(self, worker_args: WorkerArgs) -> None: + self.bytes = cloudpickle.dumps(worker_args) -def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException: - worker_args = WorkerArgs.from_bytes(serialized_worker_args) + def deserialize(self) -> WorkerArgs: + return cloudpickle.loads(self.bytes) + + +def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException: + worker_args: WorkerArgs = serialized_worker_args.deserialize() logger = logging.getLogger() @@ -147,7 +150,7 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ world_size=worker_world_size, hostname=launcher_payload.hostnames[agent_rank], timeout=launcher_payload.timeout, - ).to_bytes(), + ).serialize(), ) for i in range(num_workers) }, From 9c97d09fb362a3f15ac8c7e563d5ea6ca29993de Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Thu, 12 Sep 2024 18:38:31 -0400 Subject: [PATCH 6/6] WorkerLogRecord class --- src/torchrunx/__init__.py | 4 + src/torchrunx/agent.py | 2 +- src/torchrunx/logging_utils.py | 219 ++++++++++++++++++--------------- src/torchrunx/utils.py | 2 +- 4 files changed, 124 insertions(+), 103 deletions(-) diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index 46b3b1b..74214cb 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,6 +1,10 @@ from .launcher import Launcher, launch +from .logging_utils import add_filter_to_handler, file_handler, stream_handler __all__ = [ "Launcher", "launch", + "add_filter_to_handler", + "file_handler", + "stream_handler", ] diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 030316a..04d1ec9 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -67,7 +67,7 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce redirect_stdio_to_logger(logger) - store = dist.TCPStore( # pyright: ignore[reportPrivateImportUsage] + store = dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] host_name=worker_args.main_agent_hostname, port=worker_args.main_agent_port, world_size=worker_args.world_size, diff --git a/src/torchrunx/logging_utils.py b/src/torchrunx/logging_utils.py index 36ec67b..d12b27f 100644 --- a/src/torchrunx/logging_utils.py +++ b/src/torchrunx/logging_utils.py @@ -2,15 +2,115 @@ import datetime import logging -import os # noqa: TCH003 import pickle import struct from contextlib import redirect_stderr, redirect_stdout +from dataclasses import dataclass from io import StringIO from logging import Handler, Logger from logging.handlers import SocketHandler from pathlib import Path from socketserver import StreamRequestHandler, ThreadingTCPServer +from typing import TYPE_CHECKING + +from typing_extensions import Self + +if TYPE_CHECKING: + import os + +## Launcher utilities + + +class LogRecordSocketReceiver(ThreadingTCPServer): + def __init__(self, host: str, port: int, handlers: list[Handler]) -> None: + self.host = host + self.port = port + + class _LogRecordStreamHandler(StreamRequestHandler): + def handle(self) -> None: + while True: + chunk_size = 4 + chunk = self.connection.recv(chunk_size) + if len(chunk) < chunk_size: + break + slen = struct.unpack(">L", chunk)[0] + chunk = self.connection.recv(slen) + while len(chunk) < slen: + chunk = chunk + self.connection.recv(slen - len(chunk)) + obj = pickle.loads(chunk) + record = logging.makeLogRecord(obj) + + for handler in handlers: + handler.handle(record) + + super().__init__( + server_address=(host, port), + RequestHandlerClass=_LogRecordStreamHandler, + bind_and_activate=True, + ) + self.daemon_threads = True + + def shutdown(self) -> None: + """override BaseServer.shutdown() with added timeout""" + self._BaseServer__shutdown_request = True + self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue] + + +## Agent/worker utilities + + +@dataclass +class WorkerLogRecord(logging.LogRecord): + hostname: str + worker_rank: int | None + + @classmethod + def from_record(cls, record: logging.LogRecord, hostname: str, worker_rank: int | None) -> Self: + record.hostname = hostname + record.worker_rank = worker_rank + record.__class__ = cls + return record # pyright: ignore [reportReturnType] + + +def log_records_to_socket( + logger: Logger, + hostname: str, + worker_rank: int | None, + logger_hostname: str, + logger_port: int, +) -> None: + logger.setLevel(logging.NOTSET) + + old_factory = logging.getLogRecordFactory() + + def record_factory(*args, **kwargs) -> WorkerLogRecord: # noqa: ANN002, ANN003 + record = old_factory(*args, **kwargs) + return WorkerLogRecord.from_record(record, hostname, worker_rank) + + logging.setLogRecordFactory(record_factory) + + logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port)) + + +def redirect_stdio_to_logger(logger: Logger) -> None: + class _LoggingStream(StringIO): + def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None: + super().__init__() + self.logger = logger + self.level = level + + def flush(self) -> None: + super().flush() + value = self.getvalue() + if value != "": + self.logger.log(self.level, value) + self.truncate(0) + self.seek(0) + + logging.captureWarnings(capture=True) + redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__() + redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__() + ## Handler utilities @@ -21,14 +121,27 @@ def add_filter_to_handler( worker_rank: int | None, log_level: int = logging.NOTSET, ) -> None: - def _filter(record: logging.LogRecord) -> bool: + def _filter(record: WorkerLogRecord) -> bool: return ( - record.hostname == hostname # pyright: ignore[reportAttributeAccessIssue] - and record.worker_rank == worker_rank # pyright: ignore[reportAttributeAccessIssue] + record.hostname == hostname + and record.worker_rank == worker_rank and record.levelno >= log_level ) - handler.addFilter(_filter) + handler.addFilter(_filter) # pyright: ignore [reportArgumentType] + + +def stream_handler(hostname: str, rank: int | None, log_level: int = logging.NOTSET) -> Handler: + handler = logging.StreamHandler() + add_filter_to_handler(handler, hostname, rank, log_level=log_level) + handler.setFormatter( + logging.Formatter( + "%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s" + if rank is not None + else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s", + ), + ) + return handler def file_handler( @@ -67,19 +180,6 @@ def file_handlers( return handlers -def stream_handler(hostname: str, rank: int | None, log_level: int = logging.NOTSET) -> Handler: - handler = logging.StreamHandler() - add_filter_to_handler(handler, hostname, rank, log_level=log_level) - handler.setFormatter( - logging.Formatter( - "%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s" - if rank is not None - else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s", - ), - ) - return handler - - def default_handlers( hostnames: list[str], workers_per_host: list[int], @@ -91,86 +191,3 @@ def default_handlers( stream_handler(hostname=hostnames[0], rank=0, log_level=log_level), *file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level), ] - - -## Agent/worker utilities - - -def log_records_to_socket( - logger: Logger, - hostname: str, - worker_rank: int | None, - logger_hostname: str, - logger_port: int, -) -> None: - logger.setLevel(logging.NOTSET) - - old_factory = logging.getLogRecordFactory() - - def record_factory(*args, **kwargs) -> logging.LogRecord: # noqa: ANN002, ANN003 - record = old_factory(*args, **kwargs) - record.hostname = hostname - record.worker_rank = worker_rank - return record - - logging.setLogRecordFactory(record_factory) - - logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port)) - - -def redirect_stdio_to_logger(logger: Logger) -> None: - class _LoggingStream(StringIO): - def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None: - super().__init__() - self.logger = logger - self.level = level - - def flush(self) -> None: - super().flush() - value = self.getvalue() - if value != "": - self.logger.log(self.level, value) - self.truncate(0) - self.seek(0) - - logging.captureWarnings(capture=True) - redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__() - redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__() - - -## Launcher utilities - - -class LogRecordSocketReceiver(ThreadingTCPServer): - def __init__(self, host: str, port: int, handlers: list[Handler]) -> None: - self.host = host - self.port = port - - class _LogRecordStreamHandler(StreamRequestHandler): - def handle(self) -> None: - while True: - chunk_size = 4 - chunk = self.connection.recv(chunk_size) - if len(chunk) < chunk_size: - break - slen = struct.unpack(">L", chunk)[0] - chunk = self.connection.recv(slen) - while len(chunk) < slen: - chunk = chunk + self.connection.recv(slen - len(chunk)) - obj = pickle.loads(chunk) - record = logging.makeLogRecord(obj) - - for handler in handlers: - handler.handle(record) - - super().__init__( - server_address=(host, port), - RequestHandlerClass=_LogRecordStreamHandler, - bind_and_activate=True, - ) - self.daemon_threads = True - - def shutdown(self) -> None: - """override BaseServer.shutdown() with added timeout""" - self._BaseServer__shutdown_request = True - self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue] diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index c2559fe..1bd25c5 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -77,7 +77,7 @@ def __post_init__(self) -> None: backend="gloo", world_size=self.world_size, rank=self.rank, - store=dist.TCPStore( # pyright: ignore[reportPrivateImportUsage] + store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] host_name=self.launcher_hostname, port=self.launcher_port, world_size=self.world_size,