Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extremely randomized trees #2671

Merged
merged 6 commits into from
Feb 8, 2020
Merged

Extremely randomized trees #2671

merged 6 commits into from
Feb 8, 2020

Conversation

btrotta
Copy link
Collaborator

@btrotta btrotta commented Jan 6, 2020

Option to use extremely randomized trees as base learner, as requested in #2583.

@btrotta
Copy link
Collaborator Author

btrotta commented Jan 6, 2020

Here is a short script to test the performance on the data from Porto Seguro competition on Kaggle. Overfitting is reduced slightly using extra-trees. However it's a little slower, I'm not sure why this is.

import pandas as pd
import numpy as np
import lightgbm as lgb
import time as time
import matplotlib.pyplot as plt
import os

df = pd.read_csv(os.path.join('porto', 'train.csv'))
train_cols = [c for c in df.columns if c not in ['id', 'target'] and ('cat' not in c)]
cat_cols = [c for c in train_cols if '_cat' in c]
train_ind = np.random.choice(df.index, len(df) // 2, replace=False)
train_bool = df.index.isin(train_ind)
lgb_train = lgb.Dataset(df.loc[train_bool, train_cols], label=df.loc[train_bool, 'target'])
lgb_test = lgb.Dataset(df.loc[~train_bool, train_cols], label=df.loc[~train_bool, 'target'])
valid_sets = [lgb_train, lgb_test]
valid_names = ['train', 'valid']

# base
params = {'objective': 'binary', 'seed': 0, 'num_leaves': 32, 'learning_rate': 0.01, 'metric': 'binary_logloss'}
t = time.time()
res = {}
est = lgb.train(params, lgb_train, valid_sets=valid_sets, valid_names=valid_names, num_boost_round=1000,
                categorical_feature=cat_cols, evals_result=res)
print('time ', time.time() - t)
plt.figure()
plt.plot(res['train']['binary_logloss'])
plt.plot(res['valid']['binary_logloss'])

# extra-trees
params['extra_trees'] = True
t = time.time()
res2 = {}
est = lgb.train(params, lgb_train, valid_sets=valid_sets, valid_names=valid_names, num_boost_round=1000,
                categorical_feature=cat_cols, evals_result=res2)
print('time ', time.time() - t)
plt.plot(res2['train']['binary_logloss'])
plt.plot(res2['valid']['binary_logloss'])
plt.legend(['train', 'test', 'train_extra', 'test_extra'])

extra_example

Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@btrotta As always, great contribution! Thank you very much!

If it's not hard for you, can you please add to your benchmark extra trees forest from scikit-learn and LightGBM with {'boosting': 'rf', 'extra_trees': True}? I think it can be interesting to compare.

include/LightGBM/config.h Show resolved Hide resolved
src/treelearner/feature_histogram.hpp Outdated Show resolved Hide resolved
@guolinke
Copy link
Collaborator

guolinke commented Jan 7, 2020

@btrotta for random trees, I think it is not need to construct histograms, which is the most time-consuming part in lightgbm. you can simply generate the random trees, and then predict over these trees, and then boosting, and next tree, ...

@btrotta
Copy link
Collaborator Author

btrotta commented Jan 8, 2020

@guolinke Thanks for the advice! I will try implementing it that way.

@btrotta
Copy link
Collaborator Author

btrotta commented Jan 18, 2020

@guolinke I made an attempt at implementing your suggestion but it actually made the performance worse, so I think I must be doing something wrong. Code is on a new branch (https://github.com/btrotta/LightGBM/tree/extra2) if you want to take a look.

Here's a summary of how I tried to implement it. When finding each new node, we need to try a random split of each feature on smaller_leaf_splits and larger_leaf_splits. For each of these, I first order the gradients and hessians (as is also done in the current code before constructing histograms). Then I do a parallel for-loop over the features. For each feature:

  1. Choose a random threshold to split.
  2. Iterate over the feature's bins to calculate the left and right sum of gradient and hessian.
  3. Calculate split gain.

Then choose the feature having the best random split and split the node on that feature.

In theory, it seems to me that this should be at least as fast as the normal GBDT algorithm, since instead of constructing the full histogram (where we need to aggregate gradients and hessians for many bins and save the result in memory), we only have to sum up gradients and hessians for one split.

But in fact it takes around twice as long (using the example script above), and I don't understand why. I'd be grateful if you have any insights.

Note: I didn't re-implement ordered_sparse_bin, since it seems like that would add more complexity, and I'm not sure if the efficiency gains are worth it in this case. But I don't think this is the reason for the bad performance: I tried my test script with all zeros replaced by -1 (so no sparse features), and performance is similar.

@guolinke
Copy link
Collaborator

guolinke commented Jan 29, 2020

@btrotta sorry for the late response.
I think for the extremely randomized tree, we don't need to gather gradient/hessian during tree growing.
I think the procedure may is:

  1. generate a random tree, maybe generate one random int for selected feature, and one random float, for the threshold ("threshold_in_bin = std::round(random_float*max_bin) "). And then, call Tree::Split to grow trees.
  2. similar to refit, update the leaf output, according to grad and hess:
    Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const {
    auto tree = std::unique_ptr<Tree>(new Tree(*old_tree));
    CHECK(data_partition_->num_leaves() >= tree->num_leaves());
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < tree->num_leaves(); ++i) {
    OMP_LOOP_EX_BEGIN();
    data_size_t cnt_leaf_data = 0;
    auto tmp_idx = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data);
    double sum_grad = 0.0f;
    double sum_hess = kEpsilon;
    for (data_size_t j = 0; j < cnt_leaf_data; ++j) {
    auto idx = tmp_idx[j];
    sum_grad += gradients[idx];
    sum_hess += hessians[idx];
    }
    double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess,
    config_->lambda_l1, config_->lambda_l2, config_->max_delta_step);
    auto old_leaf_output = tree->LeafOutput(i);
    auto new_leaf_output = output * tree->shrinkage();
    tree->SetLeafOutput(i, config_->refit_decay_rate * old_leaf_output + (1.0 - config_->refit_decay_rate) * new_leaf_output);
    OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
    return tree.release();
    }

    BTW, the refit_decay_rate is zero in this case.

