Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-331] Single machine All Reduce Topology-aware Communication (Updated) #11591

Merged
merged 50 commits into from
Jul 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
9678143
add multiroot all-reduce communication pattern
Jun 4, 2018
d5e51d6
fix bug with UpdateWeight
Jun 4, 2018
0708dbc
fix PCI-E links appearing in weight matrix bug
Jun 4, 2018
5590920
optimization to skip CopyFromTo in ReduceInner gains a bit of throughput
Jun 4, 2018
4f8f58b
remove unnecessary if statement
Jun 5, 2018
908534a
Add tests
Jun 15, 2018
25cbbdc
add more tests, 6 tests left to add
Jun 16, 2018
310ee4d
get rid of some dead code
Jun 16, 2018
9cce8ea
Add comments
Jun 18, 2018
4d2790d
Add randomized tests for backtrack and kernighan-lin
Jun 18, 2018
b5b42bc
Fix Postprocess
Jun 18, 2018
6327ceb
Add switch for first valid tree when num_gpus > 8, and for maximum we…
Jun 18, 2018
8694fe7
Kernighan-Lin seems to find better trees
Jun 18, 2018
c6cd67a
get rid of printfs
Jun 20, 2018
7466c4d
change defaults
Jun 21, 2018
153ec0b
Merge branch 'feature_multirootv9' of https://github.com/ctcyang/incu…
Jun 21, 2018
7c61b6c
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Jun 21, 2018
cc935a2
inherit from CommDevice instead of Comm
Jun 22, 2018
ba60aaa
Fix lint errors
Jun 22, 2018
972e9c0
Add Python test using MXNET_KVSTORE_USETREE, fix CMake compilation pr…
Jun 27, 2018
6627dcf
fix lint errors
Jun 27, 2018
4de89a7
better header guard that works for tests
Jun 27, 2018
317c66b
get rid of unused variable warning
Jun 27, 2018
c364fd3
retrigger jenkins
Jun 28, 2018
3241d71
resolve 2 comments
Jun 29, 2018
bd926bf
address comment using Class to do test, get rid of extraneous test, u…
Jul 2, 2018
0e1a704
resolve merge conflicts
Jul 2, 2018
47b0b63
Merge remote-tracking branch 'apache/master' into feature_multirootv9
Jul 5, 2018
781a7fe
Merge remote-tracking branch 'apache/master' into feature_multirootv9…
Jul 6, 2018
a29f284
address comments
Jul 13, 2018
b310ab4
Merge branch 'feature_multirootv9merge2' into feature_multirootv9merge
Jul 13, 2018
abcb10e
Merge remote-tracking branch 'apache/master' into feature_multirootv9…
Jul 13, 2018
24b9c62
Merge remote-tracking branch 'apache/master' into feature_multirootv9…
Jul 20, 2018
7d0da7b
Merge remote-tracking branch 'apache/master' into feature_multirootv9…
Jul 20, 2018
18c1700
fix a few bugs
Jul 21, 2018
c65a620
get rid of printfs
Jul 21, 2018
a70b1b8
Merge branch 'feature_multirootv9merge3' into feature_multirootv9
Jul 21, 2018
263a4cb
Merge remote-tracking branch 'apache/master' into feature_multirootv9
Jul 21, 2018
628ba6e
get rid of print
Jul 21, 2018
b3f3235
Merge branch 'feature_multirootv9' into feature_multirootv9merge
Jul 21, 2018
a0e1366
Comment out test for now
Jul 23, 2018
63fd14e
fix 2 more bugs
Jul 23, 2018
6c0bff8
Merge branch 'feature_multirootv9merge3' into feature_multirootv9merge
Jul 23, 2018
9f5c24a
fix segfault
Jul 23, 2018
9cc24d0
change PrintVector, PrintTopo, PrintMatrix to LOG(INFO) instead of st…
Jul 24, 2018
691d5ac
Merge branch 'feature_multiv9merge4' into feature_multirootv9merge
Jul 24, 2018
67b0db0
Fix code alignment
Jul 24, 2018
c8ebb87
get rid of todo
Jul 24, 2018
5f7da5e
Make changes to env variable names to indicate they are TREE-related
Jul 24, 2018
16b8fb4
Add note saying when ARRAY_BOUND env var takes effect
Jul 24, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,32 @@ export MXNET_GPU_WORKER_NTHREADS=3
- The minimum size of a "big array".
- When the array size is bigger than this threshold, MXNET_KVSTORE_REDUCTION_NTHREADS threads are used for reduction.
- This parameter is also used as a load balancer in kvstore. It controls when to partition a single weight to all the servers. If the size of a single weight is less than MXNET_KVSTORE_BIGARRAY_BOUND then, it is sent to a single randomly picked server otherwise it is partitioned to all the servers.

