-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
tree_learner.cpp
55 lines (51 loc) · 2.18 KB
/
tree_learner.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/*!
* Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#include <LightGBM/tree_learner.h>
#include "gpu_tree_learner.h"
#include "linear_tree_learner.h"
#include "parallel_tree_learner.h"
#include "serial_tree_learner.h"
#include "cuda/cuda_single_gpu_tree_learner.hpp"
namespace LightGBM {
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type,
const Config* config, const bool boosting_on_cuda) {
if (device_type == std::string("cpu")) {
if (learner_type == std::string("serial")) {
if (config->linear_tree) {
return new LinearTreeLearner(config);
} else {
return new SerialTreeLearner(config);
}
} else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<SerialTreeLearner>(config);
} else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<SerialTreeLearner>(config);
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<SerialTreeLearner>(config);
}
} else if (device_type == std::string("gpu")) {
if (learner_type == std::string("serial")) {
return new GPUTreeLearner(config);
} else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<GPUTreeLearner>(config);
} else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<GPUTreeLearner>(config);
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<GPUTreeLearner>(config);
}
} else if (device_type == std::string("cuda")) {
if (learner_type == std::string("serial")) {
if (config->num_gpu == 1) {
return new CUDASingleGPUTreeLearner(config, boosting_on_cuda);
} else {
Log::Fatal("Currently cuda version only supports training on a single GPU.");
}
} else {
Log::Fatal("Currently cuda version only supports training on a single machine.");
}
}
return nullptr;
}
} // namespace LightGBM