Skip to content

Commit

Permalink
[python-package] use keyword arguments in predict() calls (#5755)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Mar 1, 2023
1 parent 27e69e7 commit e423120
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
20 changes: 16 additions & 4 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4052,9 +4052,16 @@ def predict(
num_iteration = self.best_iteration
else:
num_iteration = -1
return predictor.predict(data, start_iteration, num_iteration,
raw_score, pred_leaf, pred_contrib,
data_has_header, validate_features)
return predictor.predict(
data=data,
start_iteration=start_iteration,
num_iteration=num_iteration,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
data_has_header=data_has_header,
validate_features=validate_features
)

def refit(
self,
Expand Down Expand Up @@ -4130,7 +4137,12 @@ def refit(
if dataset_params is None:
dataset_params = {}
predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True, validate_features=validate_features)
leaf_preds = predictor.predict(
data=data,
start_iteration=-1,
pred_leaf=True,
validate_features=validate_features
)
nrow, ncol = leaf_preds.shape
out_is_linear = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetLinear(
Expand Down
24 changes: 20 additions & 4 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,9 +1135,16 @@ def predict(
**kwargs: Any
):
"""Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
pred_leaf, pred_contrib, validate_features,
**kwargs)
result = self.predict_proba(
X=X,
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)
if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
return result
else:
Expand All @@ -1158,7 +1165,16 @@ def predict_proba(
**kwargs: Any
):
"""Docstring is set after definition, using a template."""
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, validate_features, **kwargs)
result = super().predict(
X=X,
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
_log_warning("Cannot compute class probabilities or labels "
"due to the usage of customized objective function.\n"
Expand Down

0 comments on commit e423120

Please sign in to comment.