Skip to content

Commit

Permalink
[dask] hold ports until training (#5890)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jun 20, 2023
1 parent 07e3cf4 commit ac57d5a
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 111 deletions.
8 changes: 7 additions & 1 deletion python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class _LGBMRegressorBase: # type: ignore
from dask.bag import from_delayed as dask_bag_from_delayed
from dask.dataframe import DataFrame as dask_DataFrame
from dask.dataframe import Series as dask_Series
from dask.distributed import Client, default_client, wait
from dask.distributed import Client, Future, default_client, wait
DASK_INSTALLED = True
except ImportError:
DASK_INSTALLED = False
Expand All @@ -161,6 +161,12 @@ class Client: # type: ignore
def __init__(self, *args, **kwargs):
pass

class Future: # type: ignore
"""Dummy class for dask.distributed.Future."""

def __init__(self, *args, **kwargs):
pass

class dask_Array: # type: ignore
"""Dummy class for dask.array.Array."""

Expand Down
115 changes: 45 additions & 70 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
It is based on dask-lightgbm, which was based on dask-xgboost.
"""
import operator
import socket
from collections import defaultdict
from copy import deepcopy
Expand All @@ -18,7 +19,7 @@
import scipy.sparse as ss

from .basic import LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, Future, LGBMNotFittedError, concat,
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomObjectiveFunction,
Expand All @@ -38,18 +39,21 @@
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]


class _HostWorkers:
class _RemoteSocket:
def acquire(self) -> int:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(('', 0))
return self.socket.getsockname()[1]

def __init__(self, default: str, all_workers: List[str]):
self.default = default
self.all_workers = all_workers
def release(self) -> None:
self.socket.close()

def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self.default == other.default
and self.all_workers == other.all_workers
)

def _acquire_port() -> Tuple[_RemoteSocket, int]:
s = _RemoteSocket()
port = s.acquire()
return s, port


class _DatasetNames(Enum):
Expand Down Expand Up @@ -83,73 +87,40 @@ def _get_dask_client(client: Optional[Client]) -> Client:
return client


def _find_n_open_ports(n: int) -> List[int]:
"""Find n random open ports on localhost.
Returns
-------
ports : list of int
n random open ports on localhost.
"""
sockets = []
for _ in range(n):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
sockets.append(s)
ports = []
for s in sockets:
ports.append(s.getsockname()[1])
s.close()
return ports


def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWorkers]:
"""Group all worker addresses by hostname.
Returns
-------
host_to_workers : dict
mapping from hostname to all its workers.
"""
host_to_workers: Dict[str, _HostWorkers] = {}
for address in worker_addresses:
hostname = urlparse(address).hostname
if not hostname:
raise ValueError(f"Could not parse host name from worker address '{address}'")
if hostname not in host_to_workers:
host_to_workers[hostname] = _HostWorkers(default=address, all_workers=[address])
else:
host_to_workers[hostname].all_workers.append(address)
return host_to_workers


def _assign_open_ports_to_workers(
client: Client,
host_to_workers: Dict[str, _HostWorkers]
) -> Dict[str, int]:
workers: List[str],
) -> Tuple[Dict[str, Future], Dict[str, int]]:
"""Assign an open port to each worker.
Returns
-------
worker_to_socket_future: dict
mapping from worker address to a future pointing to the remote socket.
worker_to_port: dict
mapping from worker address to an open port.
mapping from worker address to an open port in the worker's host.
"""
host_ports_futures = {}
for hostname, workers in host_to_workers.items():
n_workers_in_host = len(workers.all_workers)
host_ports_futures[hostname] = client.submit(
_find_n_open_ports,
n=n_workers_in_host,
workers=[workers.default],
pure=False,
# Acquire port in worker
worker_to_future = {}
for worker in workers:
worker_to_future[worker] = client.submit(
_acquire_port,
workers=[worker],
allow_other_workers=False,
pure=False,
)
found_ports = client.gather(host_ports_futures)
worker_to_port = {}
for hostname, workers in host_to_workers.items():
for worker, port in zip(workers.all_workers, found_ports[hostname]):
worker_to_port[worker] = port
return worker_to_port

# schedule futures to retrieve each element of the tuple
worker_to_socket_future = {}
worker_to_port_future = {}
for worker, socket_future in worker_to_future.items():
worker_to_socket_future[worker] = client.submit(operator.itemgetter(0), socket_future)
worker_to_port_future[worker] = client.submit(operator.itemgetter(1), socket_future)

# retrieve ports
worker_to_port = client.gather(worker_to_port_future)

return worker_to_socket_future, worker_to_port


def _concat(seq: List[_DaskPart]) -> _DaskPart:
Expand Down Expand Up @@ -190,6 +161,7 @@ def _train_part(
num_machines: int,
return_model: bool,
time_out: int,
remote_socket: _RemoteSocket,
**kwargs: Any
) -> Optional[LGBMModel]:
network_params = {
Expand Down Expand Up @@ -320,6 +292,8 @@ def _train_part(
kwargs['eval_class_weight'] = [eval_class_weight[i] for i in eval_component_idx]

model = model_factory(**params)
if remote_socket is not None:
remote_socket.release()
try:
if is_ranker:
model.fit(
Expand Down Expand Up @@ -777,6 +751,7 @@ def _train(
machines = params.pop("machines")

# figure out network params
worker_to_socket_future: Dict[str, Future] = {}
worker_addresses = worker_map.keys()
if machines is not None:
_log_info("Using passed-in 'machines' parameter")
Expand All @@ -802,8 +777,7 @@ def _train(
}
else:
_log_info("Finding random open ports for workers")
host_to_workers = _group_workers_by_host(worker_map.keys())
worker_address_to_port = _assign_open_ports_to_workers(client, host_to_workers)
worker_to_socket_future, worker_address_to_port = _assign_open_ports_to_workers(client, list(worker_map.keys()))

machines = ','.join([
f'{urlparse(worker_address).hostname}:{port}'
Expand Down Expand Up @@ -831,6 +805,7 @@ def _train(
local_listen_port=worker_address_to_port[worker],
num_machines=num_machines,
time_out=params.get('time_out', 120),
remote_socket=worker_to_socket_future.get(worker, None),
return_model=(worker == master_worker),
workers=[worker],
allow_other_workers=False,
Expand Down
45 changes: 5 additions & 40 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,50 +519,13 @@ def test_classifier_custom_objective(output, task, cluster):
assert_eq(p1_proba, p1_proba_local)


def test_group_workers_by_host():
hosts = [f'0.0.0.{i}' for i in range(2)]
workers = [f'tcp://{host}:{p}' for p in range(2) for host in hosts]
expected = {
host: lgb.dask._HostWorkers(
default=f'tcp://{host}:0',
all_workers=[f'tcp://{host}:0', f'tcp://{host}:1']
)
for host in hosts
}
host_to_workers = lgb.dask._group_workers_by_host(workers)
assert host_to_workers == expected


def test_group_workers_by_host_unparseable_host_names():
workers_without_protocol = ['0.0.0.1:80', '0.0.0.2:80']
with pytest.raises(ValueError, match="Could not parse host name from worker address '0.0.0.1:80'"):
lgb.dask._group_workers_by_host(workers_without_protocol)


def test_machines_to_worker_map_unparseable_host_names():
workers = {'0.0.0.1:80': {}, '0.0.0.2:80': {}}
machines = "0.0.0.1:80,0.0.0.2:80"
with pytest.raises(ValueError, match="Could not parse host name from worker address '0.0.0.1:80'"):
lgb.dask._machines_to_worker_map(machines=machines, worker_addresses=workers.keys())


def test_assign_open_ports_to_workers(cluster):
with Client(cluster) as client:
workers = client.scheduler_info()['workers'].keys()
n_workers = len(workers)
host_to_workers = lgb.dask._group_workers_by_host(workers)
for _ in range(25):
worker_address_to_port = lgb.dask._assign_open_ports_to_workers(client, host_to_workers)
found_ports = worker_address_to_port.values()
assert len(found_ports) == n_workers
# check that found ports are different for same address (LocalCluster)
assert len(set(found_ports)) == len(found_ports)
# check that the ports are indeed open
for port in found_ports:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', port))


def test_training_does_not_fail_on_port_conflicts(cluster):
with Client(cluster) as client:
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
Expand Down Expand Up @@ -1588,15 +1551,17 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c
assert 'machines' not in params

# model 2 - machines given
workers = list(client.scheduler_info()['workers'])
workers_hostname = _get_workers_hostname(cluster)
n_workers = len(client.scheduler_info()['workers'])
open_ports = lgb.dask._find_n_open_ports(n_workers)
remote_sockets, open_ports = lgb.dask._assign_open_ports_to_workers(client, workers)
for s in remote_sockets.values():
s.release()
dask_model2 = dask_model_factory(
n_estimators=5,
num_leaves=5,
machines=",".join([
f"{workers_hostname}:{port}"
for port in open_ports
for port in open_ports.values()
]),
)

Expand Down

0 comments on commit ac57d5a

Please sign in to comment.