In this procedure, you don't need to call SerialTreeLearner::Train, maybe just a function similar to SerialTreeLearner::FitByExistingTree, and is called in GBDT class.

@guolinke
Copy link
Collaborator

guolinke commented Jan 29, 2020

also refer to

void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) {
CHECK(tree_leaf_prediction.size() > 0);
CHECK(static_cast<size_t>(num_data_) == tree_leaf_prediction.size());
CHECK(static_cast<size_t>(models_.size()) == tree_leaf_prediction[0].size());
int num_iterations = static_cast<int>(models_.size() / num_tree_per_iteration_);
std::vector<int> leaf_pred(num_data_);
for (int iter = 0; iter < num_iterations; ++iter) {
Boosting();
for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) {
int model_index = iter * num_tree_per_iteration_ + tree_id;
#pragma omp parallel for schedule(static)
for (int i = 0; i < num_data_; ++i) {
leaf_pred[i] = tree_leaf_prediction[i][model_index];
CHECK(leaf_pred[i] < models_[model_index]->num_leaves());
}
size_t offset = static_cast<size_t>(tree_id) * num_data_;
auto grad = gradients_.data() + offset;
auto hess = hessians_.data() + offset;
auto new_tree = tree_learner_->FitByExistingTree(models_[model_index].get(), leaf_pred, grad, hess);
train_score_updater_->AddScore(tree_learner_.get(), new_tree, tree_id);
models_[model_index].reset(new_tree);
}
}
}

you may need to get leaf index preds before refit.

However, leaf index preds over feature bin values is not implemented, you can refer to Tree::AddPredictionToScore to implement a new one.

@btrotta
Copy link
Collaborator Author

btrotta commented Jan 30, 2020

@guolinke

1. generate a random tree, maybe generate one random int for selected feature, and one random float, for the threshold  ("threshold_in_bin = std::round(random_float*max_bin) "). And then, call `Tree::Split` to grow trees.

If I understand correctly, you're suggesting that for each new node we just randomly choose 1 feature and 1 threshold, and then split. I think this is not the usual definition of extremely randomized trees. For example see the sklearn docs (https://scikit-learn.org/stable/modules/ensemble.html#forest)
which say the following:

As in random forests, a random subset of candidate features is used, but instead of looking for the most discriminative thresholds, thresholds are drawn at random for each candidate feature and the best of these randomly-generated thresholds is picked as the splitting rule.

So it chooses 1 random threshold for each feature, but it evaluates many features then splits on the best one.