* MXNET_KVSTORE_USETREE
- Values: 0(false) or 1(true) ```(default=0)```
- If true, MXNet tries to use tree reduction for Push and Pull communication.
- Otherwise, MXNet uses the default Push and Pull implementation.
- [Tree reduction technology](http://www.sysml.cc/doc/178.pdf) has been shown to be faster than the standard ```--kv-store device``` Push/Pull and ```--kv-store nccl``` Push/Pull for small batch sizes.

* MXNET_KVSTORE_LOGTREE
- Values: 0(false) or 1(true) ```(default=0)```
- If true and MXNET_KVSTORE_USETREE is set to 1, MXNet will log the reduction trees that have been generated.

* MXNET_KVSTORE_TREE_ARRAY_BOUND
- Values: Int ```(default=10000000)```
- The minimum size of a "big array".
- When the array size is bigger than this threshold and MXNET_KVSTORE_USETREE is set to 1, multiple trees are used to load balance the big gradient being communicated in order to better saturate link bandwidth.
- Note: This environmental variable only takes effect if Tree KVStore is being used (MXNET_KVSTORE_USETREE=1).

* MXNET_KVSTORE_TREE_BACKTRACK
- Values: 0(false) or 1(true) ```(default=0)
- If true and MXNET_KVSTORE_USETREE is set to 1, MXNet tries to use backtracking to generate the trees required for tree reduction.
- If false and MXNET_KVSTORE_USETREE is set to 1, MXNet tries to use Kernighan-Lin heuristic to generate the trees required for tree reduction.

* MXNET_KVSTORE_TREE_LINK_USAGE_PENALTY
- Values: Float ```(default=0.7)```
- The multiplicative penalty term to a link being used once.

* MXNET_ENABLE_GPU_P2P
- Values: 0(false) or 1(true) ```(default=1)```
- If true, MXNet tries to use GPU peer-to-peer communication, if available on your device,
Expand Down
121 changes: 68 additions & 53 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,31 @@ class CommDevice : public Comm {
}
}

const NDArray& ReduceRowSparse(int key, const std::vector<NDArray>& src,
int priority) {
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());

const NDArrayStorageType stype = src[0].storage_type();
NDArray& buf_merged = buf.merged_buf(stype);
if (buf.copy_buf.empty()) {
// initialize buffer for copying during reduce
buf.copy_buf.resize(src.size());
for (size_t j = 0; j < src.size(); ++j) {
buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype());
}
}
CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type())
<< "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. "
<< buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
for (size_t i = 0; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
reduce[i] = buf.copy_buf[i];
}
ElementwiseSum(reduce, &buf_merged, priority);
return buf_merged;
}

const NDArray& Reduce(int key, const std::vector<NDArray>& src,
int priority) override {
// when this reduce is called from kvstore_dist, gc is not set
Expand All @@ -490,13 +515,14 @@ class CommDevice : public Comm {

InitBuffersAndComm(src);
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());

const NDArrayStorageType stype = src[0].storage_type();
NDArray& buf_merged = buf.merged_buf(stype);
// normal dense reduce
if (stype == kDefaultStorage) {
CopyFromTo(src[0], &buf_merged, priority);

std::vector<NDArray> reduce(src.size());
reduce[0] = buf_merged;

if (buf.copy_buf.empty()) {
Expand All @@ -514,24 +540,11 @@ class CommDevice : public Comm {
CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
reduce[i+1] = buf.copy_buf[i];
}
ElementwiseSum(reduce, &buf_merged, priority);
} else {
// sparse reduce
if (buf.copy_buf.empty()) {
// initialize buffer for copying during reduce
buf.copy_buf.resize(src.size());
for (size_t j = 0; j < src.size(); ++j) {
buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype());
}
}
CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type())
<< "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. "
<< buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
for (size_t i = 0; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
reduce[i] = buf.copy_buf[i];
}
buf_merged = ReduceRowSparse(key, src, priority);
}
ElementwiseSum(reduce, &buf_merged, priority);
return buf_merged;
}

Expand Down Expand Up @@ -659,6 +672,42 @@ class CommDevice : public Comm {
}
}

using KeyAttrs = std::tuple<int, TShape, int>;
// try to allocate buff on device evenly
void InitMergeBuffer(const std::vector<Context>& devs) {
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
const KeyAttrs& a, const KeyAttrs& b) {
return std::get<1>(a).Size() > std::get<1>(b).Size();
});

std::unordered_map<int, std::pair<Context, size_t>> ctx_info;
for (auto d : devs) {
ctx_info[d.dev_id] = std::make_pair(d, 0);
}

for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
const int key = std::get<0>(sorted_key_attrs_[i]);
const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
const int type = std::get<2>(sorted_key_attrs_[i]);
auto& buf = merge_buf_[key];
Context ctx;
size_t min_size = std::numeric_limits<size_t>::max();
for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) {
size_t size = it->second.second;
if (size <= min_size) {
ctx = it->second.first;
min_size = size;
}
}
// Delayed allocation - as the dense merged buffer might not be used at all if push()
// only sees sparse arrays
bool delay_alloc = true;
buf.merged = NDArray(shape, ctx, delay_alloc, type);
ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
}

private:
void EnableP2P(const std::vector<Context>& devs) {
#if MXNET_USE_CUDA
Expand Down Expand Up @@ -702,43 +751,6 @@ class CommDevice : public Comm {
#endif
}

using KeyAttrs = std::tuple<int, TShape, int>;
// try to allocate buff on device evenly
void InitMergeBuffer(const std::vector<Context>& devs) {
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
const KeyAttrs& a, const KeyAttrs& b) {
return std::get<1>(a).Size() > std::get<1>(b).Size();
});

std::unordered_map<int, std::pair<Context, size_t>> ctx_info;
for (auto d : devs) {
ctx_info[d.dev_id] = std::make_pair(d, 0);
}

for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
const int key = std::get<0>(sorted_key_attrs_[i]);
const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
const int type = std::get<2>(sorted_key_attrs_[i]);
auto& buf = merge_buf_[key];
Context ctx;
size_t min_size = std::numeric_limits<size_t>::max();
for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) {
size_t size = it->second.second;
if (size <= min_size) {
ctx = it->second.first;
min_size = size;
}
}
// Delayed allocation - as the dense merged buffer might not be used at all if push()
// only sees sparse arrays
bool delay_alloc = true;
buf.merged = NDArray(shape, ctx, delay_alloc, type);
ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
}

std::vector<KeyAttrs> sorted_key_attrs_;
/// \brief temporal space for pushing and pulling
struct BufferEntry {
/// \brief the dense merged value for reduce and broadcast operations
Expand Down Expand Up @@ -773,7 +785,10 @@ class CommDevice : public Comm {
NDArray sparse_merged;
};
std::unordered_map<int, BufferEntry> merge_buf_;

public:
bool inited_;
std::vector<KeyAttrs> sorted_key_attrs_;
};

} // namespace kvstore
Expand Down
Loading