From 709ea4cad3ff4c8834bdf389eef8016345d9a1ac Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 2 Mar 2023 19:37:33 -0600 Subject: [PATCH] [python-package] [dask] fix mypy errors about Dask fit() return types (#5756) --- python-package/lightgbm/dask.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 5155dad9df8e..ba8f234cefec 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -1175,7 +1175,7 @@ def fit( # type: ignore[override] **kwargs: Any ) -> "DaskLGBMClassifier": """Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" - return self._lgb_dask_fit( + self._lgb_dask_fit( model_factory=LGBMClassifier, X=X, y=y, @@ -1189,6 +1189,7 @@ def fit( # type: ignore[override] eval_metric=eval_metric, **kwargs ) + return self _base_doc = _lgbmmodel_doc_fit.format( X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]", @@ -1378,7 +1379,7 @@ def fit( # type: ignore[override] **kwargs: Any ) -> "DaskLGBMRegressor": """Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" - return self._lgb_dask_fit( + self._lgb_dask_fit( model_factory=LGBMRegressor, X=X, y=y, @@ -1391,6 +1392,7 @@ def fit( # type: ignore[override] eval_metric=eval_metric, **kwargs ) + return self _base_doc = _lgbmmodel_doc_fit.format( X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]", @@ -1550,7 +1552,7 @@ def fit( # type: ignore[override] **kwargs: Any ) -> "DaskLGBMRanker": """Docstring is inherited from the lightgbm.LGBMRanker.fit.""" - return self._lgb_dask_fit( + self._lgb_dask_fit( model_factory=LGBMRanker, X=X, y=y, @@ -1566,6 +1568,7 @@ def fit( # type: ignore[override] eval_at=eval_at, **kwargs ) + return self _base_doc = _lgbmmodel_doc_fit.format( X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",