Skip to content

Commit

Permalink
[ci] [python-package] resolve remaining mypy errors in dask.py (#5858)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored May 4, 2023
1 parent 11e17f3 commit 33d90f4
Showing 1 changed file with 48 additions and 12 deletions.
60 changes: 48 additions & 12 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,13 +576,48 @@ def _train(
# pad eval sets when they come in different sizes.
n_largest_eval_parts = max(x[0].npartitions for x in eval_set)

eval_sets = defaultdict(list)
eval_sets: Dict[
int,
List[
Union[
_DatasetNames,
Tuple[
List[Optional[_DaskMatrixLike]],
List[Optional[_DaskVectorLike]]
]
]
]
] = defaultdict(list)
if eval_sample_weight:
eval_sample_weights = defaultdict(list)
eval_sample_weights: Dict[
int,
List[
Union[
_DatasetNames,
List[Optional[_DaskVectorLike]]
]
]
] = defaultdict(list)
if eval_group:
eval_groups = defaultdict(list)
eval_groups: Dict[
int,
List[
Union[
_DatasetNames,
List[Optional[_DaskVectorLike]]
]
]
] = defaultdict(list)
if eval_init_score:
eval_init_scores = defaultdict(list)
eval_init_scores: Dict[
int,
List[
Union[
_DatasetNames,
List[Optional[_DaskMatrixLike]]
]
]
] = defaultdict(list)

for i, (X_eval, y_eval) in enumerate(eval_set):
n_this_eval_parts = X_eval.npartitions
Expand Down Expand Up @@ -610,8 +645,8 @@ def _train(
eval_sets[parts_idx].append(([x_e], [y_e]))
else:
# append additional chunks of this eval set to this part.
eval_sets[parts_idx][-1][0].append(x_e)
eval_sets[parts_idx][-1][1].append(y_e)
eval_sets[parts_idx][-1][0].append(x_e) # type: ignore[index, union-attr]
eval_sets[parts_idx][-1][1].append(y_e) # type: ignore[index, union-attr]

if eval_sample_weight:
if eval_sample_weight[i] is sample_weight:
Expand All @@ -631,7 +666,7 @@ def _train(
if j < n_parts:
eval_sample_weights[parts_idx].append([w_e])
else:
eval_sample_weights[parts_idx][-1].append(w_e)
eval_sample_weights[parts_idx][-1].append(w_e) # type: ignore[union-attr]

if eval_init_score:
if eval_init_score[i] is init_score:
Expand All @@ -649,7 +684,7 @@ def _train(
if j < n_parts:
eval_init_scores[parts_idx].append([init_score_e])
else:
eval_init_scores[parts_idx][-1].append(init_score_e)
eval_init_scores[parts_idx][-1].append(init_score_e) # type: ignore[union-attr]

if eval_group:
if eval_group[i] is group:
Expand All @@ -667,7 +702,7 @@ def _train(
if j < n_parts:
eval_groups[parts_idx].append([g_e])
else:
eval_groups[parts_idx][-1].append(g_e)
eval_groups[parts_idx][-1].append(g_e) # type: ignore[union-attr]

# assign sub-eval_set components to worker parts.
for parts_idx, e_set in eval_sets.items():
Expand All @@ -686,7 +721,8 @@ def _train(

for part in parts:
if part.status == 'error': # type: ignore
return part # trigger error locally
# trigger error locally
return part # type: ignore[return-value]

# Find locations of all parts and map them to particular Dask workers
key_to_part_dict = {part.key: part for part in parts} # type: ignore
Expand All @@ -701,7 +737,7 @@ def _train(
for worker in worker_map:
has_eval_set = False
for part in worker_map[worker]:
if 'eval_set' in part.result():
if 'eval_set' in part.result(): # type: ignore[attr-defined]
has_eval_set = True
break

Expand Down Expand Up @@ -1002,7 +1038,7 @@ def _extract(items: List[Any], i: int) -> Any:
**kwargs,
)
pred_row = predict_fn(data_row)
chunks = (data.chunks[0],)
chunks: Tuple[int, ...] = (data.chunks[0],)
map_blocks_kwargs = {}
if len(pred_row.shape) > 1:
chunks += (pred_row.shape[1],)
Expand Down

0 comments on commit 33d90f4

Please sign in to comment.