Skip to content

Commit

Permalink
DDP Enhancements (mosaicml#63)
Browse files Browse the repository at this point in the history
Added timeout to the hparams for initialize_process_group. The default of 30 minutes was too long for failing tests, which prevents one from getting any meaningful log.
In cleanup(), using os.killpg to terminate DDP subprocesses instead of just subprocess.kill(). It appears that sometimes zombie processes would be still there (e.g. from ddp / dataloader workers)
In cleanup(), attempting a sigterm before resorting to a sigkill() 5 seconds later. Graceful cleanup is preferred!
Directing output from stdout and stderr to tempfiles instead of subprocess.PIPE, which can hang if a subprocess generates significant output
  • Loading branch information
ravi-mosaicml authored and coryMosaicML committed Feb 23, 2022
1 parent 26920d4 commit 016c2f5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 32 deletions.
100 changes: 70 additions & 30 deletions composer/trainer/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
from __future__ import annotations

import collections.abc
import datetime
import logging
import os
import signal
import subprocess
import sys
import tempfile
import time
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from threading import Thread
from typing import Callable, ContextManager, Iterator, List, Optional, Sequence, TypeVar, cast
from typing import Callable, ContextManager, Iterator, List, Optional, Sequence, Set, TypeVar, Union, cast

import torch
import torch.distributed
Expand All @@ -32,6 +35,8 @@

TObj = TypeVar("TObj")

CLEANUP_TIMEOUT = datetime.timedelta(seconds=5)


class DataloaderMultipleIterationWarning(Warning):
pass
Expand Down Expand Up @@ -114,17 +119,20 @@ def __init__(self,
num_nodes: int,
backend: str,
fork_rank_0: bool,
timeout: float = 5,
find_unused_parameters: bool = False,
ddp_sync_strategy: Optional[str] = None):
ddp_sync_strategy: Optional[Union[str, DDPSyncStrategy]] = None):
self.hparams = DDPHparams(
store=store_hparams,
node_rank=node_rank,
num_nodes=num_nodes,
fork_rank_0=fork_rank_0,
timeout=timeout,
)
self.nproc_per_node = nproc_per_node
self.world_size = num_nodes * nproc_per_node
self.num_nodes = num_nodes
self.node_rank = node_rank
self.store_hparams = store_hparams
self.last_return_code: Optional[int] = None
self.backend = backend
self.fork_rank_0 = fork_rank_0
self.processes: List[subprocess.Popen[str]] = []
self.killed_pids: Set[int] = set() # track which pids have been killed
self.find_unused_parameters = find_unused_parameters
if ddp_sync_strategy is None:
self.ddp_sync_strategy = DDPSyncStrategy.SINGLE_AUTO_SYNC if not find_unused_parameters else DDPSyncStrategy.FORCED_SYNC
Expand All @@ -140,6 +148,10 @@ def __init__(self,
if not torch.distributed.is_nccl_available():
raise ValueError('Requested NCCL backend not available in torch.distributed')

@property
def world_size(self) -> int:
return self.hparams.num_nodes * self.nproc_per_node

def barrier(self) -> None:
if torch.distributed.is_available():
torch.distributed.barrier()
Expand Down Expand Up @@ -194,7 +206,7 @@ def all_gather_object(self, obj: TObj) -> List[TObj]:
def launch(self, state: State, loop: Callable[[], None]):
if os.environ.get("RANK") is None:
os.environ["WORLD_SIZE"] = str(self.world_size)
logger.info("Starting DDP on node_rank(%d) with world_size(%d)", self.node_rank, self.world_size)
logger.info("Starting DDP on node_rank(%d) with world_size(%d)", self.hparams.node_rank, self.world_size)

if torch.distributed.is_available():
# Adapted from torch.distributed.launch
Expand All @@ -205,14 +217,14 @@ def launch(self, state: State, loop: Callable[[], None]):
# TODO omp num threads -- this parameter needs to be auto-tuned
for local_rank in range(self.nproc_per_node):
# each process's rank
global_rank = self.nproc_per_node * self.node_rank + local_rank
global_rank = self.nproc_per_node * self.hparams.node_rank + local_rank
current_env["RANK"] = str(global_rank)

if local_rank == 0 and not self.fork_rank_0:
if local_rank == 0 and not self.hparams.fork_rank_0:
os.environ["RANK"] = str(global_rank)
else:
logger.info("Launching process for global_rank(%d) on node_rank(%d)", global_rank,
self.node_rank)
self.hparams.node_rank)
# spawn the processes
cmd = [
sys.executable,
Expand All @@ -229,35 +241,36 @@ def launch(self, state: State, loop: Callable[[], None]):
process = subprocess.Popen(
cmd,
env=current_env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=tempfile.TemporaryFile(),
stderr=tempfile.TemporaryFile(),
text=True,
)
self.processes.append(process)
if self.fork_rank_0:
if self.hparams.fork_rank_0:
self.monitor()
return
else:
Thread(target=self.monitor, daemon=True).start()
else:
if self.world_size != 1:
raise ValueError("Must have world size == 1 when torch.distributed is not available")
if self.node_rank != 0:
if self.hparams.node_rank != 0:
raise ValueError("Must have a node_rank == 0 when torch.distributed is not available")
os.environ["RANK"] = "0"
# We are now on the correct process
global_rank = int(os.environ["RANK"])
assert global_rank // self.world_size == self.node_rank
assert global_rank // self.world_size == self.hparams.node_rank
assert os.environ["WORLD_SIZE"] == str(
self.world_size
), f"os.environ['WORLD_SIZE']({os.environ['WORLD_SIZE']}) != self.world_size({self.world_size})"
is_main = global_rank == 0
if torch.distributed.is_available():
logger.info("Initializing ddp: GLOBAL_RANK: %s, WORLD_SIZE: %s", global_rank, self.world_size)
store = self.store_hparams.initialize_object(is_main, state.world_size)
store = self.hparams.store.initialize_object(is_main, state.world_size)
torch.distributed.init_process_group(self.backend,
rank=global_rank,
world_size=self.world_size,
timeout=datetime.timedelta(seconds=self.hparams.timeout),
store=store)
assert torch.distributed.is_initialized()
assert state.is_rank_set, "state.is_rank_set should be set after torch.distributed is initialized"
Expand Down Expand Up @@ -308,7 +321,10 @@ def monitor(self) -> None:
# return code of 0 implies clean exit
# return code of -9 implies sigkill, presumably from
# cleanup() in the main process
if process.returncode not in (0, -9):
if process.pid in self.killed_pids or process.returncode == 0:
# exited cleanly
finished_processes.append(process)
else:
if process.stdout is None:
output = ""
else:
Expand All @@ -324,24 +340,46 @@ def monitor(self) -> None:
output=output,
stderr=stderr,
)
if self.fork_rank_0:
if self.hparams.fork_rank_0:
raise exc
else:
logger.exception("Error in subprocess", exc_info=exc)
sys.exit(1)
else:
# exited cleanly
finished_processes.append(process)
error_msg = [
"Error in subprocess",
"----------Subprocess STDOUT----------",
exc.output,
"----------Subprocess STDERR----------",
exc.stderr,
]
logger.exception("\n".join(error_msg), exc_info=exc)
sys.exit(process.returncode)
alive_processes = set(alive_processes) - set(finished_processes)
time.sleep(1)

