Skip to content

Commit

Permalink
[python-package] [dask] add type annotations on dask._HostWorkers (#5766
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jameslamb authored Mar 7, 2023
1 parent 98c1db7 commit bf1a604
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
23 changes: 17 additions & 6 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
It is based on dask-lightgbm, which was based on dask-xgboost.
"""
import socket
from collections import defaultdict, namedtuple
from collections import defaultdict
from copy import deepcopy
from enum import Enum, auto
from functools import partial
Expand Down Expand Up @@ -37,7 +37,18 @@
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]

_HostWorkers = namedtuple('_HostWorkers', ['default', 'all'])

class _HostWorkers:

def __init__(self, default: str, all_workers: List[str]):
self.default = default
self.all_workers = all_workers

def __eq__(self, other: "_HostWorkers") -> bool:
return (
self.default == other.default
and self.all_workers == other.all_workers
)


class _DatasetNames(Enum):
Expand Down Expand Up @@ -105,9 +116,9 @@ def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWo
if not hostname:
raise ValueError(f"Could not parse host name from worker address '{address}'")
if hostname not in host_to_workers:
host_to_workers[hostname] = _HostWorkers(default=address, all=[address])
host_to_workers[hostname] = _HostWorkers(default=address, all_workers=[address])
else:
host_to_workers[hostname].all.append(address)
host_to_workers[hostname].all_workers.append(address)
return host_to_workers


Expand All @@ -124,7 +135,7 @@ def _assign_open_ports_to_workers(
"""
host_ports_futures = {}
for hostname, workers in host_to_workers.items():
n_workers_in_host = len(workers.all)
n_workers_in_host = len(workers.all_workers)
host_ports_futures[hostname] = client.submit(
_find_n_open_ports,
n=n_workers_in_host,
Expand All @@ -135,7 +146,7 @@ def _assign_open_ports_to_workers(
found_ports = client.gather(host_ports_futures)
worker_to_port = {}
for hostname, workers in host_to_workers.items():
for worker, port in zip(workers.all, found_ports[hostname]):
for worker, port in zip(workers.all_workers, found_ports[hostname]):
worker_to_port[worker] = port
return worker_to_port

Expand Down
2 changes: 1 addition & 1 deletion tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def test_group_workers_by_host():
expected = {
host: lgb.dask._HostWorkers(
default=f'tcp://{host}:0',
all=[f'tcp://{host}:0', f'tcp://{host}:1']
all_workers=[f'tcp://{host}:0', f'tcp://{host}:1']
)
for host in hosts
}
Expand Down

0 comments on commit bf1a604

Please sign in to comment.