From d0d816eff9842072645887e867f46bcfb7840788 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Thu, 28 Oct 2021 16:34:35 +0800 Subject: [PATCH 1/3] Update gbdt.py --- qlib/contrib/model/gbdt.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index f0b0d2eb1e..61447933fc 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -14,11 +14,12 @@ class LGBModel(ModelFT, LightGBMFInt): """LightGBM Model""" - def __init__(self, loss="mse", **kwargs): + def __init__(self, loss="mse", early_stopping_rounds=50, **kwargs): if loss not in {"mse", "binary"}: raise NotImplementedError self.params = {"objective": loss, "verbosity": -1} self.params.update(kwargs) + self.early_stopping_rounds = early_stopping_rounds self.model = None def _prepare_data(self, dataset: DatasetH): @@ -44,7 +45,7 @@ def fit( self, dataset: DatasetH, num_boost_round=1000, - early_stopping_rounds=50, + early_stopping_rounds=None, verbose_eval=20, evals_result=dict(), **kwargs @@ -56,7 +57,7 @@ def fit( num_boost_round=num_boost_round, valid_sets=[dtrain, dvalid], valid_names=["train", "valid"], - early_stopping_rounds=early_stopping_rounds, + early_stopping_rounds=self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds, verbose_eval=verbose_eval, evals_result=evals_result, **kwargs From 65dce0eba256dd57965d72febe5faf8f5eadee60 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Thu, 28 Oct 2021 16:37:42 +0800 Subject: [PATCH 2/3] Update gbdt.py --- qlib/contrib/model/gbdt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index 61447933fc..3dff4bb391 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -57,7 +57,7 @@ def fit( num_boost_round=num_boost_round, valid_sets=[dtrain, dvalid], valid_names=["train", "valid"], - early_stopping_rounds=self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds, + early_stopping_rounds=(self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds), verbose_eval=verbose_eval, evals_result=evals_result, **kwargs From c53084f3ecf249f982f43049af7252bbf7cf3072 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Thu, 28 Oct 2021 16:40:57 +0800 Subject: [PATCH 3/3] Update gbdt.py --- qlib/contrib/model/gbdt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index 3dff4bb391..4e02c8eba1 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -57,7 +57,9 @@ def fit( num_boost_round=num_boost_round, valid_sets=[dtrain, dvalid], valid_names=["train", "valid"], - early_stopping_rounds=(self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds), + early_stopping_rounds=( + self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds + ), verbose_eval=verbose_eval, evals_result=evals_result, **kwargs