def cleanup(self) -> None:
for process in self.processes:
logger.info("Killing subprocess %s", process.pid)
try:
process.kill()
except Exception:
pass
if process.returncode is None:
logger.info("Killing subprocess %s with SIGTERM", process.pid)
self.killed_pids.add(process.pid)
try:
os.killpg(process.pid, signal.SIGTERM)
except ProcessLookupError:
pass
current_time = datetime.datetime.now()
while datetime.datetime.now() - current_time < CLEANUP_TIMEOUT:
all_finished = all([p.returncode is None for p in self.processes])
if all_finished:
break
time.sleep(0.1)

for process in self.processes:
if process.returncode is None:
logger.error("Failed to kill subprocess %s with SIGTERM, using SIGKILL instead", process.pid)
self.killed_pids.add(process.pid)
try:
os.killpg(process.pid, signal.SIGKILL)
except ProcessLookupError:
pass

if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()

Expand Down Expand Up @@ -420,6 +458,7 @@ class DDPHparams(hp.Hparams):
doc="Whether to fork the local rank 0 process, or use the existing process for rank 0 training.",
default=False,
)
timeout: float = hp.optional(doc="Timeout, in seconds, for initializing the DDP process group.", default=5.0)

def initialize_object(self, nproc_per_node: int, backend: str, find_unused_parameters: bool) -> DDP:
return DDP(
Expand All @@ -430,4 +469,5 @@ def initialize_object(self, nproc_per_node: int, backend: str, find_unused_param
num_nodes=self.num_nodes,
fork_rank_0=self.fork_rank_0,
find_unused_parameters=find_unused_parameters,
timeout=self.timeout,
)
9 changes: 7 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class Trainer:
(default: ``TCPStoreHparams("127.0.0.1", 43297)``)
fork_rank_0 (bool, optional): True to fork the rank 0 process in distributed data parallel,
False to not. (default: ``True``)
ddp_timeout (float, optional): Timeout, in seconds, for initializing the DDP process group.
(default: ``5.0``)
ddp_sync_strategy (DDPSyncStrategy, optional): The strategy to use for synchronizing gradients.
Leave unset to let the trainer auto-configure this.
seed (int, optional): The seed used in randomization. When not provided a random seed
Expand Down Expand Up @@ -150,6 +152,7 @@ def __init__(
# ddp hparams
ddp_store_hparams: Optional[StoreHparams] = None,
fork_rank_0: bool = False,
ddp_timeout: float = 5.0,
ddp_sync_strategy: Optional[str] = None,

# Randomness
Expand Down Expand Up @@ -206,6 +209,7 @@ def __init__(
backend=self.device.ddp_backend,
fork_rank_0=fork_rank_0,
find_unused_parameters=find_unused_parameters,
timeout=ddp_timeout,
ddp_sync_strategy=ddp_sync_strategy,
)

Expand Down Expand Up @@ -344,8 +348,9 @@ def create_from_hparams(cls, hparams: TrainerHparams) -> Trainer:
timeout=hparams.dataloader.timeout,

# ddp hparams
ddp_store_hparams=ddp.store_hparams,
fork_rank_0=ddp.fork_rank_0,
ddp_store_hparams=ddp.hparams.store,
fork_rank_0=ddp.hparams.fork_rank_0,
ddp_timeout=ddp.hparams.timeout,

# Randomness
seed=seed,
Expand Down

0 comments on commit 016c2f5

Please sign in to comment.