I think if we only choose 1 random feature, the algorithm may not fit the data well. Indeed, in the original paper on extremely randomized trees (https://link.springer.com/content/pdf/10.1007/s10994-006-6226-1.pdf), they experiment with varying the size K of the subset of candidate features considered for each split. The results for K=1 are very poor (see Figure 2). (Of course the situation is slightly different since their model is a random forest, not a boosting model like LightGBM. But still, I think the same idea probably applies.)

@guolinke
Copy link
Collaborator

guolinke commented Feb 2, 2020

@btrotta I see.
So the construcHistogram is also needed, therefore, the total time cost will be almost the same as before. And we can just use the code changes in this PR.

@btrotta
Copy link
Collaborator Author

btrotta commented Feb 2, 2020

Updated benchmark script:

import pandas as pd
import numpy as np
import lightgbm as lgb
import time as time
import matplotlib.pyplot as plt
import os
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.metrics import log_loss

df = pd.read_csv(os.path.join('porto', 'train.csv'))
train_cols = [c for c in df.columns if c not in ['id', 'target'] and ('cat' not in c)]
cat_cols = [c for c in train_cols if '_cat' in c]
train_ind = np.random.choice(df.index, len(df) // 2, replace=False)
train_bool = df.index.isin(train_ind)
lgb_train = lgb.Dataset(df.loc[train_bool, train_cols], label=df.loc[train_bool, 'target'])
lgb_test = lgb.Dataset(df.loc[~train_bool, train_cols], label=df.loc[~train_bool, 'target'])
valid_sets = [lgb_train, lgb_test]
valid_names = ['train', 'valid']

# base
params = {'objective': 'binary', 'seed': 0, 'num_leaves': 32, 'learning_rate': 0.01, 'metric': 'binary_logloss'}
t = time.time()
res = {}
est = lgb.train(params, lgb_train, valid_sets=valid_sets, valid_names=valid_names, num_boost_round=1000,
                categorical_feature=cat_cols, evals_result=res)

with open('log.txt', 'a') as f:
    f.write('lightgbm without extra_trees, time {} \n'.format(time.time() - t))
plt.figure()
plt.plot(res['train']['binary_logloss'])
plt.plot(res['valid']['binary_logloss'])

# extra-trees
params['extra_trees'] = True
t = time.time()
res2 = {}
est = lgb.train(params, lgb_train, valid_sets=valid_sets, valid_names=valid_names, num_boost_round=1000,
                categorical_feature=cat_cols, evals_result=res2)
with open('log.txt', 'a') as f:
    f.write('lightgbm with extra_trees, time {} \n'.format(time.time() - t))
plt.plot(res2['train']['binary_logloss'])
plt.plot(res2['valid']['binary_logloss'])

# extra-trees random forest
params['extra_trees'] = True
t = time.time()
res3 = {}
params['boosting'] = 'rf'
params['bagging_fraction'] = 0.8
params['bagging_freq'] = 1
est = lgb.train(params, lgb_train, valid_sets=valid_sets, valid_names=valid_names, num_boost_round=1000,
                categorical_feature=cat_cols, evals_result=res3)
with open('log.txt', 'a') as f:
    f.write('lightgbm random forest with extra_trees, time {} \n'.format(time.time() - t))
plt.plot(res3['train']['binary_logloss'])
plt.plot(res3['valid']['binary_logloss'])

# extra-trees random forest from sklearn
t = time.time()
est = ExtraTreesClassifier(n_estimators=1000, max_depth=5, max_features=1.0, max_samples=0.8, random_state=0)
est.fit(df.loc[train_bool, train_cols], df.loc[train_bool, 'target'])
prediction = pd.Series(est.predict_proba(df[train_cols])[:, 1], index=df.index)
train_loss = log_loss(df.loc[train_bool, 'target'], prediction.loc[train_bool])
test_loss = log_loss(df.loc[~train_bool, 'target'], prediction.loc[~train_bool])
with open('log.txt', 'a') as f:
    f.write('sklearn with extra_trees, time {} \n'.format(time.time() - t))
plt.scatter([1000], [train_loss])
plt.scatter([1000], [test_loss])

plt.legend(['train', 'test', 'train_extra', 'test_extra', 'train_extra_rf_lgb', 'test_extra_rf_lgb',
            'train_extra_rf_sklearn', 'test_extra_rf_sklearn'])

Output:

lightgbm without extra_trees, time 22.96505045890808 
lightgbm with extra_trees, time 22.61457848548889 
lightgbm random forest with extra_trees, time 27.370033502578735 
sklearn with extra_trees, time 1282.728404045105 

extra2

min_constraint, max_constraint, meta_->monotone_type);
// gain with split is worse than without split
if (current_gain <= min_gain_shift) continue;
if (!meta_->config->extra_trees || t - 1 + offset == rand_threshold) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe change FindBestThresholdSequence to the template method: template<bool is_rand> FindBestThresholdSequence ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the template, the function will be expanded at compiling, and don't affect the run-time performance.

@guolinke
Copy link
Collaborator

guolinke commented Feb 8, 2020

@StrikerRUS is this ready to merge?

@StrikerRUS
Copy link
Collaborator

@guolinke

is this ready to merge?

Yes, I think so. But we should wait for CI fixes.

@StrikerRUS StrikerRUS merged commit 446b8b6 into microsoft:master Feb 8, 2020
@StrikerRUS
Copy link
Collaborator

@btrotta
As usual, thank you very much for the high quality contribution! We really appreciate your efforts.

@btrotta
Copy link
Collaborator Author

btrotta commented Feb 8, 2020

@StrikerRUS @guolinke Thanks for your reviews!

@lock lock bot locked as resolved and limited conversation to collaborators Apr 15, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants