From 96781432a2179eba1f9ae2a81bba8e525dcdc7af Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 4 Jun 2018 03:51:34 +0000 Subject: [PATCH 01/36] add multiroot all-reduce communication pattern --- src/kvstore/comm_tree.h | 653 ++++++++++++++++++++++++ src/kvstore/gpu_topology.h | 971 ++++++++++++++++++++++++++++++++++++ src/kvstore/kvstore_local.h | 8 +- 3 files changed, 1631 insertions(+), 1 deletion(-) create mode 100644 src/kvstore/comm_tree.h create mode 100644 src/kvstore/gpu_topology.h diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h new file mode 100644 index 000000000000..de9082c3a80a --- /dev/null +++ b/src/kvstore/comm_tree.h @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_KVSTORE_COMM_TREE_H_ +#define MXNET_KVSTORE_COMM_TREE_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "mxnet/ndarray.h" +#include "gradient_compression.h" +#include "../ndarray/ndarray_function.h" +#include "../operator/tensor/sparse_retain-inl.h" +#include "./kvstore_utils.h" +#include "./gpu_topology.h" +namespace mxnet { +namespace kvstore { +/** + * \brief an implementation of Comm that performs reduction on device + * directly using tree. + * + * It is faster if the total device-to-device bandwidths is larger than + * device-to-cpu, which is often true for 4 or 8 GPUs. But it uses more device + * memory. + */ +class CommDeviceTree : public Comm { + public: + CommDeviceTree() { + inited_ = false; + bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 10000000); + link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_LINK_USAGE_PENALTY", 0.7); + } + + virtual ~CommDeviceTree() { } + + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int dtype = mshadow::kFloat32) override { + sorted_key_attrs_.emplace_back(key, shape, dtype); + } + + void InitBuffersAndComm(const std::vector& src) { + if (!inited_) { + for (const auto& a : src) { + devs_.push_back(a.ctx()); + } + GetTopology(); + InitMergeBuffer(); + if (dmlc::GetEnv("MXNET_ENABLE_GPU_P2P", 1)) { + EnableP2P(); + } + } + } + + // src is sliced shape + // copy_buf not sliced + // merged not sliced + const NDArray& ReduceInner(int key, const std::vector& src, int root, + int merged_row, int priority) { + std::vector> reduce(devs_.size()); + + BufferEntry& random_buf = merge_buf_[0][key]; + const NDArrayStorageType stype = random_buf.merged[0].storage_type(); + std::vector& topology = topology_[root]; + NDArray buf_slice; + + if (stype == kDefaultStorage) { + + // Copy everything into buf.merged for each gpu + for (size_t i = 0; i < src.size(); ++i) { + int start = scan_[root][depth_ ]; + int end = scan_[root][depth_+1]; + + for (int j = start; j < end; ++j) { + int topo_id = topology[j]; + BufferEntry& buf = merge_buf_[topo_id][key]; + + if ( devs_[topo_id] == src[i].ctx() ) { + //buf.merged = src[i]; + CopyFromTo(src[i], &(buf.merged[merged_row]), priority); + //LOG(WARNING) << "Initial reduce copy from " << src[i].ctx() << " to " << buf.merged[merged_row].ctx(); + } + } + } + //LOG(WARNING) << "Copy to merged"; + + for (int level = depth_; level > 0; --level) { + int start = scan_[root][level ]; + int end = scan_[root][level+1]; + //LOG(WARNING) << "Reduce level: " << level; + + //LOG(WARNING) << "From " << start << " to " << end; + unsigned is_dest = 0; + int dest_id = 0; + for (int j = start; j < end; ++j) { + int topo_id = topology[j]; + dest_id = (is_dest==0) ? topo_id : dest_id; + //LOG(WARNING) << topo_id << " -> " << dest_id; + + BufferEntry& buf_dest = merge_buf_[dest_id][key]; + BufferEntry& buf_from = merge_buf_[topo_id][key]; + //LOG(WARNING) << "Dest shape " << buf_dest.merged[merged_row].ctx() << buf_dest.copy_buf[merged_row][0].ctx(); + //LOG(WARNING) << "From shape " << buf_from.merged[merged_row].ctx() << buf_from.copy_buf[merged_row][0].ctx(); + + if (!is_dest) { + reduce[dest_id].push_back( buf_dest.merged[merged_row] ); + //LOG(WARNING) << topo_id << " == " << dest_id; + } else { + if (dest_id != topo_id) { + //buf_dest.copy_buf[is_dest-1] = NDArray( + // buf_dest.merged.shape(), buf_dest.merged.ctx(), false, + // buf_dest.merged.dtype()); + CopyFromTo(buf_from.merged[merged_row], + &(buf_dest.copy_buf[merged_row][is_dest-1]), + priority); + reduce[dest_id].push_back( buf_dest.copy_buf[merged_row][is_dest-1] ); + //LOG(WARNING) << "Reduce copy from " << buf_from.merged[merged_row].ctx() << " to " << buf_dest.copy_buf[merged_row][is_dest-1].ctx(); + //LOG(WARNING) << topo_id << " != " << dest_id; + } + } + + is_dest = (is_dest == static_cast(kBranch)-1) ? + 0 : is_dest+1; + } + + start = scan_[root][level-1]; + end = scan_[root][level ]; + for (int i = start; i < end; ++i) { + int gpu_id = topology[i]; + //LOG(WARNING) << "Doing reduce on GPU" << gpu_id; + //LOG(WARNING) << "With #elems " << reduce[gpu_id].size(); + + // conditional to detect whether operation must be done + if ( reduce[gpu_id].size() > 1 ) { + BufferEntry& buf = merge_buf_[gpu_id][key]; + //LOG(WARNING) << "reduce input 1 " << reduce[gpu_id][0].ctx(); + //LOG(WARNING) << "reduce input 2 " << reduce[gpu_id][1].ctx(); + //LOG(WARNING) << "buf.mg output " << buf.merged[merged_row].ctx(); + ElementwiseSum(reduce[gpu_id], &(buf.merged[merged_row]), priority); + } + } + + // reset + //LOG(WARNING) << "Clear reduce array"; + for (unsigned i = 0; i < devs_.size(); ++i) { + reduce[i].clear(); + } + } + } else { + LOG(WARNING) << "Only dense input supported for now"; + } + + int topo_id = topology[0]; + BufferEntry& buf = merge_buf_[topo_id][key]; + return buf.merged[merged_row]; + } + + const NDArray& Reduce(int key, const std::vector& src, + int priority) override { + // when this reduce is called from kvstore_dist, gc is not set + // we don't do compression twice in dist_sync_device + if ((gc_ != nullptr) && (gc_->get_type() != CompressionType::kNone)) { + return ReduceCompressed(key, src, priority); + } + + // avoid extra copy for single device, but it may bring problems for + // abnormal usage of kvstore + if (src.size() == 1) { + return src[0]; + } + + InitBuffersAndComm(src); + std::vector> slice(devs_.size()); + std::vector> broadcast_slice(devs_.size()); + std::vector slice_scan(devs_.size()+1); + + //LOG(WARNING) << key << " " << src[0].shape() << " " << src[0].shape().Size(); + int total_size = src[0].shape().Size(); + unsigned first_size = src[0].shape()[0]; + + const NDArrayStorageType stype = src[0].storage_type(); + // normal dense reduce + if (stype == kDefaultStorage) { + if (total_size > bigarray_bound_ && first_size >= devs_.size()) { + // Find slice bounds + slice_scan[0] = 0; + int slice_size = (first_size + devs_.size()-1)/devs_.size(); + for (unsigned i = 1; i < devs_.size(); ++i) { + slice_scan[i] = slice_scan[i-1] + slice_size; + //LOG(WARNING) << slice_scan[i]; + } + slice_scan[devs_.size()] = src[0].shape()[0]; + + // row: which slice + // col: which gpu + for (unsigned row = 0; row < devs_.size(); ++row) { + for (unsigned col = 0; col < devs_.size(); ++col) { + BufferEntry& buf = merge_buf_[col][key]; + NDArray curr_slice = src[col].Slice(slice_scan[row], + slice_scan[row+1]); + slice[row].push_back(curr_slice); + broadcast_slice[row].push_back(&(buf.merged[row])); + } + } + + // Do reduce-scatter (multiroot reduce) + // input: slice (src) + // output: buf.merge_buf + for (unsigned i = 0; i < devs_.size(); ++i) { + ReduceInner(key, slice[i], i, i, priority); + } + + for (unsigned i = 0; i < devs_.size(); ++i) { + BroadcastInner(key, *(broadcast_slice[i][i]), broadcast_slice[i], i, i, priority); + } + } else { + int root = 0; + //LOG(WARNING) << "Executing single tree reduce for key " << key << " root " << root; + ReduceInner(key, src, root, 0, priority); + + BufferEntry& buf = merge_buf_[root][key]; + return buf.merged[0]; + } + + // Copy from list of small NDArrays to one big NDArray, which is returned + int gpu_id = 0; + return src[gpu_id]; + } else { + // sparse reduce + LOG(WARNING) << "Only dense input supported for now using multiple trees"; + } + } + + const NDArray& ReduceCompressed(int key, const std::vector& src, + int priority) { + LOG(WARNING) << "ReduceCompressed not supported using multiple trees"; + /*InitBuffersAndComm(src); + auto& buf = merge_buf_[key]; + std::vector reduce(src.size()); + if (buf.copy_buf.empty()) { + // one buf for each context + buf.copy_buf.resize(src.size()); + buf.compressed_recv_buf.resize(src.size()); + buf.compressed_send_buf.resize(src.size()); + buf.residual.resize(src.size()); + + for (size_t i = 0; i < src.size(); ++i) { + buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(), + false, buf.merged.dtype()); + buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(), + false, buf.merged.dtype()); + buf.residual[i] = 0; + int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size()); + buf.compressed_recv_buf[i] = NDArray(TShape{small_size}, buf.merged.ctx(), + false, buf.merged.dtype()); + buf.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(), + false, buf.merged.dtype()); + } + } + + for (size_t i = 0; i < src.size(); ++i) { + // compress before copy + // this is done even if the data is on same context as copy_buf because + // we don't want the training to be biased towards data on this GPU + gc_->Quantize(src[i], &(buf.compressed_send_buf[i]), &(buf.residual[i]), priority); + + if (buf.compressed_send_buf[i].ctx() != buf.compressed_recv_buf[i].ctx()) { + CopyFromTo(buf.compressed_send_buf[i], &(buf.compressed_recv_buf[i]), priority); + } else { + // avoid memory copy when they are on same context + buf.compressed_recv_buf[i] = buf.compressed_send_buf[i]; + } + + gc_->Dequantize(buf.compressed_recv_buf[i], &(buf.copy_buf[i]), priority); + reduce[i] = buf.copy_buf[i]; + } + ElementwiseSum(reduce, &buf.merged); + return buf.merged;*/ + } + + void BroadcastInner(int key, const NDArray& src, + const std::vector dst, int root, int merged_row, + int priority) { + // copy to root of tree + std::vector& topology = topology_[root]; + std::vector temp(devs_.size()); + int gpu_id = topology[0]; + CopyFromTo(src, dst[gpu_id], priority); + temp[gpu_id] = *dst[gpu_id]; + //LOG(WARNING) << "Bcast copy from " << src.ctx() << " to " << buf.merged[merged_row].ctx(); + + for (int level = 1; level <= depth_; ++level) { + int start = scan_[root][level]; + int end = scan_[root][level+1]; + //LOG(WARNING) << "Bcast level: " << level; + + //LOG(WARNING) << "From " << start << " to " << end; + unsigned is_src = 0; + int src_id = 0; + for (int j = start; j < end; ++j) { + int topo_id = topology[j]; + src_id = (is_src==0) ? topo_id : src_id; + //LOG(WARNING) << src_id << " -> " << topo_id; + + if (is_src && src_id != topo_id) { + //LOG(WARNING) << src_id << " != " << topo_id; + + CopyFromTo(temp[src_id], dst[topo_id], priority); + + temp[topo_id] = *dst[topo_id]; + + //LOG(WARNING) << "Bcast copy from " << buf_from.merged[merged_row].ctx() << " to " << buf_dest.merged[merged_row].ctx(); + } + + is_src = (is_src == static_cast(kBranch)-1) ? 0 : is_src+1; + } + } + } + + void Broadcast(int key, const NDArray& src, + const std::vector dst, int priority) override { + if (!inited_) { + // copy to a random device first + int dev_id = key % dst.size(); + CopyFromTo(src, dst[dev_id], priority); + for (size_t i = 0; i < dst.size(); ++i) { + if (i != static_cast(dev_id)) { + CopyFromTo(*dst[dev_id], dst[i], priority); + } + } + } else { + int total_size = src.shape().Size(); + unsigned first_size = src.shape()[0]; + if (total_size > bigarray_bound_ && first_size >= devs_.size()) { + std::vector slice_scan(devs_.size()+1); + slice_scan[0] = 0; + int slice_size = (dst[0]->shape()[0]+devs_.size()-1)/devs_.size(); + for (unsigned i = 1; i < devs_.size(); ++i) { + slice_scan[i] = slice_scan[i-1] + slice_size; + //LOG(WARNING) << slice_scan[i]; + } + slice_scan[devs_.size()] = dst[0]->shape()[0]; + + for (unsigned gpu_id = 0; gpu_id < dst.size(); ++gpu_id) { + BufferEntry& buf = merge_buf_[gpu_id][key]; + for (unsigned i = 0; i < devs_.size(); ++i) { + if ( devs_[gpu_id] == dst[gpu_id]->ctx() ) { + NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i+1]); + CopyFromTo(buf.merged[i], &curr_slice, priority); + //LOG(WARNING) << "Bcast return copy from " << buf.merged[i].ctx() << " to " << curr_slice.ctx(); + } + } + } + } else { + int root = 0; + //LOG(WARNING) << "Executing single tree broadcast for key " << key << " root " << root; + BroadcastInner(key, src, dst, root, 0, priority); + } + } + } + + void BroadcastRowSparse(int key, const NDArray& src, + const std::vector>& dst, + const int priority) override { + LOG(WARNING) << "BroadcastRowSparse not supported by multiple trees"; + /*CHECK_EQ(src.storage_type(), kRowSparseStorage) + << "BroadcastRowSparse expects row-sparse src NDArray"; + + for (size_t i = 0; i < dst.size(); ++i) { + NDArray* out = dst[i].first; + NDArray row_id = dst[i].second; + CHECK_EQ(out->storage_type(), kRowSparseStorage) + << "BroadcastRowSparse expects row_sparse dst NDArray"; + CHECK_EQ(row_id.ctx(), src.ctx()) + << "row_id and src are expected to be on the same context"; + + // retain according to indices + const bool is_same_ctx = out->ctx() == src.ctx(); + const bool is_diff_var = out->var() != src.var(); + NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out : + NDArray(kRowSparseStorage, out->shape(), src.ctx(), true, + out->dtype(), out->aux_types()); + if (!is_diff_var) { + common::LogOnce("The output of row_sparse_pull() on key " + std::to_string(key) + + "refers to the same NDArray as the one stored in KVStore." + "Performing row_sparse_pull() with such output is going to change the " + "data stored in KVStore. Incorrect result may be generated " + "next time row_sparse_pull() is called. To avoid such an issue," + "consider create a new NDArray buffer to store the output."); + } + + Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + const TBlob& indices = row_id.data(); + using namespace mxnet::common; + NDArray temp = retained_gpu; + switch (temp.ctx().dev_mask()) { + case cpu::kDevMask: { + SparseRetainOpForwardRspWrapper(rctx.get_stream(), + src, indices, kWriteTo, &temp); + break; + } +#if MXNET_USE_CUDA + case gpu::kDevMask: { + SparseRetainOpForwardRspWrapper(rctx.get_stream(), + src, indices, kWriteTo, &temp); + // wait for GPU operations to complete + rctx.get_stream()->Wait(); + break; + } +#endif + default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + } + on_complete(); + }, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()}, + FnProperty::kNormal, priority, "KVStoreSparseRetain"); + CopyFromTo(retained_gpu, out, priority); + }*/ + } + + private: + void EnableP2P() { +#if MXNET_USE_CUDA + std::vector gpus; + for (const auto& d : devs_) { + if (d.dev_mask() == gpu::kDevMask) { + gpus.push_back(d.dev_id); + } + } + int n = static_cast(gpus.size()); + int enabled = 0; + std::vector p2p(n*n); + for (int i = 0; i < n; ++i) { + cudaSetDevice(gpus[i]); + for (int j = 0; j < n; j++) { + int access; + cudaDeviceCanAccessPeer(&access, gpus[i], gpus[j]); + if (access) { + cudaError_t e = cudaDeviceEnablePeerAccess(gpus[j], 0); + if (e == cudaSuccess || e == cudaErrorPeerAccessAlreadyEnabled) { + ++enabled; + p2p[i*n+j] = 1; + } + } + } + } + if (enabled != n*(n-1)) { + // print warning info if not fully enabled + LOG(WARNING) << "only " << enabled << " out of " + << n*(n-1) << " GPU pairs are enabled direct access. " + << "It may affect the performance. " + << "You can set MXNET_ENABLE_GPU_P2P=0 to turn it off"; + std::string access(n, '.'); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + access[j] = p2p[i*n+j] ? 'v' : '.'; + } + LOG(WARNING) << access; + } + } +#endif + } + + void GetTopology() { +#if MXNET_USE_CUDA + + std::vector link_matrix(devs_.size()*devs_.size()); + std::vector zero_dev_id(devs_.size(), -1); + GetP2PWeight( link_matrix, devs_, zero_dev_id ); + PartitionGraph( link_matrix, devs_.size(), zero_dev_id, topology_, scan_, + link_usage_penalty_ ); + + depth_ = ComputeDepth(devs_.size()); +#endif + } + + using KeyAttrs = std::tuple; + // try to allocate buff on device evenly + void InitMergeBuffer() { + + LOG(WARNING) << "Using Tree"; + + // same as all-reduce, except: + // 1) Allocate copy_buf here instead of in Reduce() + // 2) Force copy_buf to be of kRecvBufferSize + // 3) Do not use greedy assignment; all keys are assigned to each GPU + for (unsigned i = 0; i < devs_.size(); ++i) + merge_buf_.push_back( std::unordered_map() ); + + std::map key_dist; + + for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { + const int key = std::get<0>(sorted_key_attrs_[i]); + const TShape& shape = std::get<1>(sorted_key_attrs_[i]); + const int type = std::get<2>(sorted_key_attrs_[i]); + + if (key_dist.find(shape.Size()) == key_dist.end()) + key_dist[shape.Size()] = 1; + else + key_dist[shape.Size()]++; + + int start = scan_[0][depth_ ]; + int end = scan_[0][depth_+1]; + //LOG(WARNING) << "From: " << start << " to: " << end; + + // In order to generalize to any number of GPUs, must support 2 things: + // 1) detect whether we are encountering gpu for first time + // first time => allocate memory + // second time => do nothing + // 2) must use either mapping from dev_id to 0, 1, ..., n_gpus or must + // allocate merge_buf_ to be next biggest power of 2 sized or use + // 0, 1, ..., n_gpus (same mapping as dev_id) + // e.g. 5, 6, 7, 8 must all have merge_buf_.size() == 8 + // -Design decision: use second approach for now + for (int j = start; j < end; ++j) { + int topo_id = topology_[0][j]; + auto& buf = merge_buf_[topo_id][key]; + Context ctx = devs_[topo_id]; + + // buf.merged enforces that we only visit each GPU once + if (buf.merged.empty()) { + TShape shape_copy = shape; + int total_size = shape.Size(); + unsigned first_size = shape[0]; + if (total_size > bigarray_bound_ && first_size >= devs_.size()) { + // Find slice bounds + int slice_size = (first_size+devs_.size()-1)/devs_.size(); + int last_slice = first_size-(devs_.size()-1)*slice_size; + shape_copy[0] = slice_size; + //LOG(WARNING) << "Split Check emptiness of copy buf on GPU" << topo_id << " " << ctx; + buf.merged.resize(devs_.size()); + for (unsigned row = 0; row < devs_.size(); ++row) { + if (row == devs_.size()-1) + shape_copy[0] = last_slice; + //LOG(WARNING) << "Split Allocating merg buf to GPU" << topo_id << " of shape" << shape_copy; + buf.merged[row] = NDArray(shape_copy, ctx, false, type); + buf.copy_buf.push_back(std::vector()); + if (buf.copy_buf[row].empty()) { + buf.copy_buf[row].resize(kBranch-1); + for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { + //LOG(WARNING) << "Split Allocating copy buf to GPU" << topo_id; + buf.copy_buf[row][col] = NDArray(buf.merged[row].shape(), + buf.merged[row].ctx(), false, + buf.merged[row].dtype()); + } + } + } + } else { + buf.merged.push_back(NDArray(shape, ctx, false, type)); + if (buf.copy_buf.empty()) { + //LOG(WARNING) << "Check emptiness of copy buf on GPU" << topo_id<< " " << ctx; + buf.copy_buf.push_back(std::vector()); + buf.copy_buf[0].resize(kBranch-1); + for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { + //LOG(WARNING) << "Allocating copy buf to GPU" << topo_id; + buf.copy_buf[0][col] = NDArray(buf.merged[0].shape(), + buf.merged[0].ctx(), false, + buf.merged[0].dtype()); + //LOG(WARNING) << "Success allocating copy buf to GPU" << topo_id; + } + } + } + } else { + //LOG(WARNING) << topo_id << " has been allocated already"; + } + } + } + + for (auto it = key_dist.begin(); it != key_dist.end(); ++it) { + LOG(WARNING) << "Size " << it->first << " occurs " << it->second << " times"; + } + inited_ = true; + } + + std::vector sorted_key_attrs_; + /// \brief temporal space for pushing and pulling + struct BufferEntry { + /// \brief the dense merged value for reduce and broadcast operations + std::vector merged; + /// \brief the gpu buffer for copy during reduce operation + std::vector> copy_buf; + /// \brief the residual buffer for gradient compression + std::vector residual; + /// \brief the small buffer for compressed data in sender + std::vector compressed_send_buf; + /// \brief the small buffer for compressed data in receiver + std::vector compressed_recv_buf; + + /// \brief the merged buffer for the given storage type (could be either dense or row_sparse) + inline NDArray& merged_buf(NDArrayStorageType stype) { + if (stype == kDefaultStorage) { + CHECK(merged.size() > 0 && !merged[0].is_none()) << "unintialized merge buffer detected"; + return merged[0]; + } + CHECK(stype == kRowSparseStorage) << "unexpected storage type " << stype; + // check if sparse_merged is initialized + if (sparse_merged.is_none()) { + CHECK(merged.size() > 0 && !merged[0].is_none()); + sparse_merged = NDArray(kRowSparseStorage, merged[0].shape(), + merged[0].ctx(), true, merged[0].dtype()); + } + return sparse_merged; + } + + private: + /// \brief the sparse merged value for reduce and rowsparse broadcast operations + NDArray sparse_merged; + }; + /// \brief intent of merge_buf_ in old comm.h: store key->gpu mapping + /// new intent: for every gpu: store key->memory mapping + std::vector> merge_buf_; + + /// \brief NVLink-connected topology in full binary tree format + std::vector> topology_; + std::vector> scan_; + std::vector devs_; + + /// \brief Highest numbered device + int max_dev_; + int depth_; + int bigarray_bound_; + bool inited_; + float link_usage_penalty_; + + /// \brief constant for maximum size of recv buffer per GPU + /// 2: only receive from 1 other GPU + const int kBranch = 2; +}; + +} // namespace kvstore +} // namespace mxnet +#endif // MXNET_KVSTORE_COMM_TREE_H_ diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h new file mode 100644 index 000000000000..11d091889014 --- /dev/null +++ b/src/kvstore/gpu_topology.h @@ -0,0 +1,971 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_KVSTORE_GPU_TOPOLOGY_H_ +#define MXNET_KVSTORE_GPU_TOPOLOGY_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define MAX_DEPTH 16 + +namespace mxnet { +namespace kvstore { + +void PrettyPrintTopology(const std::vector> topo) { + std::cout << " ={"; + for (unsigned row = 0; row < topo.size(); ++row) { + if (row != 0) + std::cout << " "; + std::cout << "{"; + for (unsigned col = 0; col < topo[0].size(); ++col) { + std::cout << topo[row][col]; + if( col != topo[0].size()-1 ) + std::cout << ", "; + } + std::cout << "}"; + if ( row == topo.size()-1 ) + std::cout << "};"; + else + std::cout << ","; + std::cout << std::endl; + } +} + +void PrintTopo( const std::string& str, const std::vector& topo_row, + std::vector scan_row ) { + std::cout << str << ":\n"; + int depth = scan_row.size()-1; + for (int row = 0; row < depth; ++row) { + int start = scan_row[row]; + int end = scan_row[row+1]; + for (; start +void PrintMatrix( const std::string& str, const std::vector& matrix, + int num_rows, int num_cols ) { + + std::cout << str << ":\n"; + int count = 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + std::cout << matrix[count++] << " "; + } + std::cout << std::endl; + } +} + +template +void PrintVector( const std::string& str, const std::vector& vec ) { + std::cout << str << ":\n"; + for (unsigned i = 0; i < vec.size(); ++i) + std::cout << vec[i] << " "; + std::cout << std::endl; +} + +// Get relative performance of NVIDIA GPUs +// 0: Self-connection +// 1: PCI-E +// 2: 1 NVLink connection +// 3: 2 NVLink connections +// +// Generate 2 things: +// 1) adjacency matrix with row/col numbering from 0, 1, ..., n_gpu +// 2) mapping from 0, 1, ..., n_gpu to dev_id +// -used to map from 0, 1, ..., n_gpu back to dev_id for topology, which will +// be used by kvstore to do communication +// -used to build adjacency matrix with 0, 1, ..., n_gpu numbering +template +void GetP2PWeight( std::vector& matrix, + const std::vector& devs, + std::vector& zero_dev_id, + bool print=false ) { + int num_gpus = devs.size(); + int count = 0; + for (auto d : devs) { + zero_dev_id[count] = d.dev_id; + count++; + } + + cudaDeviceP2PAttr attr; + attr = cudaDevP2PAttrPerformanceRank; + std::vector max(num_gpus, 0); + + for (int row = 0; row < num_gpus; ++row) { + for (int col = 0; col < num_gpus; ++col) { + if (row==col) { + matrix[row*num_gpus+col] = 0; + } else { + int value; + int row_gpu = zero_dev_id[row]; + int col_gpu = zero_dev_id[col]; + cudaDeviceGetP2PAttribute( &value, attr, row_gpu, col_gpu ); + if (value > max[row]) + max[row] = value; + matrix[row*num_gpus+col] = static_cast(value)+1; + } + } + } + + // Check that all GPUs have at least 1 NVLink connection + int max_value = 0; + for (unsigned int i = 0; i < max.size(); ++i) { + if (max[i] > max_value) + max_value = max[i]; + } + + // If all GPUs have at least 1 NVLink connection, then we can use NVLink only + // to communicate instead of going over PCI-E + if (max_value > 0) { + for (auto matrix_value : matrix) { + matrix_value = (matrix_value==1) ? 0 : matrix_value; + } + } + PrintMatrix( "Weight W", matrix, num_gpus, num_gpus ); +} + +// Dense matrix-vector multiplication +// Assume: matrix is square +template +void gemv( const std::vector& A, + const std::vector& x, + std::vector& y ) { + int nrows = x.size(); + int count = 0; + for (int row=0; row(x[col]); + count++; + } + } +} + +// Element-wise multiplication between 2 dense vectors +// w = w * alpha*u +template +void ewisemult( const std::vector& u, + T alpha, + std::vector& w ) { + int nelem = u.size(); + for (int i=0; i(u[i]); + } +} + +// Element-wise addition between 2 dense vectors +// w = w + alpha*u +template +void ewiseadd( const std::vector& u, + T alpha, + std::vector& w ) { + int nelem = u.size(); + for (int i=0; i(u[i]); + } +} + +// Computes best 2 nodes a,b to swap given objective function: +// g = max_{a \in A, b \in B} D(a) + D(b) - 2*W(a,b) +// +// Optimization: Only need to look at upper triangular since weight matrix is +// symmetric +template +void FindBestMove( const std::vector& W, + const std::vector& P_temp, + const std::vector& D, + const std::unordered_set& used, + int& a, + int& b, + T& g ) { + int nrows = P_temp.size(); + g = 0; + a = -1; + b = -1; + for (int row=0; rowg ) { + g = cost; + a = row; + b = col; + } + } + } +} + +// Performs partition on each existing partition in graph W if partition has +// more than 4 elements in it +// @output: stop returns true if no partitions with >=4 elements found +// returns false otherwise +template +bool KernighanLin( const std::vector& W, + std::vector& P, + int& num_partitions, + std::vector>& cluster_pairs, + std::mt19937& gen ) { + + std::vector histogram(num_partitions, 0); + std::vector P_temp(P.size(), 0); + std::vector P_temp2(P.size(), 0); + std::vector D(P.size(), 0); + std::vector D_temp(P.size(), 0); + + // 0) For every partition, determine if it can be partitioned further. + // To do this, we must do a histogram of each partition: + for (unsigned i=0; i( + static_cast(color),-partition_size)); + + // Do Kernighan-Lin if clustering is necessary + } else { + stop = false; + + // 1) If it has more than 4 elements, we can partition further. + // Assign random balanced partition of it + // -balanced is more important than random, so allocate first half to A + // and rest to B + int first_partition = 0; + int target_partition = partition_size/2; + std::vector cluster_list; + + for (unsigned i = 0; i < P.size(); ++i) { + // Required to shift from [0,1] to {-1,1} + // 1 means vertex i is in Cluster A + // -1 means vertex i is in Cluster B + if (P[i] == static_cast(color)) { + cluster_list.push_back(i); + //std::cout << "Number in Cluster A: " << first_partition << "\n"; + //std::cout << "Put vertex " << i << " in Cluster " << P_temp[i] << "\n"; + } else + P_temp[i] = 0; + } + + // 1b) Shuffle using random generator + std::shuffle(cluster_list.begin(), cluster_list.end(), gen); + //PrintVector("Partition permutation", cluster_list); + for (unsigned i = 0; i < cluster_list.size(); ++i) { + if (first_partition < target_partition) { + int dest = cluster_list[i]; + P_temp[dest] = 1; + first_partition++; + } else { + int dest = cluster_list[i]; + P_temp[dest] = -1; + } + } + //PrintVector("Partition candidate", P_temp); + + // 2) Do iterations of Kernighan-Lin until convergence + T g_max = 0; + int g_k = -1; + unsigned count = 0; + do { + count++; + P_temp2 = P_temp; + + // a) Compute difference between external and internal costs of all + // elements in vector D + gemv( W, P_temp, D ); + //PrintVector( "D pre-ewisemult", D ); + ewisemult( P_temp, -1.f, D ); + //PrintVector( "D post-ewisemult", D ); + + // av and bv are used to hold candidates for moving + // gv stores the score associated with move + std::vector av; + std::vector bv; + std::vector gv; + + std::unordered_set used; + + for (int iter=0; iter 0) { + //std::cout << "Best move found in iter " << iter; + //std::cout << ": " << a << " -> " << b << " : " << g << "\n"; + } else { + //std::cout << "No moves found in iter " << iter << std::endl; + g_max = 0; + break; + } + + // c) Store best move to av, bv, gv + av.push_back(a); + bv.push_back(b); + gv.push_back(g); + + // d) Eliminate best move from consideration in vector P_temp + P_temp[a] *= -1; + P_temp[b] *= -1; + used.insert(a); + used.insert(b); + + // e) Update D using P_temp + //PrintVector( "P_temp post-update", P_temp ); + gemv( W, P_temp, D ); + //PrintVector( "D pre-ewisemult", D ); + ewisemult( P_temp, -1.f, D ); + //PrintVector( "D post-ewisemult", D ); + D[a] = 0; + D[b] = 0; + //PrintVector( "D post-ewisemult", D ); + } + + // 3) Find when to stop by doing linear scan through gv + // Recompute score g_max + for (unsigned k = 0; k < gv.size(); ++k) { + if (k > 0) + gv[k] += gv[k-1]; + if (gv[k] > g_max) { + g_max = gv[k]; + g_k = k + 1; + } + } + + // 4) If move is "good", commit moves by updating P_temp and P_temp2 + // Otherwise, rollback changes to P_temp2 + if (g_max > 0) { + for (int i = 0; i < g_k; i++) { + //std::cout << g_max << " " << g_k << " " << i << " " << av.size() << " " << bv.size() << " " << gv.size() << std::endl; + int a = av[i]; + int b = bv[i]; + int temp = P_temp2[a]; + P_temp2[a] = P_temp2[b]; + P_temp2[b] = temp; + + P_temp = P_temp2; + } + } else { + P_temp = P_temp2; + } + } while (g_max > 0 && count <= P.size()); + + // 5) Update P using P_temp + int moves = 0; + for (unsigned i=0; i(static_cast(color), + static_cast(num_partitions))); + + num_partitions++; + } + } + + return stop; +} + +// Returns root of a given color if found in roots +// Returns -1 if it is not found +int GetRoot( const std::vector& P, + int color, + const std::unordered_set& roots ) { + for (auto root : roots) { + if (P[root]==color) + return root; + } + return -1; +} + +// Returns root of a given color if found in roots +// Returns -1 if it is not found +int GetChild( const std::vector& P, + int color, + int parent ) { + int size = P.size(); + for (int i = 0; i < size; ++i) { + //std::cout << "Child " << i << ": " << P[i] << std::endl; + if (P[i] == color && i != parent) + return i; + } + return -1; +} + +// Computes best 2 nodes a,b to swap given objective function: +// g = max_{a \in A, b \in B} 2*W(a,b) +// +// Optimization: Only need to look at upper triangular since weight matrix is +// symmetric +template +void FindBestEdge( const std::vector& W, + const std::vector& P, + int parent, + int dest_cluster, + std::vector& b, + T& g ) { + int nrows = P.size(); + int row = parent; + g = 0; + b.push_back(-1); + for (int col=0; col g ) { + b.clear(); + } + if( cost >= g ) { + b.push_back(col); + g = cost; + } + } +} + +// Given a vector of color pairs, appends to binary tree matrix topo +template +int GenerateBinaryTree( std::vector& W, + const std::vector& P, + std::vector>& cluster_pairs, + std::unordered_set& roots, + std::vector& topo_row, + std::vector& scan_row, + std::mt19937& gen ) { + std::unordered_set new_roots; + std::unordered_map new_topo; + int reset = 0; + + for (unsigned i = 0; i < cluster_pairs.size(); ++i) { + //std::cout << "Cluster pair " << i << std::endl; + if (i==0) + scan_row.push_back(topo_row.size()); + //std::cout << "Pair " << i << ": " << cluster_pairs[i].first << " " << cluster_pairs[i].second << std::endl; + int parent, child = -1; + if (cluster_pairs[i].second==-2) { + // Root must exist in first element of pair + int color = cluster_pairs[i].first; + parent = GetRoot( P, color, roots ); + if (parent == -1) return 1; + child = GetChild(P, color, parent); + //std::cout << "Best link (case 1): " << color << ": " << parent << " -> " << child << ": " << std::endl; + } else if (cluster_pairs[i].second==-1) { + int color = cluster_pairs[i].first; + parent = GetRoot( P, color, roots ); + if (parent == -1) return 1; + child = parent; + //std::cout << "Best link (case 2): " << color << ": " << parent << " -> " << child << ": " << std::endl; + } else { + // Root must exist in either first or second element of pair + int color = cluster_pairs[i].first; + parent = GetRoot(P, color, roots); + color = (parent==-1) ? cluster_pairs[i].second : color; + parent = (parent==-1) ? GetRoot(P, color, roots) : parent; + + int from_cluster = color; + int dest_cluster = (from_cluster==cluster_pairs[i].first) ? + cluster_pairs[i].second : cluster_pairs[i].first; + + std::vector candidates; + T weight; + FindBestEdge( W, P, parent, dest_cluster, candidates, weight ); + + // If no candidates + if (candidates[0]!=-1) { + /*if (candidates[0] == -1) { + std::cout << "Appending candidates\n"; + candidates.clear(); + for (unsigned col = 0; col < P.size(); ++col) { + if (W[parent*P.size()+col] > 0) + for ( + candidates.push_back(col); + reset = 2; + } + }*/ + // Look for candidate that has not been used at this level or previous + // levels + /*for (unsigned i = 0; i < candidates.size(); ++i) { + bool exit = true; + int last = scan_row.size()-1; + for (auto it = new_topo.begin(); it != new_topo.end(); ++it) { + std::cout << "Testing " << candidates[i] << " " << it->second << std::endl; + if (candidates[i] == it->second) { + std::cout << candidates[i] << " has been encountered before\n"; + exit = false; + break; + } + } + if (exit) { + child = candidates[i]; + std::cout << "GPU " << child << " not found before!\n"; + break; + } + }*/ + std::shuffle(candidates.begin(), candidates.end(), gen); + child = candidates[0]; + } + + if (child == -1) { + //std::cout << "No path to other cluster found from " << parent << " at level " << scan_row.size() << std::endl; + new_roots.insert(parent); + + //child = parent; + return 1; + /*else { + child = parent; + std::cout << "Best link (case 4): " << parent << " -> " << child << ": " << std::endl; + }*/ + } else { + //std::cout << "Best link (case 3): " << parent << " -> " << child << ": " << weight << std::endl; + new_roots.insert(parent); + new_roots.insert(child); + } + } + + new_topo[parent] = child; + int num_rows = P.size(); + } + + int depth = scan_row.size(); + int start = scan_row[depth-2]; + int end = scan_row[depth-1]; + + for (int i = start; i < end; ++i) { + int parent = topo_row[i]; + int child; + + // If not first, check previous level whether or not we are encountering + // this root for the first time in this level of the tree + if (i != start && parent == topo_row[i-1]) + child = parent; + else + child = new_topo[parent]; + topo_row.push_back(parent); + topo_row.push_back(child); + //std::cout << "New pair: " << parent << " " << child << " " << new_topo[parent] << std::endl; + } + + cluster_pairs.clear(); + roots.clear(); + roots = std::move(new_roots); + + return reset; +} + +int ComputeDepth( int n ) { + for (int depth = 0; depth < MAX_DEPTH; ++depth) { + int num = 2 << depth; + if (n <= num) + return depth+1; + } +} + +template +bool IsValid( const std::vector& W, + const std::vector& state, + int num_elements, + int row, + int depth ) { + + for (int i = 0; i < depth; ++i) { + int stride = 1 << i; + for (unsigned j = 0; j+stride < row; j += 2*stride) { + int from = state[j]; + int dest = state[j+stride]; + //std::cout << "Comparing " << j << " and " << j+stride << " in row " << row << std::endl; + if (W[from*num_elements + dest] <= static_cast(0) && from != dest) { + //std::cout << "Not valid: no edge from " << from << " to " << dest << std::endl; + return false; + } + } + } + + std::unordered_set found; + std::vector found_vec(num_elements,0); + for (auto val : state) { + if (val == -1) + continue; + if (val < num_elements) { + if (found.find(val) == found.end()) { + found.insert(val); + found_vec[val] = 1; + } + } else { + //std::cout << "Not valid: " << val << " exceeds # of GPUs\n"; + return false; + } + } + int modifier = (1 << depth) - num_elements; + int num_found= found.size(); + + if (row < num_elements) { + if (num_found > row || num_found < row - modifier) { + //std::cout << "Not valid: " << found.size() << " rows found but expected between " << row << " and " << row - modifier << std::endl; + return false; + } + } else if (row == state.size()) + for (int i = 0; i < num_elements; ++i) + if (found_vec[i] == 0) + return false; + + return true; +} + +void Postprocess( std::vector& result, int num_elements, int depth) { + + std::vector histogram(num_elements, 0); + for (unsigned i = 0; i < result.size(); ++i) { + int val = result[i]; + histogram[val]++; + } + + for (int i = 0; i == 0; ++i) { + int stride = 1 << i; + for (int j = result.size()-1; j-stride >= 0; j -= 2*stride) { + //std::cout << "Comparing " << j << " and " << j-stride << std::endl; + int from = result[j]; + int dest = result[j-stride]; + if (histogram[from] > 1 && from != dest) { + //PrintVector("Old histogram", histogram); + //std::cout << "Swapping from " << from << " to " << dest << " on indices " << j << " and " << j-stride << std::endl; + result[j] = dest; + histogram[from]--; + //PrintVector("New histogram", histogram); + //PrintVector("New result", result); + } + } + } +} + +template +T GetTreeWeight( const std::vector& W, + const std::vector& result, + int num_elements, + int depth) { + T weight = 0.f; + std::unordered_set links_used; + + for (int i = 0; i < depth; ++i) { + int stride = 1 << i; + std::vector nodes_used(num_elements, false); + for (unsigned j = 0; j+stride < result.size(); j += 2*stride) { + int from = result[j]; + int dest = result[j+stride]; + if (from != dest) { + weight += W[from*num_elements+dest]; + + // Penalize: (1) use of redundant edges in a single tree + // (2) repeated use of a GPU in a single tree at the same + // level above the leaf level + if (links_used.find(from*num_elements+dest) != links_used.end()) { + weight -= 100; + //std::cout << "Penalty 1: " << from << " to " << dest << std::endl; + } + links_used.insert(from*num_elements+dest); + links_used.insert(dest*num_elements+from); + //std::cout << "Not valid: no edge from " << from << " to " << dest << std::endl; + } + + nodes_used[from] = true; + if (i > 0 && nodes_used[dest]) { + weight -= 10; + //std::cout << "Penalty 2: " << from << " and " << dest << " seen before\n"; + } + nodes_used[dest] = true; + } + } + + return weight; +} + +void FormTopology( const std::vector& result, + std::vector& topo_row, + std::vector& scan_row, + int depth ) { + scan_row.push_back(topo_row.size()); + for (int i = depth; i > 0; --i) { + int stride = 1 << i; + for (unsigned j = 0; j < result.size(); j += stride) { + int from = result[j]; + topo_row.push_back(from); + } + scan_row.push_back(topo_row.size()); + } + + // Insert at the end, result vector + topo_row.insert(topo_row.end(), result.begin(), result.end()); +} + +template +void Backtrack( const std::vector& W, + std::vector& state, + std::vector& best_result, + T& best_result_weight, + int row, + int num_elements, + int depth ) { + if (row == state.size()) { + std::vector result = state; + Postprocess(result, num_elements, depth); + T weight = GetTreeWeight(W, result, num_elements, depth); + if (weight > best_result_weight) { + std::swap(best_result_weight, weight); + best_result = result; + //std::cout << "New best weight: " << best_result_weight << " > " << weight << std::endl; + //PrintVector("New best", result); + } else { + //std::cout << "Not best weight: " << weight << " < " << best_result_weight << std::endl; + //PrintVector("Not best", result); + } + return; + } + + for (unsigned j = 0; j < num_elements; ++j) { + state[row] = j; + //PrintVector("Trying state", state); + if (IsValid(W, state, num_elements, row+1, depth)) { + Backtrack( W, state, best_result, best_result_weight, row+1, num_elements, + depth ); + state[row] = -1; + } else + state[row] = -1; + } +} + +template +void UpdateWeight( std::vector& W, + const std::vector& topo_row, + int num_elements, + float alpha ) { + for (unsigned i = 0; i < topo_row.size() - 1; i += 2) { + unsigned parent = topo_row[i]; + unsigned child = topo_row[i+1]; + if (parent >= num_elements*num_elements || + child >= num_elements*num_elements) + std::cout << "W array out of bounds\n"; + else if (parent != child) { + W[parent*num_elements+child] *= alpha; + W[child*num_elements+parent] *= alpha; + } + } +} + +// Do brute-force backtracking approach if Kernighan-Lin fails to find a binary +// tree of height Log P +// Metrics: +// 1) minimize depth +// 2) maximize edge weight +template +void BacktrackingGenerateBinaryTree( std::vector& W, + int num_elements, + int root, + std::vector& topo_row, + std::vector& scan_row ) { + + // Clear before starting + topo_row.clear(); + scan_row.clear(); + + // Compute depth + // 5: 3 + // 6: 3 + // 7: 3 + // 8: 3 + // 9: 4 + int depth = ComputeDepth(num_elements); + int depth_leaves = 1< state(depth_leaves, -1); + std::vector result(depth_leaves, -1); + T result_weight = std::numeric_limits::lowest(); + + // Place root and try all combinations + state[0] = root; + //PrintVector("state", state); + + Backtrack( W, state, result, result_weight, 1, num_elements, depth ); + FormTopology( result, topo_row, scan_row, depth ); +} + +template +void PartitionGraphFromRoot( std::vector& W, + int num_elements, + int root, + std::vector>& topo, + std::vector>& scan, + float alpha ) { + + int num_partitions = 1; + + // Initialize partition array to indicate which partition each element belongs + // to beginning with 0 + std::vector P(num_elements, 0); + + // Initialize vector of pairs that will tell us edges between what 2 clusters + // we should be looking to build the tree from + std::vector> cluster_pairs; + + // Initialize vector of roots that will tell us edges between + std::unordered_set roots; + roots.insert(root); + + // Will be used to obtain a seed for the random number engine + // RNG: Standard mersenne_twister_engine seeded with rd() + // -use 0 for testing (TODO: remove this) + std::random_device rd; + std::mt19937 gen(1); + //std::mt19937 gen(rd()); + + // Temporary variables for rewinding + std::vector P_temp; + int num_partitions_temp; + std::unordered_set roots_temp; + std::vector topo_temp; + std::vector scan_temp; + + // Determine number of partition levels + // If first partition, determine root of maximal spanning tree + bool stop = false; + int reset = 1; + int level = 0; + + bool backtrack = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 1); + while (!backtrack && (!stop || reset)) { + if (reset == 1) { + cluster_pairs.clear(); + P_temp = P; + num_partitions_temp = num_partitions; + roots_temp = roots; + topo_temp = topo[root]; + scan_temp = scan[root]; + } + + // Run Kernighan-Lin to generate partition + stop = KernighanLin(W, P_temp, num_partitions_temp, cluster_pairs, gen); + //PrintVector("New partition", P_temp); + + // Use partitions found and a given root to find best inter-cluster edge for // each pair of clusters, and returns them as roots of next cluster + // If reset is true, then rewind back to previous clustering + reset = GenerateBinaryTree(W, P_temp, cluster_pairs, roots_temp, + topo_temp, scan_temp, gen); + + if (reset) + level++; + if (level > 10) break; + } + + if (reset == 1) { + if (!backtrack) + std::cout << "No valid binary tree found from root " << root << ", try backtracking\n"; + //std::cout << "Trying backtracking\n"; + BacktrackingGenerateBinaryTree(W, num_elements, root, topo[root], + scan[root]); + } else { + topo[root] = topo_temp; + scan[root] = scan_temp; + } + UpdateWeight( W, topo[root], num_elements, alpha ); +} + +// Generalization from num_elements to list of devices done using zero_dev_id +// mapping, which gets us from 0, 1, ..., n_gpus to dev_id +template +void PartitionGraph( const std::vector& W, + int num_elements, + const std::vector& zero_dev_id, + std::vector>& topo, + std::vector>& scan, + float alpha=0.7 ) { + std::vector W_copy = W; + + topo.clear(); + scan.clear(); + for (int i = 0; i < num_elements; ++i) { + topo.push_back(std::vector()); + scan.push_back(std::vector()); + topo[i].push_back(i); + scan[i].push_back(0); + PartitionGraphFromRoot(W_copy, num_elements, i, topo, scan, alpha); + scan[i].push_back(topo[i].size()); + } + + // Note: must sum up adj matrix to show link usage before we readjust topo + // from 0, 1, ..., n_gpus format to dev_id format, which will cause segfault + std::vector adj(W.size(), 0); + for (int row = 0; row < num_elements; ++row) { + for (unsigned col = 1; col < topo[0].size(); col += 2) { + int from = std::min(topo[row][col], topo[row][col+1]); + int dest = std::max(topo[row][col], topo[row][col+1]); + if (from != dest) { + adj[from*num_elements+dest] += 1; + adj[dest*num_elements+from] += 1; + } + } + } + + std::vector> topo_temp(num_elements, + std::vector()); + for (int i = 0; i < num_elements; ++i) { + for (unsigned j = 0; j < topo[i].size(); ++j) { + int val = topo[i][j]; + topo_temp[i].push_back( zero_dev_id[val] ); + } + PrintTopo("Topo_temp", topo_temp[i], scan[i]); + } + + PrintMatrix("Links", adj, num_elements, num_elements); + bool backtrack = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 1); + if (backtrack) + LOG(WARNING) << "Using Backtracking to generate trees"; + else + LOG(WARNING) << "Using Kernighan-Lin to generate trees"; +} + +} // namespace kvstore +} // namespace mxnet +#endif // MXNET_KVSTORE_GPU_TOPOLOGY_H diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 38ecf121dfeb..11763db03c61 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -34,6 +34,7 @@ #include #include #include "./comm.h" +#include "./comm_tree.h" #include "./kvstore_utils.h" #include "../ndarray/ndarray_function.h" @@ -56,7 +57,12 @@ class KVStoreLocal : public KVStore { */ explicit KVStoreLocal(bool use_device_comm) : KVStore() { if (use_device_comm) { - comm_ = new CommDevice(); + bool tree = dmlc::GetEnv("MXNET_USE_TREE", 1); + if (tree) { + comm_ = new CommDeviceTree(); + } else { + comm_ = new CommDevice(); + } } else { comm_ = new CommCPU(); } From d5e51d65a13d0c74e58697f976d09f2c925cb93b Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 4 Jun 2018 04:33:05 +0000 Subject: [PATCH 02/36] fix bug with UpdateWeight --- src/kvstore/gpu_topology.h | 2 +- src/kvstore/kvstore_local.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 11d091889014..06a6afa3e8e2 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -780,7 +780,7 @@ void UpdateWeight( std::vector& W, const std::vector& topo_row, int num_elements, float alpha ) { - for (unsigned i = 0; i < topo_row.size() - 1; i += 2) { + for (unsigned i = 1; i < topo_row.size() - 1; i += 2) { unsigned parent = topo_row[i]; unsigned child = topo_row[i+1]; if (parent >= num_elements*num_elements || diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 11763db03c61..df622de5f36c 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -57,7 +57,7 @@ class KVStoreLocal : public KVStore { */ explicit KVStoreLocal(bool use_device_comm) : KVStore() { if (use_device_comm) { - bool tree = dmlc::GetEnv("MXNET_USE_TREE", 1); + bool tree = dmlc::GetEnv("MXNET_KVSTORE_USETREE", 1); if (tree) { comm_ = new CommDeviceTree(); } else { From 0708dbcd45283dde890370593472eed1ed8495b4 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 4 Jun 2018 04:38:15 +0000 Subject: [PATCH 03/36] fix PCI-E links appearing in weight matrix bug --- src/kvstore/gpu_topology.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 06a6afa3e8e2..ff59f12510d8 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -152,7 +152,7 @@ void GetP2PWeight( std::vector& matrix, // If all GPUs have at least 1 NVLink connection, then we can use NVLink only // to communicate instead of going over PCI-E if (max_value > 0) { - for (auto matrix_value : matrix) { + for (auto& matrix_value : matrix) { matrix_value = (matrix_value==1) ? 0 : matrix_value; } } From 55909204234fc232fffa18b28725f835947bb4bc Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 4 Jun 2018 22:27:25 +0000 Subject: [PATCH 04/36] optimization to skip CopyFromTo in ReduceInner gains a bit of throughput --- src/kvstore/comm_tree.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index de9082c3a80a..75e34c0ba9b8 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -52,6 +52,7 @@ class CommDeviceTree : public Comm { inited_ = false; bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 10000000); link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_LINK_USAGE_PENALTY", 0.7); + stream_ = dmlc::GetEnv("MXNET_KVSTORE_STREAM", 1); } virtual ~CommDeviceTree() { } @@ -169,7 +170,7 @@ class CommDeviceTree : public Comm { } } } else { - LOG(WARNING) << "Only dense input supported for now"; + //LOG(WARNING) << "Only dense input supported for now"; } int topo_id = topology[0]; @@ -301,13 +302,14 @@ class CommDeviceTree : public Comm { } void BroadcastInner(int key, const NDArray& src, - const std::vector dst, int root, int merged_row, + const std::vector& dst, int root, int merged_row, int priority) { // copy to root of tree std::vector& topology = topology_[root]; std::vector temp(devs_.size()); int gpu_id = topology[0]; - CopyFromTo(src, dst[gpu_id], priority); + if (merged_row == -1) + CopyFromTo(src, dst[gpu_id], priority); temp[gpu_id] = *dst[gpu_id]; //LOG(WARNING) << "Bcast copy from " << src.ctx() << " to " << buf.merged[merged_row].ctx(); @@ -376,7 +378,7 @@ class CommDeviceTree : public Comm { } else { int root = 0; //LOG(WARNING) << "Executing single tree broadcast for key " << key << " root " << root; - BroadcastInner(key, src, dst, root, 0, priority); + BroadcastInner(key, src, dst, root, -1, priority); } } } @@ -641,6 +643,7 @@ class CommDeviceTree : public Comm { int depth_; int bigarray_bound_; bool inited_; + bool stream_; float link_usage_penalty_; /// \brief constant for maximum size of recv buffer per GPU From 4f8f58b0ba4355d7b1ae5985521486aaa6ed7ab6 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Tue, 5 Jun 2018 17:15:28 +0000 Subject: [PATCH 05/36] remove unnecessary if statement --- src/kvstore/comm_tree.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 75e34c0ba9b8..775fcff51116 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -368,11 +368,9 @@ class CommDeviceTree : public Comm { for (unsigned gpu_id = 0; gpu_id < dst.size(); ++gpu_id) { BufferEntry& buf = merge_buf_[gpu_id][key]; for (unsigned i = 0; i < devs_.size(); ++i) { - if ( devs_[gpu_id] == dst[gpu_id]->ctx() ) { - NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i+1]); - CopyFromTo(buf.merged[i], &curr_slice, priority); + NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i+1]); + CopyFromTo(buf.merged[i], &curr_slice, priority); //LOG(WARNING) << "Bcast return copy from " << buf.merged[i].ctx() << " to " << curr_slice.ctx(); - } } } } else { From 908534a43f24f1454df0b98a0c5994aded5a04c5 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Fri, 15 Jun 2018 21:45:52 +0000 Subject: [PATCH 06/36] Add tests --- src/kvstore/comm_tree.h | 14 +- src/kvstore/gpu_topology.h | 17 ++- tests/cpp/kvstore/gpu_topology_test.cc | 204 +++++++++++++++++++++++++ 3 files changed, 221 insertions(+), 14 deletions(-) create mode 100644 tests/cpp/kvstore/gpu_topology_test.cc diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 775fcff51116..f215139842cd 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -18,7 +18,7 @@ */ /** - * Copyright (c) 2015 by Contributors + * Copyright (c) 2018 by Contributors */ #ifndef MXNET_KVSTORE_COMM_TREE_H_ #define MXNET_KVSTORE_COMM_TREE_H_ @@ -67,7 +67,7 @@ class CommDeviceTree : public Comm { for (const auto& a : src) { devs_.push_back(a.ctx()); } - GetTopology(); + QueryTopology(); InitMergeBuffer(); if (dmlc::GetEnv("MXNET_ENABLE_GPU_P2P", 1)) { EnableP2P(); @@ -368,9 +368,11 @@ class CommDeviceTree : public Comm { for (unsigned gpu_id = 0; gpu_id < dst.size(); ++gpu_id) { BufferEntry& buf = merge_buf_[gpu_id][key]; for (unsigned i = 0; i < devs_.size(); ++i) { - NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i+1]); - CopyFromTo(buf.merged[i], &curr_slice, priority); + if ( devs_[gpu_id] == dst[gpu_id]->ctx() ) { + NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i+1]); + CopyFromTo(buf.merged[i], &curr_slice, priority); //LOG(WARNING) << "Bcast return copy from " << buf.merged[i].ctx() << " to " << curr_slice.ctx(); + } } } } else { @@ -482,7 +484,7 @@ class CommDeviceTree : public Comm { #endif } - void GetTopology() { + void QueryTopology() { #if MXNET_USE_CUDA std::vector link_matrix(devs_.size()*devs_.size()); @@ -491,7 +493,7 @@ class CommDeviceTree : public Comm { PartitionGraph( link_matrix, devs_.size(), zero_dev_id, topology_, scan_, link_usage_penalty_ ); - depth_ = ComputeDepth(devs_.size()); + depth_ = ComputeDepth(devs_.size()); #endif } diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index ff59f12510d8..e4e91a6f7368 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -566,7 +566,6 @@ int GenerateBinaryTree( std::vector& W, } new_topo[parent] = child; - int num_rows = P.size(); } int depth = scan_row.size(); @@ -579,7 +578,7 @@ int GenerateBinaryTree( std::vector& W, // If not first, check previous level whether or not we are encountering // this root for the first time in this level of the tree - if (i != start && parent == topo_row[i-1]) + if (i != start && parent == static_cast(topo_row[i-1])) child = parent; else child = new_topo[parent]; @@ -601,6 +600,7 @@ int ComputeDepth( int n ) { if (n <= num) return depth+1; } + return 0; } template @@ -612,7 +612,7 @@ bool IsValid( const std::vector& W, for (int i = 0; i < depth; ++i) { int stride = 1 << i; - for (unsigned j = 0; j+stride < row; j += 2*stride) { + for (int j = 0; j+stride < row; j += 2*stride) { int from = state[j]; int dest = state[j+stride]; //std::cout << "Comparing " << j << " and " << j+stride << " in row " << row << std::endl; @@ -646,7 +646,7 @@ bool IsValid( const std::vector& W, //std::cout << "Not valid: " << found.size() << " rows found but expected between " << row << " and " << row - modifier << std::endl; return false; } - } else if (row == state.size()) + } else if (row == static_cast(state.size())) for (int i = 0; i < num_elements; ++i) if (found_vec[i] == 0) return false; @@ -681,7 +681,7 @@ void Postprocess( std::vector& result, int num_elements, int depth) { } template -T GetTreeWeight( const std::vector& W, +T ComputeTreeWeight( const std::vector& W, const std::vector& result, int num_elements, int depth) { @@ -737,6 +737,7 @@ void FormTopology( const std::vector& result, // Insert at the end, result vector topo_row.insert(topo_row.end(), result.begin(), result.end()); + scan_row.push_back(topo_row.size()); } template @@ -747,10 +748,10 @@ void Backtrack( const std::vector& W, int row, int num_elements, int depth ) { - if (row == state.size()) { + if (row == static_cast(state.size())) { std::vector result = state; Postprocess(result, num_elements, depth); - T weight = GetTreeWeight(W, result, num_elements, depth); + T weight = ComputeTreeWeight(W, result, num_elements, depth); if (weight > best_result_weight) { std::swap(best_result_weight, weight); best_result = result; @@ -763,7 +764,7 @@ void Backtrack( const std::vector& W, return; } - for (unsigned j = 0; j < num_elements; ++j) { + for (int j = 0; j < num_elements; ++j) { state[row] = j; //PrintVector("Trying state", state); if (IsValid(W, state, num_elements, row+1, depth)) { diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc new file mode 100644 index 000000000000..6331f0997941 --- /dev/null +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file gpu_topology_test.cc + * \brief gpu topology tests +*/ + +#include +#include +#include +#include "../src/kvstore/gpu_topology.h" + +#define NUM_GPUS 8 + +// Permutes matrix W using permutation vector P and stores output in matrix A +// Assumption: W is square and symmetric +void PermuteMatrix( const std::vector& W, + const std::vector& P, + std::vector& A ) { + int nrows = P.size(); + std::vector temp(nrows*nrows,0); + + int count = 0; + for (int row=0; row state0 = {3, 2, 1, 5, 0, 0, 4, 6}; + std::vector topo0; + std::vector scan0; + std::vector correct0= {3, 3, 0, 3, 1, 0, 4, 3, 2, 1, 5, 0, 0, 4, 6}; + std::vector correct_scan0 = {0, 1, 3, 7, 15}; + mxnet::kvstore::FormTopology(state0, topo0, scan0, 3); + ASSERT_EQ(topo0.size(), correct0.size()); + for (unsigned i = 0; i < correct0.size(); ++i) + ASSERT_EQ(static_cast(topo0[i]), correct0[i]); + ASSERT_EQ(scan0.size(), correct_scan0.size()); + for (unsigned i = 0; i < correct_scan0.size(); ++i) + ASSERT_EQ(scan0[i], correct_scan0[i]); + + std::vector state1 = {3, 2, 0, 4, 1, 1, 5, 6}; + std::vector topo1; + std::vector scan1; + std::vector correct1= {3, 3, 1, 3, 0, 1, 5, 3, 2, 0, 4, 1, 1, 5, 6}; + std::vector correct_scan1 = {0, 1, 3, 7, 15}; + mxnet::kvstore::FormTopology(state1, topo1, scan1, 3); + ASSERT_EQ(topo1.size(), correct1.size()); + for (unsigned i = 0; i < correct1.size(); ++i) + ASSERT_EQ(topo1[i], correct1[i]); + ASSERT_EQ(scan1.size(), correct_scan1.size()); + for (unsigned i = 0; i < correct_scan1.size(); ++i) + ASSERT_EQ(scan1[i], correct_scan1[i]); +} + +TEST(ComputeTreeWeightTest, TestComputeTreeWeight) { + + std::vector W = {0, 2, 2, 3, 3, 0, 0, + 2, 0, 3, 2, 0, 3, 0, + 2, 3, 0, 3, 0, 0, 2, + 3, 2, 3, 0, 0, 0, 0, + 3, 0, 0, 0, 0, 2, 2, + 0, 3, 0, 0, 2, 0, 3, + 0, 0, 2, 0, 2, 3, 0}; + + std::vector state0 = {3, 2, 1, 5, 0, 0, 4, 6}; + ASSERT_EQ(mxnet::kvstore::ComputeTreeWeight(W, state0, 7, 3), 16); + + std::vector state1 = {3, 2, 0, 4, 1, 1, 5, 6}; + ASSERT_EQ(mxnet::kvstore::ComputeTreeWeight(W, state1, 7, 3), 17); +} + +TEST(PostprocessTest, TestPostprocess) { + std::vector result0 = {3, 0, 0, 4, 1, 2, 5, 6}; + std::vector correct0= {3, 3, 0, 4, 1, 2, 5, 6}; + mxnet::kvstore::Postprocess( result0, 7, 3 ); + for (unsigned i = 0; i < correct0.size(); ++i) + ASSERT_EQ(result0[i], correct0[i]); + + std::vector result1 = {2, 0, 0, 4, 1, 3, 5, 1}; + std::vector correct1= {2, 2, 0, 4, 1, 3, 5, 5}; + mxnet::kvstore::Postprocess( result1, 6, 3 ); + for (unsigned i = 0; i < correct1.size(); ++i) + ASSERT_EQ(result1[i], correct1[i]); + + std::vector result2 = {5, 4, 1, 3, 1, 0, 2, 0}; + std::vector correct2= {5, 4, 1, 3, 1, 0, 2, 2}; + mxnet::kvstore::Postprocess( result2, 6, 3 ); + for (unsigned i = 0; i < correct2.size(); ++i) + ASSERT_EQ(result2[i], correct2[i]); +} + +TEST(ComputeDepthTest, TestDepth) { + ASSERT_EQ(mxnet::kvstore::ComputeDepth(8), 3); + ASSERT_EQ(mxnet::kvstore::ComputeDepth(7), 3); + ASSERT_EQ(mxnet::kvstore::ComputeDepth(5), 3); + ASSERT_EQ(mxnet::kvstore::ComputeDepth(4), 2); + ASSERT_EQ(mxnet::kvstore::ComputeDepth(16), 4); +} + +TEST(IsValidTest, TestIsValid) { + + std::vector W = {0, 2, 2, 3, 3, 0, 0, + 2, 0, 3, 2, 0, 3, 0, + 2, 3, 0, 3, 0, 0, 2, + 3, 2, 3, 0, 0, 0, 0, + 3, 0, 0, 0, 0, 2, 2, + 0, 3, 0, 0, 2, 0, 3, + 0, 0, 2, 0, 2, 3, 0}; + + std::vector state0 = {3, 2, 1, 5, 0, 0, 4, 6}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state0, 7, 7, 3), true); + + // 3 connects to 1 first + std::vector state1 = {3, 2, 0, 4, 1, 1, 5, 6}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state1, 7, 7, 3), true); + + // 3 does not connect to 5 + std::vector state2 = {3, 2, 5, 1, 0, 4, 2, 5}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state2, 7, 7, 3), false); + + // 7 exceeds number of GPUs + std::vector state3 = {3, 7, 2, 6, 0, 1, 4, 5}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state3, 7, 7, 3), false); + + // Test -1 + std::vector state4 = {3, -1, 2, 6, 0, 1, 4, 5}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state4, 7, 7, 3), true); + + // Test -1 + std::vector state5 = {3, -1, 2, 6, 0, 1, 4, -1}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state5, 7, 8, 3), false); + + // Test 1 row + std::vector state6 = {3, -1, -1, -1, -1, -1, -1, -1}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state6, 7, 1, 3), true); + +} + +TEST(PermuteMatrixTest, TestIdentity) { + + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 2, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + + std::vector P1 = {0, 1, 2, 3, 4, 5, 6, 7}; + std::vector A(NUM_GPUS*NUM_GPUS, 0); + PermuteMatrix( W, P1, A ); + //PrintMatrix("P1", A, NUM_GPUS, NUM_GPUS); + for (unsigned i=0; i W = {0, 1, 2, 3, 2, 4, + 1, 0, 1, 4, 2, 1, + 2, 1, 0, 3, 2, 1, + 3, 4, 3, 0, 4, 3, + 2, 2, 2, 4, 0, 2, + 4, 1, 1, 3, 2, 0}; + std::vector P(6, 0); + std::vector> cluster_pairs; + int num_partitions = 1; + std::mt19937 gen(1); + mxnet::kvstore::KernighanLin( W, P, num_partitions, cluster_pairs, gen ); + +} From 25cbbdc25269397b29a05c56badcdef70cd682ee Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Sat, 16 Jun 2018 00:17:03 +0000 Subject: [PATCH 07/36] add more tests, 6 tests left to add --- src/kvstore/gpu_topology.h | 24 ++- tests/cpp/kvstore/gpu_topology_test.cc | 269 +++++++++++++++++++++++-- 2 files changed, 269 insertions(+), 24 deletions(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index e4e91a6f7368..71bfa05d3399 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -161,10 +161,11 @@ void GetP2PWeight( std::vector& matrix, // Dense matrix-vector multiplication // Assume: matrix is square +// y = A*x (no accumulate) template -void gemv( const std::vector& A, +void gemv( const std::vector& A, const std::vector& x, - std::vector& y ) { + std::vector& y ) { int nrows = x.size(); int count = 0; for (int row=0; row& W, // more than 4 elements in it // @output: stop returns true if no partitions with >=4 elements found // returns false otherwise +// cluster_pairs stores the mapping that tells us which 2 clusters are +// the output of partitioning one large cluster template bool KernighanLin( const std::vector& W, std::vector& P, @@ -427,20 +430,23 @@ int GetRoot( const std::vector& P, int GetChild( const std::vector& P, int color, int parent ) { - int size = P.size(); - for (int i = 0; i < size; ++i) { + for (unsigned i = 0; i < P.size(); ++i) { //std::cout << "Child " << i << ": " << P[i] << std::endl; - if (P[i] == color && i != parent) + if (P[i] == color && static_cast(i) != parent) return i; } return -1; } -// Computes best 2 nodes a,b to swap given objective function: -// g = max_{a \in A, b \in B} 2*W(a,b) +// Computes highest weighted edge a-b // -// Optimization: Only need to look at upper triangular since weight matrix is -// symmetric +// Contraints: +// -vertex a must be parent +// -vertex b must be in dest_cluster +// +// @output: b is vector of candidates if a tie happens +// g is weight of edge +// Optimization: Only need to look at row a in matrix template void FindBestEdge( const std::vector& W, const std::vector& P, diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index 6331f0997941..c09df7646a84 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -28,8 +28,6 @@ #include #include "../src/kvstore/gpu_topology.h" -#define NUM_GPUS 8 - // Permutes matrix W using permutation vector P and stores output in matrix A // Assumption: W is square and symmetric void PermuteMatrix( const std::vector& W, @@ -57,7 +55,7 @@ void PermuteMatrix( const std::vector& W, } } -TEST(FormTopologyTest, TestFormTopology) { +TEST(GpuTopology, TestFormTopology) { std::vector state0 = {3, 2, 1, 5, 0, 0, 4, 6}; std::vector topo0; std::vector scan0; @@ -69,7 +67,7 @@ TEST(FormTopologyTest, TestFormTopology) { ASSERT_EQ(static_cast(topo0[i]), correct0[i]); ASSERT_EQ(scan0.size(), correct_scan0.size()); for (unsigned i = 0; i < correct_scan0.size(); ++i) - ASSERT_EQ(scan0[i], correct_scan0[i]); + ASSERT_EQ(static_cast(scan0[i]), correct_scan0[i]); std::vector state1 = {3, 2, 0, 4, 1, 1, 5, 6}; std::vector topo1; @@ -79,13 +77,13 @@ TEST(FormTopologyTest, TestFormTopology) { mxnet::kvstore::FormTopology(state1, topo1, scan1, 3); ASSERT_EQ(topo1.size(), correct1.size()); for (unsigned i = 0; i < correct1.size(); ++i) - ASSERT_EQ(topo1[i], correct1[i]); + ASSERT_EQ(static_cast(topo1[i]), correct1[i]); ASSERT_EQ(scan1.size(), correct_scan1.size()); for (unsigned i = 0; i < correct_scan1.size(); ++i) - ASSERT_EQ(scan1[i], correct_scan1[i]); + ASSERT_EQ(static_cast(scan1[i]), correct_scan1[i]); } -TEST(ComputeTreeWeightTest, TestComputeTreeWeight) { +TEST(GpuTopology, TestComputeTreeWeight) { std::vector W = {0, 2, 2, 3, 3, 0, 0, 2, 0, 3, 2, 0, 3, 0, @@ -102,7 +100,7 @@ TEST(ComputeTreeWeightTest, TestComputeTreeWeight) { ASSERT_EQ(mxnet::kvstore::ComputeTreeWeight(W, state1, 7, 3), 17); } -TEST(PostprocessTest, TestPostprocess) { +TEST(GpuTopology, TestPostprocess) { std::vector result0 = {3, 0, 0, 4, 1, 2, 5, 6}; std::vector correct0= {3, 3, 0, 4, 1, 2, 5, 6}; mxnet::kvstore::Postprocess( result0, 7, 3 ); @@ -122,7 +120,7 @@ TEST(PostprocessTest, TestPostprocess) { ASSERT_EQ(result2[i], correct2[i]); } -TEST(ComputeDepthTest, TestDepth) { +TEST(GpuTopology, TestDepth) { ASSERT_EQ(mxnet::kvstore::ComputeDepth(8), 3); ASSERT_EQ(mxnet::kvstore::ComputeDepth(7), 3); ASSERT_EQ(mxnet::kvstore::ComputeDepth(5), 3); @@ -130,7 +128,7 @@ TEST(ComputeDepthTest, TestDepth) { ASSERT_EQ(mxnet::kvstore::ComputeDepth(16), 4); } -TEST(IsValidTest, TestIsValid) { +TEST(GpuTopology, TestIsValid) { std::vector W = {0, 2, 2, 3, 3, 0, 0, 2, 0, 3, 2, 0, 3, 0, @@ -166,10 +164,195 @@ TEST(IsValidTest, TestIsValid) { // Test 1 row std::vector state6 = {3, -1, -1, -1, -1, -1, -1, -1}; ASSERT_EQ(mxnet::kvstore::IsValid(W, state6, 7, 1, 3), true); +} +// gemvTest +TEST(GpuTopology, TestGemv) { + std::vector A = {0, 2, 2, 3, 3, 1, 1, 1, // 13 + 2, 0, 3, 2, 1, 3, 1, 1, // 13 + 2, 3, 0, 3, 1, 1, 2, 1, // 13 + 3, 2, 3, 0, 1, 1, 1, 2, // 13 + 3, 1, 1, 1, 0, 2, 2, 3, // 13 + 1, 3, 1, 1, 2, 0, 3, 2, // 13 + 1, 1, 2, 1, 2, 3, 0, 3, // 13 + 1, 1, 1, 2, 3, 2, 3, 0}; // 13 + std::vector x(8, 1); + std::vector y(8, 0); + std::iota(y.begin(), y.end(), 0); + std::vector correct_y(8, 13); + mxnet::kvstore::gemv( A, x, y ); + + ASSERT_EQ(y.size(), correct_y.size()); + for (unsigned i = 0; i < y.size(); ++i ) + ASSERT_EQ( y[i], correct_y[i] ); } -TEST(PermuteMatrixTest, TestIdentity) { +// ewisemultTest +TEST(GpuTopology, TestEwisemult) { + std::vector x(8, 1); + std::vector y(8, 0); + std::iota(y.begin(), y.end(), 0); + int alpha = 5; + std::vector correct_y = {0, 5, 10, 15, 20, 25, 30, 35}; + mxnet::kvstore::ewisemult( x, alpha, y ); + + ASSERT_EQ(y.size(), correct_y.size()); + for (unsigned i = 0; i < y.size(); ++i ) + ASSERT_EQ( y[i], correct_y[i] ); +} + +// ewiseaddTest +TEST(GpuTopology, TestEwiseadd) { + std::vector x(8, 1); + std::vector y(8, 0); + std::iota(y.begin(), y.end(), 0); + int alpha = 5; + std::vector correct_y(8,0); + std::iota(correct_y.begin(), correct_y.end(), 5); + mxnet::kvstore::ewiseadd( x, alpha, y ); + + ASSERT_EQ(y.size(), correct_y.size()); + for (unsigned i = 0; i < y.size(); ++i ) + ASSERT_EQ( y[i], correct_y[i] ); +} + +// FindBestMoveTest +TEST(GpuTopology, TestFindBestMove) { + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 2, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + std::vector P(8, 0); + std::iota(P.begin(), P.end(), 1); + std::unordered_set used; + + std::vector D1 = {20,0, 0, 0, 0, 0, 0,20}; + int a1, b1, g1; + int correct_a1 = 0; + int correct_b1 = 7; + int correct_g1 = 38; + mxnet::kvstore::FindBestMove( W, P, D1, used, a1, b1, g1 ); + ASSERT_EQ(a1, correct_a1); + ASSERT_EQ(b1, correct_b1); + ASSERT_EQ(g1, correct_g1); + + // -1, -1, 0 indicates no best edge found + std::vector D2 = {0, 0, 0, 0, 0, 0, 0, 0}; + int a2, b2, g2; + int correct_a2 = -1; + int correct_b2 = -1; + int correct_g2 = 0; + mxnet::kvstore::FindBestMove( W, P, D2, used, a2, b2, g2 ); + ASSERT_EQ(a2, correct_a2); + ASSERT_EQ(b2, correct_b2); + ASSERT_EQ(g2, correct_g2); +} + +// GetRootTest +TEST(GpuTopology, TestGetRoot) { + std::vector P = {0, 0, 1, 1, 2, 2, 3, 3}; + + // Test when roots are non-empty, and matches color + std::unordered_set roots1 = {0, 2, 4, 6}; + std::vector color1 = {0, 1, 2, 3}; + for (unsigned i = 0; i < color1.size(); ++i) { + int root1 = mxnet::kvstore::GetRoot(P, color1[i], roots1); + int correct_root1 = 2*i; + ASSERT_EQ(root1, correct_root1); + } + + // Test when roots is empty + std::unordered_set roots2; + int color2 = 0; + int correct_root2 = -1; + int root2 = mxnet::kvstore::GetRoot(P, color2, roots2); + ASSERT_EQ(root2, correct_root2); + + // Test when roots is non-empty, but no root matches color + std::unordered_set roots3 = {0}; + int color3 = 1; + int correct_root3 = -1; + int root3 = mxnet::kvstore::GetRoot(P, color3, roots3); + ASSERT_EQ(root3, correct_root3); +} + +// GetChildTest +TEST(GpuTopology, TestGetChild) { + std::vector P = {0, 0, 1, 2, 2, 2, 3, 3}; + + // Test when color is not found + int color1 = 4; + int parent1= 4; + int correct_child1 = -1; + int child1 = mxnet::kvstore::GetChild(P, color1, parent1); + ASSERT_EQ(child1, correct_child1); + + // Test when color is found, but is equal to parent + int color2 = 1; + int parent2= 2; + int correct_child2 = -1; + int child2 = mxnet::kvstore::GetChild(P, color2, parent2); + ASSERT_EQ(child2, correct_child2); + + // Test when color is found and not equal to parent + int color3 = 3; + int parent3= 6; + int correct_child3 = 7; + int child3 = mxnet::kvstore::GetChild(P, color3, parent3); + ASSERT_EQ(child3, correct_child3); +} + +// FindBestEdgeTest +TEST(GpuTopology, TestFindBestEdge) { + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 2, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + std::vector P(8, 0); + std::unordered_set used; + + int parent1 = 3; + int dest1 = 0; + std::vector b1; + int g1; + std::vector correct_b1 = {0, 2}; + int correct_g1 = 3; + mxnet::kvstore::FindBestEdge( W, P, parent1, dest1, b1, g1 ); + ASSERT_EQ(b1.size(), correct_b1.size()); + for (unsigned i = 0; i < b1.size(); ++i) + ASSERT_EQ(b1[i], correct_b1[i]); + ASSERT_EQ(g1, correct_g1); + + // {-1}, 0 indicates no best edge found + int parent2 = 4; + int dest2 = 1; + std::vector b2; + int g2; + std::vector correct_b2 = {-1}; + int correct_g2 = 0; + mxnet::kvstore::FindBestEdge( W, P, parent2, dest2, b2, g2 ); + ASSERT_EQ(b2.size(), correct_b2.size()); + for (unsigned i = 0; i < b2.size(); ++i) + ASSERT_EQ(b2[i], correct_b2[i]); + ASSERT_EQ(g2, correct_g2); +} + +// GenerateBinaryTreeTest +// Backtrack +// UpdateWeight +// BacktrackingGenerateBinaryTree +// PartitionGraphFromRoot +// PartitionGraph + +TEST(GpuTopology, TestPermuteMatrix) { std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, 2, 0, 3, 2, 1, 3, 1, 1, @@ -181,14 +364,13 @@ TEST(PermuteMatrixTest, TestIdentity) { 1, 1, 1, 2, 3, 2, 3, 0}; std::vector P1 = {0, 1, 2, 3, 4, 5, 6, 7}; - std::vector A(NUM_GPUS*NUM_GPUS, 0); + std::vector A(8*8, 0); PermuteMatrix( W, P1, A ); - //PrintMatrix("P1", A, NUM_GPUS, NUM_GPUS); for (unsigned i=0; i W = {0, 1, 2, 3, 2, 4, 1, 0, 1, 4, 2, 1, 2, 1, 0, 3, 2, 1, @@ -199,6 +381,63 @@ TEST(KernighanLinTest, Test1) { std::vector> cluster_pairs; int num_partitions = 1; std::mt19937 gen(1); - mxnet::kvstore::KernighanLin( W, P, num_partitions, cluster_pairs, gen ); + bool stop = mxnet::kvstore::KernighanLin( W, P, num_partitions, cluster_pairs, gen ); + + std::vector> correct_pairs; + correct_pairs.push_back(std::make_pair(0,1)); + std::vector correct_P = {0, 1, 0, 1, 1, 0}; + ASSERT_EQ(stop, false); + ASSERT_EQ(num_partitions, 2); + ASSERT_EQ(cluster_pairs.size(), correct_pairs.size()); + for (unsigned i = 0; i < cluster_pairs.size(); ++i) { + ASSERT_EQ(cluster_pairs[i].first, correct_pairs[i].first); + ASSERT_EQ(cluster_pairs[i].second, correct_pairs[i].second); + } + ASSERT_EQ(P.size(), correct_P.size()); + unsigned error = 0; + for (unsigned i = 0; i < P.size(); ++i) { + if (P[i] != correct_P[i]) + error++; + } + EXPECT_TRUE (error == 0 || error == P.size()) + << "Where real value: " << error + << " not equal neither: " << 0 + << " nor: " << P.size() << "."; +} +TEST(GpuTopology, TestKernighanLin2) { + std::vector W = {0, 1, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 1, 0, 1, 1, 1, + 0, 0, 1, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 1, 0, 0, + 1, 1, 1, 0, 1, 0, 0, 0, + 0, 0, 1, 1, 0, 0, 0, 1, + 0, 0, 1, 1, 0, 0, 1, 0}; + std::vector P(8, 0); + std::vector> cluster_pairs; + int num_partitions = 1; + std::mt19937 gen(1); + bool stop = mxnet::kvstore::KernighanLin( W, P, num_partitions, cluster_pairs, gen ); + + std::vector> correct_pairs; + correct_pairs.push_back(std::make_pair(0,1)); + std::vector correct_P = {0, 0, 1, 1, 0, 0, 1, 1}; + ASSERT_EQ(stop, false); + ASSERT_EQ(num_partitions, 2); + ASSERT_EQ(cluster_pairs.size(), correct_pairs.size()); + for (unsigned i = 0; i < cluster_pairs.size(); ++i) { + ASSERT_EQ(cluster_pairs[i].first, correct_pairs[i].first); + ASSERT_EQ(cluster_pairs[i].second, correct_pairs[i].second); + } + ASSERT_EQ(P.size(), correct_P.size()); + unsigned error = 0; + for (unsigned i = 0; i < P.size(); ++i) { + if (P[i] != correct_P[i]) + error++; + } + EXPECT_TRUE (error == 0 || error == P.size()) + << "Where real value: " << error + << " not equal neither: " << 0 + << " nor: " << P.size() << "."; } From 310ee4d168eb28033a558b578b76fd7fe203df8b Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Sat, 16 Jun 2018 00:31:58 +0000 Subject: [PATCH 08/36] get rid of some dead code --- src/kvstore/gpu_topology.h | 40 +++++--------------------- tests/cpp/kvstore/gpu_topology_test.cc | 5 ++++ 2 files changed, 12 insertions(+), 33 deletions(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 71bfa05d3399..072a2bad45de 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -473,6 +473,13 @@ void FindBestEdge( const std::vector& W, } // Given a vector of color pairs, appends to binary tree matrix topo +// @input: cluster_pairs gives pairing between clusters, an edge is found +// between each pairing +// roots gives source vertex +// gen gives random number generation to break ties +// @output: cluster_pairs +// topo_row says where new edges are appended to +// scan_row says where we should start looking for topo_row template int GenerateBinaryTree( std::vector& W, const std::vector& P, @@ -521,35 +528,6 @@ int GenerateBinaryTree( std::vector& W, // If no candidates if (candidates[0]!=-1) { - /*if (candidates[0] == -1) { - std::cout << "Appending candidates\n"; - candidates.clear(); - for (unsigned col = 0; col < P.size(); ++col) { - if (W[parent*P.size()+col] > 0) - for ( - candidates.push_back(col); - reset = 2; - } - }*/ - // Look for candidate that has not been used at this level or previous - // levels - /*for (unsigned i = 0; i < candidates.size(); ++i) { - bool exit = true; - int last = scan_row.size()-1; - for (auto it = new_topo.begin(); it != new_topo.end(); ++it) { - std::cout << "Testing " << candidates[i] << " " << it->second << std::endl; - if (candidates[i] == it->second) { - std::cout << candidates[i] << " has been encountered before\n"; - exit = false; - break; - } - } - if (exit) { - child = candidates[i]; - std::cout << "GPU " << child << " not found before!\n"; - break; - } - }*/ std::shuffle(candidates.begin(), candidates.end(), gen); child = candidates[0]; } @@ -560,10 +538,6 @@ int GenerateBinaryTree( std::vector& W, //child = parent; return 1; - /*else { - child = parent; - std::cout << "Best link (case 4): " << parent << " -> " << child << ": " << std::endl; - }*/ } else { //std::cout << "Best link (case 3): " << parent << " -> " << child << ": " << weight << std::endl; new_roots.insert(parent); diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index c09df7646a84..6e29ba78e8e2 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -346,6 +346,11 @@ TEST(GpuTopology, TestFindBestEdge) { } // GenerateBinaryTreeTest +TEST(GpuTopology, TestGenerateBinaryTree) { + + mxnet::kvstore::GenerateBinaryTree(); +} + // Backtrack // UpdateWeight // BacktrackingGenerateBinaryTree From 9cce8eac2cc1a3b5f969ae9db7376ab3a14626ea Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 18 Jun 2018 02:16:50 +0000 Subject: [PATCH 09/36] Add comments --- src/kvstore/gpu_topology.h | 103 ++++++++++++++++++++++++++++++------- 1 file changed, 83 insertions(+), 20 deletions(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 072a2bad45de..92254cc4ce08 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -574,6 +574,8 @@ int GenerateBinaryTree( std::vector& W, return reset; } +// @input: n is the number of nodes in a balanced binary tree +// @output: returns how many levels of binary tree there are int ComputeDepth( int n ) { for (int depth = 0; depth < MAX_DEPTH; ++depth) { int num = 2 << depth; @@ -583,6 +585,11 @@ int ComputeDepth( int n ) { return 0; } +// Checks whether a given state forms a spanning tree that satisfies: +// -balanced +// -binary +// -each edge in tree corresponds to link in network topology +// -each edge in tree does not form self-loop template bool IsValid( const std::vector& W, const std::vector& state, @@ -590,6 +597,9 @@ bool IsValid( const std::vector& W, int row, int depth ) { + // At each level of tree, check whether edge: + // -corresponds to link in network topology + // -corresponds to self-loop for (int i = 0; i < depth; ++i) { int stride = 1 << i; for (int j = 0; j+stride < row; j += 2*stride) { @@ -603,6 +613,8 @@ bool IsValid( const std::vector& W, } } + // If we encounter GPU for first time, increment found_vec. + // Otherwise, do nothing std::unordered_set found; std::vector found_vec(num_elements,0); for (auto val : state) { @@ -618,14 +630,24 @@ bool IsValid( const std::vector& W, return false; } } + + // modifier is maximum number of repeats a single GPU can take + // e.g. 5 GPUs in 3-level binary tree => one GPU can repeat 3x + // GPU0 GPU0 GPU0 GPU0 GPU1 GPU2 GPU3 GPU4 int modifier = (1 << depth) - num_elements; int num_found= found.size(); + // So we know we have an invalid state if we find: + // -only 4 unique GPUs + // -9 unique GPUs if (row < num_elements) { if (num_found > row || num_found < row - modifier) { //std::cout << "Not valid: " << found.size() << " rows found but expected between " << row << " and " << row - modifier << std::endl; return false; } + + // If we are at last recursive level, we can apply a more stringent check: + // -if some GPU is not found, then we are in invalid state } else if (row == static_cast(state.size())) for (int i = 0; i < num_elements; ++i) if (found_vec[i] == 0) @@ -634,6 +656,24 @@ bool IsValid( const std::vector& W, return true; } +// This function takes a spanning tree encoded as state (result), which may have// repeated GPUs representing NO-SENDs and converts it into a unique format. +// This has the effect of recognizing redundant sends, grouping them together, +// so that the Reduce call knows not to perform a CopyFromTo. +// +// Initial result: [3 0 0 4 1 2 5 6] +// Final result: [3 3 0 4 1 2 5 6] +// +// Initial: +// 3 +// 3 1 +// 3 0 1 5 +// 3 0 0 4 1 2 5 6 // GPU3 will make redundant send to GPU0 +// +// Final: +// 3 +// 3 1 +// 3 0 1 5 +// 3 3 0 4 1 2 5 6 // GPU3 knows not to make redundant send to itself void Postprocess( std::vector& result, int num_elements, int depth) { std::vector histogram(num_elements, 0); @@ -642,29 +682,29 @@ void Postprocess( std::vector& result, int num_elements, int depth) { histogram[val]++; } - for (int i = 0; i == 0; ++i) { - int stride = 1 << i; - for (int j = result.size()-1; j-stride >= 0; j -= 2*stride) { - //std::cout << "Comparing " << j << " and " << j-stride << std::endl; - int from = result[j]; - int dest = result[j-stride]; - if (histogram[from] > 1 && from != dest) { - //PrintVector("Old histogram", histogram); - //std::cout << "Swapping from " << from << " to " << dest << " on indices " << j << " and " << j-stride << std::endl; - result[j] = dest; - histogram[from]--; - //PrintVector("New histogram", histogram); - //PrintVector("New result", result); + int stride = 1; + for (int j = result.size()-1; j-stride >= 0; j -= 2*stride) { + //std::cout << "Comparing " << j << " and " << j-stride << std::endl; + int from = result[j]; + int dest = result[j-stride]; + if (histogram[from] > 1 && from != dest) { + //PrintVector("Old histogram", histogram); + //std::cout << "Swapping from " << from << " to " << dest << " on indices " << j << " and " << j-stride << std::endl; + result[j] = dest; + histogram[from]--; + //PrintVector("New histogram", histogram); + //PrintVector("New result", result); } } - } } +// Given a spanning tree encoded as a state (result) and weight of each edge +// in the link topology graph, compute its weight. template T ComputeTreeWeight( const std::vector& W, - const std::vector& result, - int num_elements, - int depth) { + const std::vector& result, + int num_elements, + int depth) { T weight = 0.f; std::unordered_set links_used; @@ -701,6 +741,20 @@ T ComputeTreeWeight( const std::vector& W, return weight; } +// Given a spanning tree encoded as result, which was convenient for performing +// backtracking, convert it topology_ and scan_ in the classic "binary tree +// stored in an array" format. For binary trees scan_ is redundant, but this +// additional data structure leaves future generalization to k-radix trees. +// +// Initial result: [3 3 0 4 1 2 5 6] +// topology_: [3 3 1 3 0 1 5 3 3 0 4 1 2 5 6] +// scan_: [0 1 3 7 15] +// +// topology_ is stored in the classic "binary tree stored in an array" format +// e.g. 3 +// 3 1 +// 3 0 1 5 +// 3 3 0 4 1 2 5 6 void FormTopology( const std::vector& result, std::vector& topo_row, std::vector& scan_row, @@ -720,6 +774,11 @@ void FormTopology( const std::vector& result, scan_row.push_back(topo_row.size()); } +// Recursive function that finds a spanning tree, which fulfills the following +// conditions: +// -balanced +// -binary +// -maximum weight template void Backtrack( const std::vector& W, std::vector& state, @@ -756,6 +815,8 @@ void Backtrack( const std::vector& W, } } +// Apply penalty factor alpha to each link in link topology graph that is used +// by the spanning tree template void UpdateWeight( std::vector& W, const std::vector& topo_row, @@ -775,10 +836,12 @@ void UpdateWeight( std::vector& W, } // Do brute-force backtracking approach if Kernighan-Lin fails to find a binary -// tree of height Log P -// Metrics: -// 1) minimize depth +// tree of height Log P. +// +// Constraints: +// 1) minimize depth (balance) // 2) maximize edge weight +// 3) tree is binary template void BacktrackingGenerateBinaryTree( std::vector& W, int num_elements, From 4d2790df8186a2edd68a8d0eb226d93a973ed0b6 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 18 Jun 2018 08:35:33 +0000 Subject: [PATCH 10/36] Add randomized tests for backtrack and kernighan-lin --- src/kvstore/comm_tree.h | 14 +- src/kvstore/gpu_topology.h | 191 +++++++++++------------ tests/cpp/kvstore/gpu_topology_test.cc | 207 ++++++++++++++++++++++++- 3 files changed, 304 insertions(+), 108 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index f215139842cd..1e420b8be363 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -51,6 +51,7 @@ class CommDeviceTree : public Comm { CommDeviceTree() { inited_ = false; bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 10000000); + backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 1); link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_LINK_USAGE_PENALTY", 0.7); stream_ = dmlc::GetEnv("MXNET_KVSTORE_STREAM", 1); } @@ -486,12 +487,14 @@ class CommDeviceTree : public Comm { void QueryTopology() { #if MXNET_USE_CUDA - std::vector link_matrix(devs_.size()*devs_.size()); - std::vector zero_dev_id(devs_.size(), -1); - GetP2PWeight( link_matrix, devs_, zero_dev_id ); - PartitionGraph( link_matrix, devs_.size(), zero_dev_id, topology_, scan_, - link_usage_penalty_ ); + GetP2PWeight( devs_, link_matrix ); + if (backtrack_) + LOG(WARNING) << "Using Backtracking to generate trees"; + else + LOG(WARNING) << "Using Kernighan-Lin to generate trees"; + ComputeTrees( link_matrix, devs_.size(), link_usage_penalty_, backtrack_, + topology_, scan_ ); depth_ = ComputeDepth(devs_.size()); #endif @@ -644,6 +647,7 @@ class CommDeviceTree : public Comm { int bigarray_bound_; bool inited_; bool stream_; + bool backtrack_; float link_usage_penalty_; /// \brief constant for maximum size of recv buffer per GPU diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 92254cc4ce08..4632a2b3a11e 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -58,28 +58,18 @@ void PrettyPrintTopology(const std::vector> topo) { } } -void PrintTopo( const std::string& str, const std::vector& topo_row, - std::vector scan_row ) { +template +void PrintVector( const std::string& str, const std::vector& vec ) { std::cout << str << ":\n"; - int depth = scan_row.size()-1; - for (int row = 0; row < depth; ++row) { - int start = scan_row[row]; - int end = scan_row[row+1]; - for (; start void PrintMatrix( const std::string& str, const std::vector& matrix, int num_rows, int num_cols ) { - + PrintVector("Matrix vector", matrix); std::cout << str << ":\n"; int count = 0; for (int row = 0; row < num_rows; ++row) { @@ -90,33 +80,40 @@ void PrintMatrix( const std::string& str, const std::vector& matrix, } } -template -void PrintVector( const std::string& str, const std::vector& vec ) { +void PrintTopo( const std::string& str, const std::vector& topo_row, + std::vector scan_row ) { + PrintVector("Topo vector", topo_row); + PrintVector("Scan vector", scan_row); std::cout << str << ":\n"; - for (unsigned i = 0; i < vec.size(); ++i) - std::cout << vec[i] << " "; - std::cout << std::endl; + int depth = scan_row.size()-1; + for (int row = 0; row < depth; ++row) { + int start = scan_row[row]; + int end = scan_row[row+1]; + for (; start -void GetP2PWeight( std::vector& matrix, - const std::vector& devs, - std::vector& zero_dev_id, - bool print=false ) { +void GetP2PWeight( const std::vector& devs, + std::vector& matrix ) { int num_gpus = devs.size(); int count = 0; + std::vector zero_dev_id(num_gpus, -1); for (auto d : devs) { zero_dev_id[count] = d.dev_id; count++; @@ -475,19 +472,19 @@ void FindBestEdge( const std::vector& W, // Given a vector of color pairs, appends to binary tree matrix topo // @input: cluster_pairs gives pairing between clusters, an edge is found // between each pairing -// roots gives source vertex -// gen gives random number generation to break ties +// roots gives source vertex +// gen gives random number generation to break ties // @output: cluster_pairs // topo_row says where new edges are appended to -// scan_row says where we should start looking for topo_row +// scan_row says where we should start looking for topo_row template -int GenerateBinaryTree( std::vector& W, - const std::vector& P, - std::vector>& cluster_pairs, - std::unordered_set& roots, - std::vector& topo_row, - std::vector& scan_row, - std::mt19937& gen ) { +int KLGenerateBinaryTree( std::vector& W, + const std::vector& P, + std::vector>& cluster_pairs, + std::unordered_set& roots, + std::vector& topo_row, + std::vector& scan_row, + std::mt19937& gen ) { std::unordered_set new_roots; std::unordered_map new_topo; int reset = 0; @@ -499,7 +496,7 @@ int GenerateBinaryTree( std::vector& W, //std::cout << "Pair " << i << ": " << cluster_pairs[i].first << " " << cluster_pairs[i].second << std::endl; int parent, child = -1; if (cluster_pairs[i].second==-2) { - // Root must exist in first element of pair + // Root must be color of pair.first int color = cluster_pairs[i].first; parent = GetRoot( P, color, roots ); if (parent == -1) return 1; @@ -606,8 +603,8 @@ bool IsValid( const std::vector& W, int from = state[j]; int dest = state[j+stride]; //std::cout << "Comparing " << j << " and " << j+stride << " in row " << row << std::endl; - if (W[from*num_elements + dest] <= static_cast(0) && from != dest) { - //std::cout << "Not valid: no edge from " << from << " to " << dest << std::endl; + if (W[from*num_elements + dest] == static_cast(0) && from != dest) { + //std::cout << "Not valid: no edge from " << from << " to " << dest << " at index " << from*num_elements+dest << std::endl; return false; } } @@ -650,8 +647,10 @@ bool IsValid( const std::vector& W, // -if some GPU is not found, then we are in invalid state } else if (row == static_cast(state.size())) for (int i = 0; i < num_elements; ++i) - if (found_vec[i] == 0) + if (found_vec[i] == 0) { + //std::cout << "Not valid: " << i << " not found" << std::endl; return false; + } return true; } @@ -791,6 +790,8 @@ void Backtrack( const std::vector& W, std::vector result = state; Postprocess(result, num_elements, depth); T weight = ComputeTreeWeight(W, result, num_elements, depth); + + // Save this spanning tree if it is highest weight tree found sofar if (weight > best_result_weight) { std::swap(best_result_weight, weight); best_result = result; @@ -803,6 +804,7 @@ void Backtrack( const std::vector& W, return; } + // If not last recursive level, try to find valid tree for next level for (int j = 0; j < num_elements; ++j) { state[row] = j; //PrintVector("Trying state", state); @@ -827,8 +829,9 @@ void UpdateWeight( std::vector& W, unsigned child = topo_row[i+1]; if (parent >= num_elements*num_elements || child >= num_elements*num_elements) - std::cout << "W array out of bounds\n"; + LOG(WARNING) << "W array out of bounds\n"; else if (parent != child) { + //std::cout << W[child*num_elements+parent] << " " << alpha << std::endl; W[parent*num_elements+child] *= alpha; W[child*num_elements+parent] *= alpha; } @@ -843,11 +846,11 @@ void UpdateWeight( std::vector& W, // 2) maximize edge weight // 3) tree is binary template -void BacktrackingGenerateBinaryTree( std::vector& W, - int num_elements, - int root, - std::vector& topo_row, - std::vector& scan_row ) { +void BacktrackGenerateBinaryTree( std::vector& W, + int num_elements, + int root, + std::vector& topo_row, + std::vector& scan_row ) { // Clear before starting topo_row.clear(); @@ -870,19 +873,22 @@ void BacktrackingGenerateBinaryTree( std::vector& W, // Place root and try all combinations state[0] = root; - //PrintVector("state", state); Backtrack( W, state, result, result_weight, 1, num_elements, depth ); + //PrintVector("result", result); FormTopology( result, topo_row, scan_row, depth ); } +// ComputeTreesFromRoot does the same thing as ComputeTrees, with the only +// exception being it will do it from a fixed GPU as root template -void PartitionGraphFromRoot( std::vector& W, - int num_elements, - int root, - std::vector>& topo, - std::vector>& scan, - float alpha ) { +void ComputeTreesFromRoot( std::vector& W, + int num_elements, + int root, + float alpha, + bool backtrack, + std::vector& topo, + std::vector& scan ) { int num_partitions = 1; @@ -918,15 +924,14 @@ void PartitionGraphFromRoot( std::vector& W, int reset = 1; int level = 0; - bool backtrack = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 1); while (!backtrack && (!stop || reset)) { if (reset == 1) { cluster_pairs.clear(); P_temp = P; num_partitions_temp = num_partitions; roots_temp = roots; - topo_temp = topo[root]; - scan_temp = scan[root]; + topo_temp = topo; + scan_temp = scan; } // Run Kernighan-Lin to generate partition @@ -935,7 +940,7 @@ void PartitionGraphFromRoot( std::vector& W, // Use partitions found and a given root to find best inter-cluster edge for // each pair of clusters, and returns them as roots of next cluster // If reset is true, then rewind back to previous clustering - reset = GenerateBinaryTree(W, P_temp, cluster_pairs, roots_temp, + reset = KLGenerateBinaryTree(W, P_temp, cluster_pairs, roots_temp, topo_temp, scan_temp, gen); if (reset) @@ -947,24 +952,30 @@ void PartitionGraphFromRoot( std::vector& W, if (!backtrack) std::cout << "No valid binary tree found from root " << root << ", try backtracking\n"; //std::cout << "Trying backtracking\n"; - BacktrackingGenerateBinaryTree(W, num_elements, root, topo[root], - scan[root]); + BacktrackGenerateBinaryTree(W, num_elements, root, topo, scan); } else { - topo[root] = topo_temp; - scan[root] = scan_temp; + topo = topo_temp; + scan = scan_temp; + scan.push_back(topo.size()); } - UpdateWeight( W, topo[root], num_elements, alpha ); + UpdateWeight( W, topo, num_elements, alpha ); } -// Generalization from num_elements to list of devices done using zero_dev_id -// mapping, which gets us from 0, 1, ..., n_gpus to dev_id +// ComputeTrees computes balanced binary spanning trees of maximum edge weight +// given a link topology graph stored in adjacency matrix format +// @input: W is the link topology matrix +// num_elements is the number of GPUs +// alpha is the link usage penalty +// backtrack is whether or not we use backtracking to generate trees +// @output: topo stores the trees generated +// scan stores the start of each level of each tree template -void PartitionGraph( const std::vector& W, - int num_elements, - const std::vector& zero_dev_id, - std::vector>& topo, - std::vector>& scan, - float alpha=0.7 ) { +void ComputeTrees( const std::vector& W, + int num_elements, + float alpha, + bool backtrack, + std::vector>& topo, + std::vector>& scan ) { std::vector W_copy = W; topo.clear(); @@ -974,8 +985,8 @@ void PartitionGraph( const std::vector& W, scan.push_back(std::vector()); topo[i].push_back(i); scan[i].push_back(0); - PartitionGraphFromRoot(W_copy, num_elements, i, topo, scan, alpha); - scan[i].push_back(topo[i].size()); + ComputeTreesFromRoot(W_copy, num_elements, i, alpha, backtrack, topo[i], + scan[i]); } // Note: must sum up adj matrix to show link usage before we readjust topo @@ -994,20 +1005,10 @@ void PartitionGraph( const std::vector& W, std::vector> topo_temp(num_elements, std::vector()); - for (int i = 0; i < num_elements; ++i) { - for (unsigned j = 0; j < topo[i].size(); ++j) { - int val = topo[i][j]; - topo_temp[i].push_back( zero_dev_id[val] ); - } - PrintTopo("Topo_temp", topo_temp[i], scan[i]); - } + for (int i = 0; i < num_elements; ++i) + PrintTopo("Topo", topo[i], scan[i]); PrintMatrix("Links", adj, num_elements, num_elements); - bool backtrack = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 1); - if (backtrack) - LOG(WARNING) << "Using Backtracking to generate trees"; - else - LOG(WARNING) << "Using Kernighan-Lin to generate trees"; } } // namespace kvstore diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index 6e29ba78e8e2..c86d7e5b4d14 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -28,6 +28,71 @@ #include #include "../src/kvstore/gpu_topology.h" +void GenerateMatrix( std::vector& W, int num_gpus, float k, + std::mt19937& gen) { + std::uniform_real_distribution<> dis(0., 1.); + for (int row = 0; row < num_gpus; ++row) { + for (int col = row+1; col < num_gpus; ++col) { + float sample = dis(gen); + if (sample < k) + continue; + sample = dis(gen); + if (sample < 0.33f) { + W[row*num_gpus+col] = 1.f; + W[col*num_gpus+row] = 1.f; + } else if (sample < 0.66f) { + W[row*num_gpus+col] = 2.f; + W[col*num_gpus+row] = 2.f; + } else { + W[row*num_gpus+col] = 3.f; + W[col*num_gpus+row] = 3.f; + } + } + } +} + +bool IsSatisfactory( const std::vector& W, int num_gpus, int depth ) { + for (int row = 0; row < num_gpus; ++row) { + int out_edges = 0; + for (int col = 0; col < num_gpus; ++col) { + if (W[row*num_gpus+col] > 0.f) + out_edges++; + } + if (out_edges < depth) + return false; + } + return true; +} + +// Generates random link topology matrix using random number generator +void TestComputeTreesRandomized( int num_gpus, float alpha, int backtrack, + std::mt19937& gen ) { + std::uniform_real_distribution<> dis(0.f, 1.f); + bool satisfied = false; + std::vector W(num_gpus*num_gpus, 0.f); + int depth = mxnet::kvstore::ComputeDepth(num_gpus); + while (!satisfied) { + float k = dis(gen); + std::fill(W.begin(), W.end(), 0.f); + GenerateMatrix(W, num_gpus, k, gen); + satisfied = IsSatisfactory(W, num_gpus, depth); + if (!satisfied) + LOG(WARNING) << k << " is not satisfactory"; + } + + std::vector> topo; + std::vector> scan; + //mxnet::kvstore::PrintMatrix("W", W, num_gpus, num_gpus); + mxnet::kvstore::ComputeTrees( W, num_gpus, alpha, backtrack, topo, scan ); + + unsigned correct_topo_size = (1 << (depth + 1)) - 1; + unsigned correct_scan_size = depth+2; + for (int i = 0; i < num_gpus; ++i) { + ASSERT_EQ(correct_topo_size, topo[i].size()); + ASSERT_EQ(correct_scan_size, scan[i].size()); + } +} + // Permutes matrix W using permutation vector P and stores output in matrix A // Assumption: W is square and symmetric void PermuteMatrix( const std::vector& W, @@ -278,6 +343,13 @@ TEST(GpuTopology, TestGetRoot) { int correct_root3 = -1; int root3 = mxnet::kvstore::GetRoot(P, color3, roots3); ASSERT_EQ(root3, correct_root3); + + std::vector P2 = {0, 1, 1, 0, 2, 3, 3, 2}; + std::unordered_set roots4 = roots1; + int color4 = 0; + int correct_root4 = 0; + int root4 = mxnet::kvstore::GetRoot(P, color4, roots4); + ASSERT_EQ(root4, correct_root4); } // GetChildTest @@ -345,20 +417,139 @@ TEST(GpuTopology, TestFindBestEdge) { ASSERT_EQ(g2, correct_g2); } -// GenerateBinaryTreeTest -TEST(GpuTopology, TestGenerateBinaryTree) { +// KLGenerateBinaryTreeTest +TEST(GpuTopology, TestKLGenerateBinaryTree1) { + std::vector W = {0, 2, 3, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 3, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + std::vector P = {0, 1, 1, 0, 2, 3, 3, 2}; + std::vector> cluster_pairs; + cluster_pairs.push_back(std::make_pair(0,-2)); + cluster_pairs.push_back(std::make_pair(1,-2)); + cluster_pairs.push_back(std::make_pair(2,-2)); + cluster_pairs.push_back(std::make_pair(3,-2)); + std::unordered_set roots = {0, 2, 4, 6}; + std::vector topo = {0, 2, 4, 6}; + std::vector scan(2,0); + std::mt19937 gen(1); + mxnet::kvstore::KLGenerateBinaryTree(W, P, cluster_pairs, roots, topo, scan, + gen); + std::vector correct_topo = {0, 2, 4, 6, 0, 3, 2, 1, 4, 7, 6, 5}; + std::vector correct_scan = {0, 0, 4}; + ASSERT_EQ(topo.size(), correct_topo.size()); + for (unsigned i = 0; i < topo.size(); ++i) + ASSERT_EQ(topo[i], correct_topo[i]); + ASSERT_EQ(scan.size(), correct_scan.size()); + for (unsigned i = 0; i < scan.size(); ++i) + ASSERT_EQ(scan[i], correct_scan[i]); +} - mxnet::kvstore::GenerateBinaryTree(); +TEST(GpuTopology, TestKLGenerateBinaryTree2) { + std::vector W = {0, 2, 3, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 3, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + std::vector P = {0, 1, 1, 0, 2, 3, 3, 2}; + std::vector> cluster_pairs; + cluster_pairs.push_back(std::make_pair(0,-2)); + cluster_pairs.push_back(std::make_pair(1,-2)); + cluster_pairs.push_back(std::make_pair(2,-2)); + cluster_pairs.push_back(std::make_pair(3,-2)); + std::unordered_set roots = {0, 2, 4, 6}; + std::vector topo = {0, 6, 4, 2}; + std::vector scan(2,0); + std::mt19937 gen(1); + mxnet::kvstore::KLGenerateBinaryTree(W, P, cluster_pairs, roots, topo, scan, + gen); + std::vector correct_topo = {0, 6, 4, 2, 0, 3, 6, 5, 4, 7, 2, 1}; + std::vector correct_scan = {0, 0, 4}; + ASSERT_EQ(topo.size(), correct_topo.size()); + for (unsigned i = 0; i < topo.size(); ++i) + ASSERT_EQ(topo[i], correct_topo[i]); + ASSERT_EQ(scan.size(), correct_scan.size()); + for (unsigned i = 0; i < scan.size(); ++i) + ASSERT_EQ(scan[i], correct_scan[i]); } +// UpdateWeightTest +TEST(GpuTopology, TestUpdateWeight) { + std::vector W = {0.f, 1.f, + 1.f, 0.f}; + std::vector topo= {1, 1, 0}; + int num_gpus = 2; + float alpha = 0.7; + std::vector correct_W = {0.f, 0.7f, + 0.7f, 0.f}; + mxnet::kvstore::UpdateWeight(W, topo, num_gpus, alpha); + ASSERT_EQ(W.size(), correct_W.size()); + for (unsigned i = 0; i < W.size(); ++i) { + ASSERT_EQ(W[i], correct_W[i]); + } +} // Backtrack -// UpdateWeight -// BacktrackingGenerateBinaryTree -// PartitionGraphFromRoot -// PartitionGraph +// BacktrackGenerateBinaryTree +// ComputeTreesFromRoot +TEST(GpuTopology, TestComputeTreesFromRoot) { + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 2, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + int num_gpus = 8; + int root = 0; + float alpha = 0.7; + bool backtrack = true; + unsigned correct_topo_size = 15; + unsigned correct_scan_size = 5; + std::vector topo; + std::vector scan; + + mxnet::kvstore::ComputeTreesFromRoot( W, num_gpus, root, alpha, backtrack, + topo, scan ); + + ASSERT_EQ(topo.size(), correct_topo_size); + ASSERT_EQ(scan.size(), correct_scan_size); +} -TEST(GpuTopology, TestPermuteMatrix) { +// ComputeTreesTest with backtracking +TEST(GpuTopology, TestComputeTrees1) { + std::mt19937 gen(1); + float alpha = 0.7; + bool backtrack = true; + // Do 100 randomized tests per GPU count from 2 to 16 + for (int num_gpus = 2; num_gpus <= 8; ++num_gpus) { + for (int i = 0; i < 5; ++i) { + TestComputeTreesRandomized( num_gpus, alpha, backtrack, gen ); + } + } +} +// ComputeTreesTest with Kernighan-Lin +TEST(GpuTopology, TestComputeTrees2) { + std::mt19937 gen(1); + float alpha = 0.7; + bool backtrack = false; + // Do 100 randomized tests per GPU count from 2 to 16 + for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { + for (int i = 0; i < 100; ++i) { + TestComputeTreesRandomized( num_gpus, alpha, backtrack, gen ); + } + } +} + +TEST(GpuTopology, TestPermuteMatrix) { std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, 2, 0, 3, 2, 1, 3, 1, 1, 2, 3, 0, 3, 1, 1, 2, 1, From b5b42bc37c9f8a65291f5c6a1299a9d85cc6ba55 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 18 Jun 2018 21:01:49 +0000 Subject: [PATCH 11/36] Fix Postprocess --- src/kvstore/gpu_topology.h | 89 +++++++++++++++++--------- tests/cpp/kvstore/gpu_topology_test.cc | 20 ++++-- 2 files changed, 71 insertions(+), 38 deletions(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 4632a2b3a11e..565d8d148080 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -645,12 +645,14 @@ bool IsValid( const std::vector& W, // If we are at last recursive level, we can apply a more stringent check: // -if some GPU is not found, then we are in invalid state - } else if (row == static_cast(state.size())) - for (int i = 0; i < num_elements; ++i) + } else if (row == static_cast(state.size())) { + for (int i = 0; i < num_elements; ++i) { if (found_vec[i] == 0) { //std::cout << "Not valid: " << i << " not found" << std::endl; return false; } + } + } return true; } @@ -674,36 +676,47 @@ bool IsValid( const std::vector& W, // 3 0 1 5 // 3 3 0 4 1 2 5 6 // GPU3 knows not to make redundant send to itself void Postprocess( std::vector& result, int num_elements, int depth) { - - std::vector histogram(num_elements, 0); - for (unsigned i = 0; i < result.size(); ++i) { - int val = result[i]; - histogram[val]++; - } - - int stride = 1; - for (int j = result.size()-1; j-stride >= 0; j -= 2*stride) { - //std::cout << "Comparing " << j << " and " << j-stride << std::endl; - int from = result[j]; - int dest = result[j-stride]; - if (histogram[from] > 1 && from != dest) { - //PrintVector("Old histogram", histogram); - //std::cout << "Swapping from " << from << " to " << dest << " on indices " << j << " and " << j-stride << std::endl; - result[j] = dest; - histogram[from]--; - //PrintVector("New histogram", histogram); - //PrintVector("New result", result); + for (int level = depth - 1; level >= 0; --level) { + int stride = 1 << level; + std::vector histogram_above(num_elements,0); + for (unsigned i = 0; i < result.size(); i += 2*stride) { + int val = result[i]; + histogram_above[val]++; + } + std::vector histogram(num_elements, 0); + for (unsigned i = 0; i < result.size(); i += stride) { + int val = result[i]; + histogram[val]++; + } + //PrintVector("above histo", histogram_above); + + for (int i = result.size()-stride; i-stride >= 0; i -= 2*stride) { + //std::cout << "Comparing " << i << " and " << i-stride << std::endl; + int from = result[i]; + int dest = result[i-stride]; + if ((histogram[from] > 1 || histogram_above[from] >= 1) && from != dest) { + //PrintVector("Old histogram", histogram); + //std::cout << "Swapping from " << from << " to " << dest << " on indices " << i << " and " << i-stride << std::endl; + result[i] = dest; + histogram[from]--; + //PrintVector("New histogram", histogram); + //PrintVector("New result", result); } } + } } // Given a spanning tree encoded as a state (result) and weight of each edge // in the link topology graph, compute its weight. +// @input: penalty controls whether or not penalties are applied to tree +// -usually turned on when backtracking to get better solutions +// -usually turned off when outside the penalty to get weight of tree template T ComputeTreeWeight( const std::vector& W, const std::vector& result, int num_elements, - int depth) { + int depth, + bool penalty ) { T weight = 0.f; std::unordered_set links_used; @@ -719,7 +732,8 @@ T ComputeTreeWeight( const std::vector& W, // Penalize: (1) use of redundant edges in a single tree // (2) repeated use of a GPU in a single tree at the same // level above the leaf level - if (links_used.find(from*num_elements+dest) != links_used.end()) { + if (links_used.find(from*num_elements+dest) != links_used.end() + && penalty) { weight -= 100; //std::cout << "Penalty 1: " << from << " to " << dest << std::endl; } @@ -729,7 +743,7 @@ T ComputeTreeWeight( const std::vector& W, } nodes_used[from] = true; - if (i > 0 && nodes_used[dest]) { + if (i > 0 && nodes_used[dest] && penalty) { weight -= 10; //std::cout << "Penalty 2: " << from << " and " << dest << " seen before\n"; } @@ -779,17 +793,18 @@ void FormTopology( const std::vector& result, // -binary // -maximum weight template -void Backtrack( const std::vector& W, +bool Backtrack( const std::vector& W, std::vector& state, std::vector& best_result, T& best_result_weight, int row, int num_elements, - int depth ) { + int depth, + bool optimal ) { if (row == static_cast(state.size())) { std::vector result = state; Postprocess(result, num_elements, depth); - T weight = ComputeTreeWeight(W, result, num_elements, depth); + T weight = ComputeTreeWeight(W, result, num_elements, depth, true); // Save this spanning tree if it is highest weight tree found sofar if (weight > best_result_weight) { @@ -801,20 +816,24 @@ void Backtrack( const std::vector& W, //std::cout << "Not best weight: " << weight << " < " << best_result_weight << std::endl; //PrintVector("Not best", result); } - return; + return !optimal; } // If not last recursive level, try to find valid tree for next level + bool stop = false; for (int j = 0; j < num_elements; ++j) { state[row] = j; //PrintVector("Trying state", state); if (IsValid(W, state, num_elements, row+1, depth)) { - Backtrack( W, state, best_result, best_result_weight, row+1, num_elements, - depth ); + stop = Backtrack( W, state, best_result, best_result_weight, row+1, + num_elements, depth, optimal ); state[row] = -1; } else state[row] = -1; + if (stop) + return stop; } + return stop; } // Apply penalty factor alpha to each link in link topology graph that is used @@ -857,6 +876,7 @@ void BacktrackGenerateBinaryTree( std::vector& W, scan_row.clear(); // Compute depth + // num_elements: depth // 5: 3 // 6: 3 // 7: 3 @@ -874,8 +894,14 @@ void BacktrackGenerateBinaryTree( std::vector& W, // Place root and try all combinations state[0] = root; - Backtrack( W, state, result, result_weight, 1, num_elements, depth ); + Backtrack( W, state, result, result_weight, 1, num_elements, depth, false ); + //result_weight = ComputeTreeWeight(W, result, num_elements, depth, false); + //Backtrack( W, state, result, result_weight, 1, num_elements, depth, true ); + //T result_weight2 = ComputeTreeWeight(W, result, num_elements, depth, false); //PrintVector("result", result); + //std::cout << "First solution reached " << + // (result_weight2-result_weight)/result_weight2 << " of optimal " << + // result_weight << " " << result_weight2 << "\n"; FormTopology( result, topo_row, scan_row, depth ); } @@ -1008,6 +1034,7 @@ void ComputeTrees( const std::vector& W, for (int i = 0; i < num_elements; ++i) PrintTopo("Topo", topo[i], scan[i]); + PrintMatrix("W", W, num_elements, num_elements); PrintMatrix("Links", adj, num_elements, num_elements); } diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index c86d7e5b4d14..d1aab257c085 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -76,8 +76,8 @@ void TestComputeTreesRandomized( int num_gpus, float alpha, int backtrack, std::fill(W.begin(), W.end(), 0.f); GenerateMatrix(W, num_gpus, k, gen); satisfied = IsSatisfactory(W, num_gpus, depth); - if (!satisfied) - LOG(WARNING) << k << " is not satisfactory"; + //if (!satisfied) + // LOG(WARNING) << k << " is not satisfactory"; } std::vector> topo; @@ -159,10 +159,10 @@ TEST(GpuTopology, TestComputeTreeWeight) { 0, 0, 2, 0, 2, 3, 0}; std::vector state0 = {3, 2, 1, 5, 0, 0, 4, 6}; - ASSERT_EQ(mxnet::kvstore::ComputeTreeWeight(W, state0, 7, 3), 16); + ASSERT_EQ(mxnet::kvstore::ComputeTreeWeight(W, state0, 7, 3, false), 16); std::vector state1 = {3, 2, 0, 4, 1, 1, 5, 6}; - ASSERT_EQ(mxnet::kvstore::ComputeTreeWeight(W, state1, 7, 3), 17); + ASSERT_EQ(mxnet::kvstore::ComputeTreeWeight(W, state1, 7, 3, false), 17); } TEST(GpuTopology, TestPostprocess) { @@ -179,10 +179,16 @@ TEST(GpuTopology, TestPostprocess) { ASSERT_EQ(result1[i], correct1[i]); std::vector result2 = {5, 4, 1, 3, 1, 0, 2, 0}; - std::vector correct2= {5, 4, 1, 3, 1, 0, 2, 2}; + std::vector correct2= {5, 4, 5, 3, 1, 0, 2, 2}; mxnet::kvstore::Postprocess( result2, 6, 3 ); for (unsigned i = 0; i < correct2.size(); ++i) ASSERT_EQ(result2[i], correct2[i]); + + std::vector result3 = {10,10, 0, 0, 0, 0, 0, 1, 2, 3, 6, 4, 7, 5, 8, 9}; + std::vector correct3= {10,10,10,10, 0, 0, 0, 1, 2, 3, 6, 4, 7, 5, 8, 9}; + mxnet::kvstore::Postprocess( result3, 11, 4 ); + for (unsigned i = 0; i < correct3.size(); ++i) + ASSERT_EQ(result3[i], correct3[i]); } TEST(GpuTopology, TestDepth) { @@ -529,7 +535,7 @@ TEST(GpuTopology, TestComputeTrees1) { float alpha = 0.7; bool backtrack = true; // Do 100 randomized tests per GPU count from 2 to 16 - for (int num_gpus = 2; num_gpus <= 8; ++num_gpus) { + for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { for (int i = 0; i < 5; ++i) { TestComputeTreesRandomized( num_gpus, alpha, backtrack, gen ); } @@ -543,7 +549,7 @@ TEST(GpuTopology, TestComputeTrees2) { bool backtrack = false; // Do 100 randomized tests per GPU count from 2 to 16 for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { - for (int i = 0; i < 100; ++i) { + for (int i = 0; i < 5; ++i) { TestComputeTreesRandomized( num_gpus, alpha, backtrack, gen ); } } From 6327ceb1c15f9be8d759bc8081116a8a870c6137 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 18 Jun 2018 22:40:22 +0000 Subject: [PATCH 12/36] Add switch for first valid tree when num_gpus > 8, and for maximum weight when num_gpus <= 8 --- src/kvstore/gpu_topology.h | 93 ++++++++++++++++++++++++++++++++++---- 1 file changed, 83 insertions(+), 10 deletions(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 565d8d148080..9479191c6307 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -69,7 +69,6 @@ void PrintVector( const std::string& str, const std::vector& vec ) { template void PrintMatrix( const std::string& str, const std::vector& matrix, int num_rows, int num_cols ) { - PrintVector("Matrix vector", matrix); std::cout << str << ":\n"; int count = 0; for (int row = 0; row < num_rows; ++row) { @@ -793,14 +792,14 @@ void FormTopology( const std::vector& result, // -binary // -maximum weight template -bool Backtrack( const std::vector& W, - std::vector& state, - std::vector& best_result, - T& best_result_weight, - int row, - int num_elements, - int depth, - bool optimal ) { +bool RecursiveBacktrack( const std::vector& W, + std::vector& state, + std::vector& best_result, + T& best_result_weight, + int row, + int num_elements, + int depth, + bool optimal ) { if (row == static_cast(state.size())) { std::vector result = state; Postprocess(result, num_elements, depth); @@ -836,6 +835,74 @@ bool Backtrack( const std::vector& W, return stop; } +template +void IterativeBacktrack( const std::vector& W, + std::vector& state, + std::vector& best_result, + T& best_result_weight, + int row, + int num_elements, + int depth, + bool optimal ) { + std::stack state_stack; + row = 1; + int pos = 0; + state_stack.push(pos); + + while (true) { + // If there is no valid position, 2 cases: + // a) if stack is empty, break and stop search + // b) if stack is not empty, pop stack and set current position to next + // position backtrack to previous row + while (!state_stack.empty() && pos >= num_elements) { + pos = state_stack.top(); + pos++; + state_stack.pop(); + state[state_stack.size()+1] = -1; + row--; + } + if (state_stack.empty()) break; + + state[row] = pos; + // If there is a valid position push the position to stack, set current + // position to 0 and move to next row + //PrintVector("Trying state", state); + if (IsValid(W, state, num_elements, row+1, depth)) { + state_stack.push(pos); + pos = 0; + row++; + } else { + pos++; + state[row] = -1; + } + + // If stack has size N, a solution is found + // Pop stack, set current position to next position + // Backtrack to find next solution + if (row == static_cast(state.size())) { + //PrintVector("state", state); + std::vector result = state; + Postprocess(result, num_elements, depth); + T weight = ComputeTreeWeight(W, result, num_elements, depth, true); + + // Save this spanning tree if it is highest weight tree found sofar + if (weight > best_result_weight) { + std::swap(best_result_weight, weight); + best_result = result; + //std::cout << "New best weight: " << best_result_weight << " > " << weight << std::endl; + //PrintVector("New best", result); + } + if (!optimal) break; + + pos = state_stack.top(); + pos++; + state_stack.pop(); + state[state_stack.size()+1] = -1; + row--; + } + } +} + // Apply penalty factor alpha to each link in link topology graph that is used // by the spanning tree template @@ -894,7 +961,13 @@ void BacktrackGenerateBinaryTree( std::vector& W, // Place root and try all combinations state[0] = root; - Backtrack( W, state, result, result_weight, 1, num_elements, depth, false ); + // Seek optimal solution until depth <= 3 i.e. 8 GPUs + // For larger numbers of GPUs, settle for first tree found (non-optimal), but + // this saves a lot of runtime, because Backtrack is exponential time + if (depth <= 3) + IterativeBacktrack( W, state, result, result_weight, 1, num_elements, depth, true ); + else + IterativeBacktrack( W, state, result, result_weight, 1, num_elements, depth, false ); //result_weight = ComputeTreeWeight(W, result, num_elements, depth, false); //Backtrack( W, state, result, result_weight, 1, num_elements, depth, true ); //T result_weight2 = ComputeTreeWeight(W, result, num_elements, depth, false); From 8694fe702ed433476b21546742cc88ac00412f26 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 18 Jun 2018 22:54:01 +0000 Subject: [PATCH 13/36] Kernighan-Lin seems to find better trees --- src/kvstore/comm_tree.h | 2 +- src/kvstore/gpu_topology.h | 20 -------------------- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 1e420b8be363..2ea1b213ccdf 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -51,7 +51,7 @@ class CommDeviceTree : public Comm { CommDeviceTree() { inited_ = false; bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 10000000); - backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 1); + backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 0); link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_LINK_USAGE_PENALTY", 0.7); stream_ = dmlc::GetEnv("MXNET_KVSTORE_STREAM", 1); } diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 9479191c6307..ed7a07b5f6ff 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -38,26 +38,6 @@ namespace mxnet { namespace kvstore { -void PrettyPrintTopology(const std::vector> topo) { - std::cout << " ={"; - for (unsigned row = 0; row < topo.size(); ++row) { - if (row != 0) - std::cout << " "; - std::cout << "{"; - for (unsigned col = 0; col < topo[0].size(); ++col) { - std::cout << topo[row][col]; - if( col != topo[0].size()-1 ) - std::cout << ", "; - } - std::cout << "}"; - if ( row == topo.size()-1 ) - std::cout << "};"; - else - std::cout << ","; - std::cout << std::endl; - } -} - template void PrintVector( const std::string& str, const std::vector& vec ) { std::cout << str << ":\n"; From c6cd67a9153e769955180dcb4522be5aa8268b9c Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Wed, 20 Jun 2018 18:02:50 +0000 Subject: [PATCH 14/36] get rid of printfs --- src/kvstore/comm_tree.h | 48 ++------------------------ src/kvstore/gpu_topology.h | 71 ++++---------------------------------- 2 files changed, 9 insertions(+), 110 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 2ea1b213ccdf..a6e9b99a1ccc 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -100,46 +100,32 @@ class CommDeviceTree : public Comm { BufferEntry& buf = merge_buf_[topo_id][key]; if ( devs_[topo_id] == src[i].ctx() ) { - //buf.merged = src[i]; CopyFromTo(src[i], &(buf.merged[merged_row]), priority); - //LOG(WARNING) << "Initial reduce copy from " << src[i].ctx() << " to " << buf.merged[merged_row].ctx(); } } } - //LOG(WARNING) << "Copy to merged"; for (int level = depth_; level > 0; --level) { int start = scan_[root][level ]; int end = scan_[root][level+1]; - //LOG(WARNING) << "Reduce level: " << level; - //LOG(WARNING) << "From " << start << " to " << end; unsigned is_dest = 0; int dest_id = 0; for (int j = start; j < end; ++j) { int topo_id = topology[j]; dest_id = (is_dest==0) ? topo_id : dest_id; - //LOG(WARNING) << topo_id << " -> " << dest_id; BufferEntry& buf_dest = merge_buf_[dest_id][key]; BufferEntry& buf_from = merge_buf_[topo_id][key]; - //LOG(WARNING) << "Dest shape " << buf_dest.merged[merged_row].ctx() << buf_dest.copy_buf[merged_row][0].ctx(); - //LOG(WARNING) << "From shape " << buf_from.merged[merged_row].ctx() << buf_from.copy_buf[merged_row][0].ctx(); if (!is_dest) { reduce[dest_id].push_back( buf_dest.merged[merged_row] ); - //LOG(WARNING) << topo_id << " == " << dest_id; } else { if (dest_id != topo_id) { - //buf_dest.copy_buf[is_dest-1] = NDArray( - // buf_dest.merged.shape(), buf_dest.merged.ctx(), false, - // buf_dest.merged.dtype()); CopyFromTo(buf_from.merged[merged_row], &(buf_dest.copy_buf[merged_row][is_dest-1]), priority); reduce[dest_id].push_back( buf_dest.copy_buf[merged_row][is_dest-1] ); - //LOG(WARNING) << "Reduce copy from " << buf_from.merged[merged_row].ctx() << " to " << buf_dest.copy_buf[merged_row][is_dest-1].ctx(); - //LOG(WARNING) << topo_id << " != " << dest_id; } } @@ -151,27 +137,21 @@ class CommDeviceTree : public Comm { end = scan_[root][level ]; for (int i = start; i < end; ++i) { int gpu_id = topology[i]; - //LOG(WARNING) << "Doing reduce on GPU" << gpu_id; - //LOG(WARNING) << "With #elems " << reduce[gpu_id].size(); // conditional to detect whether operation must be done if ( reduce[gpu_id].size() > 1 ) { BufferEntry& buf = merge_buf_[gpu_id][key]; - //LOG(WARNING) << "reduce input 1 " << reduce[gpu_id][0].ctx(); - //LOG(WARNING) << "reduce input 2 " << reduce[gpu_id][1].ctx(); - //LOG(WARNING) << "buf.mg output " << buf.merged[merged_row].ctx(); ElementwiseSum(reduce[gpu_id], &(buf.merged[merged_row]), priority); } } // reset - //LOG(WARNING) << "Clear reduce array"; for (unsigned i = 0; i < devs_.size(); ++i) { reduce[i].clear(); } } } else { - //LOG(WARNING) << "Only dense input supported for now"; + LOG(WARNING) << "Only dense input supported for now"; } int topo_id = topology[0]; @@ -198,7 +178,6 @@ class CommDeviceTree : public Comm { std::vector> broadcast_slice(devs_.size()); std::vector slice_scan(devs_.size()+1); - //LOG(WARNING) << key << " " << src[0].shape() << " " << src[0].shape().Size(); int total_size = src[0].shape().Size(); unsigned first_size = src[0].shape()[0]; @@ -211,7 +190,6 @@ class CommDeviceTree : public Comm { int slice_size = (first_size + devs_.size()-1)/devs_.size(); for (unsigned i = 1; i < devs_.size(); ++i) { slice_scan[i] = slice_scan[i-1] + slice_size; - //LOG(WARNING) << slice_scan[i]; } slice_scan[devs_.size()] = src[0].shape()[0]; @@ -239,7 +217,6 @@ class CommDeviceTree : public Comm { } } else { int root = 0; - //LOG(WARNING) << "Executing single tree reduce for key " << key << " root " << root; ReduceInner(key, src, root, 0, priority); BufferEntry& buf = merge_buf_[root][key]; @@ -312,29 +289,20 @@ class CommDeviceTree : public Comm { if (merged_row == -1) CopyFromTo(src, dst[gpu_id], priority); temp[gpu_id] = *dst[gpu_id]; - //LOG(WARNING) << "Bcast copy from " << src.ctx() << " to " << buf.merged[merged_row].ctx(); for (int level = 1; level <= depth_; ++level) { int start = scan_[root][level]; int end = scan_[root][level+1]; - //LOG(WARNING) << "Bcast level: " << level; - //LOG(WARNING) << "From " << start << " to " << end; unsigned is_src = 0; int src_id = 0; for (int j = start; j < end; ++j) { int topo_id = topology[j]; src_id = (is_src==0) ? topo_id : src_id; - //LOG(WARNING) << src_id << " -> " << topo_id; if (is_src && src_id != topo_id) { - //LOG(WARNING) << src_id << " != " << topo_id; - CopyFromTo(temp[src_id], dst[topo_id], priority); - temp[topo_id] = *dst[topo_id]; - - //LOG(WARNING) << "Bcast copy from " << buf_from.merged[merged_row].ctx() << " to " << buf_dest.merged[merged_row].ctx(); } is_src = (is_src == static_cast(kBranch)-1) ? 0 : is_src+1; @@ -362,7 +330,6 @@ class CommDeviceTree : public Comm { int slice_size = (dst[0]->shape()[0]+devs_.size()-1)/devs_.size(); for (unsigned i = 1; i < devs_.size(); ++i) { slice_scan[i] = slice_scan[i-1] + slice_size; - //LOG(WARNING) << slice_scan[i]; } slice_scan[devs_.size()] = dst[0]->shape()[0]; @@ -372,13 +339,11 @@ class CommDeviceTree : public Comm { if ( devs_[gpu_id] == dst[gpu_id]->ctx() ) { NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i+1]); CopyFromTo(buf.merged[i], &curr_slice, priority); - //LOG(WARNING) << "Bcast return copy from " << buf.merged[i].ctx() << " to " << curr_slice.ctx(); } } } } else { int root = 0; - //LOG(WARNING) << "Executing single tree broadcast for key " << key << " root " << root; BroadcastInner(key, src, dst, root, -1, priority); } } @@ -527,9 +492,9 @@ class CommDeviceTree : public Comm { int start = scan_[0][depth_ ]; int end = scan_[0][depth_+1]; - //LOG(WARNING) << "From: " << start << " to: " << end; - // In order to generalize to any number of GPUs, must support 2 things: + // In order to generalize to any number of GPUs, there are many + // strategies: // 1) detect whether we are encountering gpu for first time // first time => allocate memory // second time => do nothing @@ -537,7 +502,6 @@ class CommDeviceTree : public Comm { // allocate merge_buf_ to be next biggest power of 2 sized or use // 0, 1, ..., n_gpus (same mapping as dev_id) // e.g. 5, 6, 7, 8 must all have merge_buf_.size() == 8 - // -Design decision: use second approach for now for (int j = start; j < end; ++j) { int topo_id = topology_[0][j]; auto& buf = merge_buf_[topo_id][key]; @@ -553,18 +517,15 @@ class CommDeviceTree : public Comm { int slice_size = (first_size+devs_.size()-1)/devs_.size(); int last_slice = first_size-(devs_.size()-1)*slice_size; shape_copy[0] = slice_size; - //LOG(WARNING) << "Split Check emptiness of copy buf on GPU" << topo_id << " " << ctx; buf.merged.resize(devs_.size()); for (unsigned row = 0; row < devs_.size(); ++row) { if (row == devs_.size()-1) shape_copy[0] = last_slice; - //LOG(WARNING) << "Split Allocating merg buf to GPU" << topo_id << " of shape" << shape_copy; buf.merged[row] = NDArray(shape_copy, ctx, false, type); buf.copy_buf.push_back(std::vector()); if (buf.copy_buf[row].empty()) { buf.copy_buf[row].resize(kBranch-1); for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { - //LOG(WARNING) << "Split Allocating copy buf to GPU" << topo_id; buf.copy_buf[row][col] = NDArray(buf.merged[row].shape(), buf.merged[row].ctx(), false, buf.merged[row].dtype()); @@ -574,15 +535,12 @@ class CommDeviceTree : public Comm { } else { buf.merged.push_back(NDArray(shape, ctx, false, type)); if (buf.copy_buf.empty()) { - //LOG(WARNING) << "Check emptiness of copy buf on GPU" << topo_id<< " " << ctx; buf.copy_buf.push_back(std::vector()); buf.copy_buf[0].resize(kBranch-1); for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { - //LOG(WARNING) << "Allocating copy buf to GPU" << topo_id; buf.copy_buf[0][col] = NDArray(buf.merged[0].shape(), buf.merged[0].ctx(), false, buf.merged[0].dtype()); - //LOG(WARNING) << "Success allocating copy buf to GPU" << topo_id; } } } diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index ed7a07b5f6ff..fe061ff76e44 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -260,15 +260,12 @@ bool KernighanLin( const std::vector& W, // -1 means vertex i is in Cluster B if (P[i] == static_cast(color)) { cluster_list.push_back(i); - //std::cout << "Number in Cluster A: " << first_partition << "\n"; - //std::cout << "Put vertex " << i << " in Cluster " << P_temp[i] << "\n"; } else P_temp[i] = 0; } // 1b) Shuffle using random generator std::shuffle(cluster_list.begin(), cluster_list.end(), gen); - //PrintVector("Partition permutation", cluster_list); for (unsigned i = 0; i < cluster_list.size(); ++i) { if (first_partition < target_partition) { int dest = cluster_list[i]; @@ -279,7 +276,6 @@ bool KernighanLin( const std::vector& W, P_temp[dest] = -1; } } - //PrintVector("Partition candidate", P_temp); // 2) Do iterations of Kernighan-Lin until convergence T g_max = 0; @@ -292,9 +288,7 @@ bool KernighanLin( const std::vector& W, // a) Compute difference between external and internal costs of all // elements in vector D gemv( W, P_temp, D ); - //PrintVector( "D pre-ewisemult", D ); ewisemult( P_temp, -1.f, D ); - //PrintVector( "D post-ewisemult", D ); // av and bv are used to hold candidates for moving // gv stores the score associated with move @@ -310,10 +304,7 @@ bool KernighanLin( const std::vector& W, T g; FindBestMove( W, P_temp, D, used, a, b, g ); if (g > 0) { - //std::cout << "Best move found in iter " << iter; - //std::cout << ": " << a << " -> " << b << " : " << g << "\n"; } else { - //std::cout << "No moves found in iter " << iter << std::endl; g_max = 0; break; } @@ -330,14 +321,10 @@ bool KernighanLin( const std::vector& W, used.insert(b); // e) Update D using P_temp - //PrintVector( "P_temp post-update", P_temp ); gemv( W, P_temp, D ); - //PrintVector( "D pre-ewisemult", D ); ewisemult( P_temp, -1.f, D ); - //PrintVector( "D post-ewisemult", D ); D[a] = 0; D[b] = 0; - //PrintVector( "D post-ewisemult", D ); } // 3) Find when to stop by doing linear scan through gv @@ -355,7 +342,6 @@ bool KernighanLin( const std::vector& W, // Otherwise, rollback changes to P_temp2 if (g_max > 0) { for (int i = 0; i < g_k; i++) { - //std::cout << g_max << " " << g_k << " " << i << " " << av.size() << " " << bv.size() << " " << gv.size() << std::endl; int a = av[i]; int b = bv[i]; int temp = P_temp2[a]; @@ -377,8 +363,6 @@ bool KernighanLin( const std::vector& W, moves++; } } - //std::cout << "New color " << num_partitions << " with " << moves; - //std::cout << " elements\n"; cluster_pairs.push_back(std::make_pair(static_cast(color), static_cast(num_partitions))); @@ -407,7 +391,6 @@ int GetChild( const std::vector& P, int color, int parent ) { for (unsigned i = 0; i < P.size(); ++i) { - //std::cout << "Child " << i << ": " << P[i] << std::endl; if (P[i] == color && static_cast(i) != parent) return i; } @@ -469,10 +452,8 @@ int KLGenerateBinaryTree( std::vector& W, int reset = 0; for (unsigned i = 0; i < cluster_pairs.size(); ++i) { - //std::cout << "Cluster pair " << i << std::endl; if (i==0) scan_row.push_back(topo_row.size()); - //std::cout << "Pair " << i << ": " << cluster_pairs[i].first << " " << cluster_pairs[i].second << std::endl; int parent, child = -1; if (cluster_pairs[i].second==-2) { // Root must be color of pair.first @@ -480,13 +461,11 @@ int KLGenerateBinaryTree( std::vector& W, parent = GetRoot( P, color, roots ); if (parent == -1) return 1; child = GetChild(P, color, parent); - //std::cout << "Best link (case 1): " << color << ": " << parent << " -> " << child << ": " << std::endl; } else if (cluster_pairs[i].second==-1) { int color = cluster_pairs[i].first; parent = GetRoot( P, color, roots ); if (parent == -1) return 1; child = parent; - //std::cout << "Best link (case 2): " << color << ": " << parent << " -> " << child << ": " << std::endl; } else { // Root must exist in either first or second element of pair int color = cluster_pairs[i].first; @@ -509,13 +488,11 @@ int KLGenerateBinaryTree( std::vector& W, } if (child == -1) { - //std::cout << "No path to other cluster found from " << parent << " at level " << scan_row.size() << std::endl; new_roots.insert(parent); //child = parent; return 1; } else { - //std::cout << "Best link (case 3): " << parent << " -> " << child << ": " << weight << std::endl; new_roots.insert(parent); new_roots.insert(child); } @@ -540,7 +517,6 @@ int KLGenerateBinaryTree( std::vector& W, child = new_topo[parent]; topo_row.push_back(parent); topo_row.push_back(child); - //std::cout << "New pair: " << parent << " " << child << " " << new_topo[parent] << std::endl; } cluster_pairs.clear(); @@ -581,9 +557,7 @@ bool IsValid( const std::vector& W, for (int j = 0; j+stride < row; j += 2*stride) { int from = state[j]; int dest = state[j+stride]; - //std::cout << "Comparing " << j << " and " << j+stride << " in row " << row << std::endl; if (W[from*num_elements + dest] == static_cast(0) && from != dest) { - //std::cout << "Not valid: no edge from " << from << " to " << dest << " at index " << from*num_elements+dest << std::endl; return false; } } @@ -602,7 +576,6 @@ bool IsValid( const std::vector& W, found_vec[val] = 1; } } else { - //std::cout << "Not valid: " << val << " exceeds # of GPUs\n"; return false; } } @@ -618,7 +591,6 @@ bool IsValid( const std::vector& W, // -9 unique GPUs if (row < num_elements) { if (num_found > row || num_found < row - modifier) { - //std::cout << "Not valid: " << found.size() << " rows found but expected between " << row << " and " << row - modifier << std::endl; return false; } @@ -627,7 +599,6 @@ bool IsValid( const std::vector& W, } else if (row == static_cast(state.size())) { for (int i = 0; i < num_elements; ++i) { if (found_vec[i] == 0) { - //std::cout << "Not valid: " << i << " not found" << std::endl; return false; } } @@ -667,19 +638,13 @@ void Postprocess( std::vector& result, int num_elements, int depth) { int val = result[i]; histogram[val]++; } - //PrintVector("above histo", histogram_above); for (int i = result.size()-stride; i-stride >= 0; i -= 2*stride) { - //std::cout << "Comparing " << i << " and " << i-stride << std::endl; int from = result[i]; int dest = result[i-stride]; if ((histogram[from] > 1 || histogram_above[from] >= 1) && from != dest) { - //PrintVector("Old histogram", histogram); - //std::cout << "Swapping from " << from << " to " << dest << " on indices " << i << " and " << i-stride << std::endl; result[i] = dest; histogram[from]--; - //PrintVector("New histogram", histogram); - //PrintVector("New result", result); } } } @@ -714,17 +679,14 @@ T ComputeTreeWeight( const std::vector& W, if (links_used.find(from*num_elements+dest) != links_used.end() && penalty) { weight -= 100; - //std::cout << "Penalty 1: " << from << " to " << dest << std::endl; } links_used.insert(from*num_elements+dest); links_used.insert(dest*num_elements+from); - //std::cout << "Not valid: no edge from " << from << " to " << dest << std::endl; } nodes_used[from] = true; if (i > 0 && nodes_used[dest] && penalty) { weight -= 10; - //std::cout << "Penalty 2: " << from << " and " << dest << " seen before\n"; } nodes_used[dest] = true; } @@ -789,11 +751,6 @@ bool RecursiveBacktrack( const std::vector& W, if (weight > best_result_weight) { std::swap(best_result_weight, weight); best_result = result; - //std::cout << "New best weight: " << best_result_weight << " > " << weight << std::endl; - //PrintVector("New best", result); - } else { - //std::cout << "Not best weight: " << weight << " < " << best_result_weight << std::endl; - //PrintVector("Not best", result); } return !optimal; } @@ -802,7 +759,6 @@ bool RecursiveBacktrack( const std::vector& W, bool stop = false; for (int j = 0; j < num_elements; ++j) { state[row] = j; - //PrintVector("Trying state", state); if (IsValid(W, state, num_elements, row+1, depth)) { stop = Backtrack( W, state, best_result, best_result_weight, row+1, num_elements, depth, optimal ); @@ -846,7 +802,6 @@ void IterativeBacktrack( const std::vector& W, state[row] = pos; // If there is a valid position push the position to stack, set current // position to 0 and move to next row - //PrintVector("Trying state", state); if (IsValid(W, state, num_elements, row+1, depth)) { state_stack.push(pos); pos = 0; @@ -860,7 +815,6 @@ void IterativeBacktrack( const std::vector& W, // Pop stack, set current position to next position // Backtrack to find next solution if (row == static_cast(state.size())) { - //PrintVector("state", state); std::vector result = state; Postprocess(result, num_elements, depth); T weight = ComputeTreeWeight(W, result, num_elements, depth, true); @@ -869,8 +823,6 @@ void IterativeBacktrack( const std::vector& W, if (weight > best_result_weight) { std::swap(best_result_weight, weight); best_result = result; - //std::cout << "New best weight: " << best_result_weight << " > " << weight << std::endl; - //PrintVector("New best", result); } if (!optimal) break; @@ -893,11 +845,8 @@ void UpdateWeight( std::vector& W, for (unsigned i = 1; i < topo_row.size() - 1; i += 2) { unsigned parent = topo_row[i]; unsigned child = topo_row[i+1]; - if (parent >= num_elements*num_elements || - child >= num_elements*num_elements) - LOG(WARNING) << "W array out of bounds\n"; - else if (parent != child) { - //std::cout << W[child*num_elements+parent] << " " << alpha << std::endl; + if (!(parent >= num_elements*num_elements || + child >= num_elements*num_elements) && (parent != child)) { W[parent*num_elements+child] *= alpha; W[child*num_elements+parent] *= alpha; } @@ -948,13 +897,6 @@ void BacktrackGenerateBinaryTree( std::vector& W, IterativeBacktrack( W, state, result, result_weight, 1, num_elements, depth, true ); else IterativeBacktrack( W, state, result, result_weight, 1, num_elements, depth, false ); - //result_weight = ComputeTreeWeight(W, result, num_elements, depth, false); - //Backtrack( W, state, result, result_weight, 1, num_elements, depth, true ); - //T result_weight2 = ComputeTreeWeight(W, result, num_elements, depth, false); - //PrintVector("result", result); - //std::cout << "First solution reached " << - // (result_weight2-result_weight)/result_weight2 << " of optimal " << - // result_weight << " " << result_weight2 << "\n"; FormTopology( result, topo_row, scan_row, depth ); } @@ -1015,7 +957,6 @@ void ComputeTreesFromRoot( std::vector& W, // Run Kernighan-Lin to generate partition stop = KernighanLin(W, P_temp, num_partitions_temp, cluster_pairs, gen); - //PrintVector("New partition", P_temp); // Use partitions found and a given root to find best inter-cluster edge for // each pair of clusters, and returns them as roots of next cluster // If reset is true, then rewind back to previous clustering @@ -1029,8 +970,7 @@ void ComputeTreesFromRoot( std::vector& W, if (reset == 1) { if (!backtrack) - std::cout << "No valid binary tree found from root " << root << ", try backtracking\n"; - //std::cout << "Trying backtracking\n"; + LOG(WARNING) << "No valid binary tree found from root " << root << ", try backtracking"; BacktrackGenerateBinaryTree(W, num_elements, root, topo, scan); } else { topo = topo_temp; @@ -1084,11 +1024,12 @@ void ComputeTrees( const std::vector& W, std::vector> topo_temp(num_elements, std::vector()); - for (int i = 0; i < num_elements; ++i) + + /*for (int i = 0; i < num_elements; ++i) PrintTopo("Topo", topo[i], scan[i]); PrintMatrix("W", W, num_elements, num_elements); - PrintMatrix("Links", adj, num_elements, num_elements); + PrintMatrix("Links", adj, num_elements, num_elements);*/ } } // namespace kvstore From 7466c4deb24e28f0ba1187a36cced34baf80fdce Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Thu, 21 Jun 2018 20:59:12 +0000 Subject: [PATCH 15/36] change defaults --- src/kvstore/comm_tree.h | 10 +++++----- src/kvstore/kvstore_local.h | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 2ea1b213ccdf..f5b90776f469 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -50,7 +50,7 @@ class CommDeviceTree : public Comm { public: CommDeviceTree() { inited_ = false; - bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 10000000); + gpuarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_GPUARRAY_BOUND", 10000000); backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 0); link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_LINK_USAGE_PENALTY", 0.7); stream_ = dmlc::GetEnv("MXNET_KVSTORE_STREAM", 1); @@ -205,7 +205,7 @@ class CommDeviceTree : public Comm { const NDArrayStorageType stype = src[0].storage_type(); // normal dense reduce if (stype == kDefaultStorage) { - if (total_size > bigarray_bound_ && first_size >= devs_.size()) { + if (total_size > gpuarray_bound_ && first_size >= devs_.size()) { // Find slice bounds slice_scan[0] = 0; int slice_size = (first_size + devs_.size()-1)/devs_.size(); @@ -356,7 +356,7 @@ class CommDeviceTree : public Comm { } else { int total_size = src.shape().Size(); unsigned first_size = src.shape()[0]; - if (total_size > bigarray_bound_ && first_size >= devs_.size()) { + if (total_size > gpuarray_bound_ && first_size >= devs_.size()) { std::vector slice_scan(devs_.size()+1); slice_scan[0] = 0; int slice_size = (dst[0]->shape()[0]+devs_.size()-1)/devs_.size(); @@ -548,7 +548,7 @@ class CommDeviceTree : public Comm { TShape shape_copy = shape; int total_size = shape.Size(); unsigned first_size = shape[0]; - if (total_size > bigarray_bound_ && first_size >= devs_.size()) { + if (total_size > gpuarray_bound_ && first_size >= devs_.size()) { // Find slice bounds int slice_size = (first_size+devs_.size()-1)/devs_.size(); int last_slice = first_size-(devs_.size()-1)*slice_size; @@ -644,7 +644,7 @@ class CommDeviceTree : public Comm { /// \brief Highest numbered device int max_dev_; int depth_; - int bigarray_bound_; + int gpuarray_bound_; bool inited_; bool stream_; bool backtrack_; diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index df622de5f36c..89550ed49b66 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -57,7 +57,7 @@ class KVStoreLocal : public KVStore { */ explicit KVStoreLocal(bool use_device_comm) : KVStore() { if (use_device_comm) { - bool tree = dmlc::GetEnv("MXNET_KVSTORE_USETREE", 1); + bool tree = dmlc::GetEnv("MXNET_KVSTORE_USETREE", 0); if (tree) { comm_ = new CommDeviceTree(); } else { From cc935a2c6fa730f2ef60178ad996f77a576a2459 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Fri, 22 Jun 2018 00:39:53 +0000 Subject: [PATCH 16/36] inherit from CommDevice instead of Comm --- src/kvstore/comm.h | 49 ++++++++----- src/kvstore/comm_tree.h | 150 ++++++---------------------------------- 2 files changed, 53 insertions(+), 146 deletions(-) diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index a5d6a1dabeff..39d040dbbf2c 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -474,6 +474,31 @@ class CommDevice : public Comm { } } + const NDArray& ReduceRowSparse(int key, const std::vector& src, + int priority) { + auto& buf = merge_buf_[key]; + std::vector reduce(src.size()); + + const NDArrayStorageType stype = src[0].storage_type(); + NDArray& buf_merged = buf.merged_buf(stype); + if (buf.copy_buf.empty()) { + // initialize buffer for copying during reduce + buf.copy_buf.resize(src.size()); + for (size_t j = 0; j < src.size(); ++j) { + buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype()); + } + } + CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type()) + << "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. " + << buf.copy_buf[0].storage_type() << "(buf.copy_buf)"; + for (size_t i = 0; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i]), priority); + reduce[i] = buf.copy_buf[i]; + } + ElementwiseSum(reduce, &buf_merged, priority); + return buf_merged; + } + const NDArray& Reduce(int key, const std::vector& src, int priority) override { // when this reduce is called from kvstore_dist, gc is not set @@ -490,13 +515,14 @@ class CommDevice : public Comm { InitBuffersAndComm(src); auto& buf = merge_buf_[key]; - std::vector reduce(src.size()); const NDArrayStorageType stype = src[0].storage_type(); NDArray& buf_merged = buf.merged_buf(stype); // normal dense reduce if (stype == kDefaultStorage) { CopyFromTo(src[0], &buf_merged, priority); + + std::vector reduce(src.size()); reduce[0] = buf_merged; if (buf.copy_buf.empty()) { @@ -514,24 +540,11 @@ class CommDevice : public Comm { CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority); reduce[i+1] = buf.copy_buf[i]; } + ElementwiseSum(reduce, &buf_merged, priority); } else { // sparse reduce - if (buf.copy_buf.empty()) { - // initialize buffer for copying during reduce - buf.copy_buf.resize(src.size()); - for (size_t j = 0; j < src.size(); ++j) { - buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype()); - } - } - CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type()) - << "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. " - << buf.copy_buf[0].storage_type() << "(buf.copy_buf)"; - for (size_t i = 0; i < src.size(); ++i) { - CopyFromTo(src[i], &(buf.copy_buf[i]), priority); - reduce[i] = buf.copy_buf[i]; - } + buf_merged = ReduceRowSparse( key, src, priority ); } - ElementwiseSum(reduce, &buf_merged, priority); return buf_merged; } @@ -767,11 +780,13 @@ class CommDevice : public Comm { return sparse_merged; } - private: + private: /// \brief the sparse merged value for reduce and rowsparse broadcast operations NDArray sparse_merged; }; std::unordered_map merge_buf_; + + public: bool inited_; }; diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index dc1f1803ad97..2a175e04b76d 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -46,14 +46,13 @@ namespace kvstore { * device-to-cpu, which is often true for 4 or 8 GPUs. But it uses more device * memory. */ -class CommDeviceTree : public Comm { +class CommDeviceTree : public CommDevice { public: CommDeviceTree() { inited_ = false; gpuarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_GPUARRAY_BOUND", 10000000); backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 0); link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_LINK_USAGE_PENALTY", 0.7); - stream_ = dmlc::GetEnv("MXNET_KVSTORE_STREAM", 1); } virtual ~CommDeviceTree() { } @@ -83,7 +82,7 @@ class CommDeviceTree : public Comm { int merged_row, int priority) { std::vector> reduce(devs_.size()); - BufferEntry& random_buf = merge_buf_[0][key]; + TreeBufferEntry& random_buf = tree_merge_buf_[0][key]; const NDArrayStorageType stype = random_buf.merged[0].storage_type(); std::vector& topology = topology_[root]; NDArray buf_slice; @@ -97,7 +96,7 @@ class CommDeviceTree : public Comm { for (int j = start; j < end; ++j) { int topo_id = topology[j]; - BufferEntry& buf = merge_buf_[topo_id][key]; + TreeBufferEntry& buf = tree_merge_buf_[topo_id][key]; if ( devs_[topo_id] == src[i].ctx() ) { CopyFromTo(src[i], &(buf.merged[merged_row]), priority); @@ -115,8 +114,8 @@ class CommDeviceTree : public Comm { int topo_id = topology[j]; dest_id = (is_dest==0) ? topo_id : dest_id; - BufferEntry& buf_dest = merge_buf_[dest_id][key]; - BufferEntry& buf_from = merge_buf_[topo_id][key]; + TreeBufferEntry& buf_dest = tree_merge_buf_[dest_id][key]; + TreeBufferEntry& buf_from = tree_merge_buf_[topo_id][key]; if (!is_dest) { reduce[dest_id].push_back( buf_dest.merged[merged_row] ); @@ -140,7 +139,7 @@ class CommDeviceTree : public Comm { // conditional to detect whether operation must be done if ( reduce[gpu_id].size() > 1 ) { - BufferEntry& buf = merge_buf_[gpu_id][key]; + TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; ElementwiseSum(reduce[gpu_id], &(buf.merged[merged_row]), priority); } } @@ -155,7 +154,7 @@ class CommDeviceTree : public Comm { } int topo_id = topology[0]; - BufferEntry& buf = merge_buf_[topo_id][key]; + TreeBufferEntry& buf = tree_merge_buf_[topo_id][key]; return buf.merged[merged_row]; } @@ -197,7 +196,7 @@ class CommDeviceTree : public Comm { // col: which gpu for (unsigned row = 0; row < devs_.size(); ++row) { for (unsigned col = 0; col < devs_.size(); ++col) { - BufferEntry& buf = merge_buf_[col][key]; + TreeBufferEntry& buf = tree_merge_buf_[col][key]; NDArray curr_slice = src[col].Slice(slice_scan[row], slice_scan[row+1]); slice[row].push_back(curr_slice); @@ -219,7 +218,7 @@ class CommDeviceTree : public Comm { int root = 0; ReduceInner(key, src, root, 0, priority); - BufferEntry& buf = merge_buf_[root][key]; + TreeBufferEntry& buf = tree_merge_buf_[root][key]; return buf.merged[0]; } @@ -228,60 +227,13 @@ class CommDeviceTree : public Comm { return src[gpu_id]; } else { // sparse reduce - LOG(WARNING) << "Only dense input supported for now using multiple trees"; + return ReduceRowSparse( key, src, priority ); } } - const NDArray& ReduceCompressed(int key, const std::vector& src, - int priority) { - LOG(WARNING) << "ReduceCompressed not supported using multiple trees"; - /*InitBuffersAndComm(src); - auto& buf = merge_buf_[key]; - std::vector reduce(src.size()); - if (buf.copy_buf.empty()) { - // one buf for each context - buf.copy_buf.resize(src.size()); - buf.compressed_recv_buf.resize(src.size()); - buf.compressed_send_buf.resize(src.size()); - buf.residual.resize(src.size()); - - for (size_t i = 0; i < src.size(); ++i) { - buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(), - false, buf.merged.dtype()); - buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(), - false, buf.merged.dtype()); - buf.residual[i] = 0; - int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size()); - buf.compressed_recv_buf[i] = NDArray(TShape{small_size}, buf.merged.ctx(), - false, buf.merged.dtype()); - buf.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(), - false, buf.merged.dtype()); - } - } - - for (size_t i = 0; i < src.size(); ++i) { - // compress before copy - // this is done even if the data is on same context as copy_buf because - // we don't want the training to be biased towards data on this GPU - gc_->Quantize(src[i], &(buf.compressed_send_buf[i]), &(buf.residual[i]), priority); - - if (buf.compressed_send_buf[i].ctx() != buf.compressed_recv_buf[i].ctx()) { - CopyFromTo(buf.compressed_send_buf[i], &(buf.compressed_recv_buf[i]), priority); - } else { - // avoid memory copy when they are on same context - buf.compressed_recv_buf[i] = buf.compressed_send_buf[i]; - } - - gc_->Dequantize(buf.compressed_recv_buf[i], &(buf.copy_buf[i]), priority); - reduce[i] = buf.copy_buf[i]; - } - ElementwiseSum(reduce, &buf.merged); - return buf.merged;*/ - } - - void BroadcastInner(int key, const NDArray& src, - const std::vector& dst, int root, int merged_row, - int priority) { + void BroadcastInner(int key, const NDArray& src, + const std::vector& dst, int root, + int merged_row, int priority) { // copy to root of tree std::vector& topology = topology_[root]; std::vector temp(devs_.size()); @@ -334,7 +286,7 @@ class CommDeviceTree : public Comm { slice_scan[devs_.size()] = dst[0]->shape()[0]; for (unsigned gpu_id = 0; gpu_id < dst.size(); ++gpu_id) { - BufferEntry& buf = merge_buf_[gpu_id][key]; + TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; for (unsigned i = 0; i < devs_.size(); ++i) { if ( devs_[gpu_id] == dst[gpu_id]->ctx() ) { NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i+1]); @@ -349,64 +301,6 @@ class CommDeviceTree : public Comm { } } - void BroadcastRowSparse(int key, const NDArray& src, - const std::vector>& dst, - const int priority) override { - LOG(WARNING) << "BroadcastRowSparse not supported by multiple trees"; - /*CHECK_EQ(src.storage_type(), kRowSparseStorage) - << "BroadcastRowSparse expects row-sparse src NDArray"; - - for (size_t i = 0; i < dst.size(); ++i) { - NDArray* out = dst[i].first; - NDArray row_id = dst[i].second; - CHECK_EQ(out->storage_type(), kRowSparseStorage) - << "BroadcastRowSparse expects row_sparse dst NDArray"; - CHECK_EQ(row_id.ctx(), src.ctx()) - << "row_id and src are expected to be on the same context"; - - // retain according to indices - const bool is_same_ctx = out->ctx() == src.ctx(); - const bool is_diff_var = out->var() != src.var(); - NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out : - NDArray(kRowSparseStorage, out->shape(), src.ctx(), true, - out->dtype(), out->aux_types()); - if (!is_diff_var) { - common::LogOnce("The output of row_sparse_pull() on key " + std::to_string(key) + - "refers to the same NDArray as the one stored in KVStore." - "Performing row_sparse_pull() with such output is going to change the " - "data stored in KVStore. Incorrect result may be generated " - "next time row_sparse_pull() is called. To avoid such an issue," - "consider create a new NDArray buffer to store the output."); - } - - Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) { - const TBlob& indices = row_id.data(); - using namespace mxnet::common; - NDArray temp = retained_gpu; - switch (temp.ctx().dev_mask()) { - case cpu::kDevMask: { - SparseRetainOpForwardRspWrapper(rctx.get_stream(), - src, indices, kWriteTo, &temp); - break; - } -#if MXNET_USE_CUDA - case gpu::kDevMask: { - SparseRetainOpForwardRspWrapper(rctx.get_stream(), - src, indices, kWriteTo, &temp); - // wait for GPU operations to complete - rctx.get_stream()->Wait(); - break; - } -#endif - default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; - } - on_complete(); - }, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()}, - FnProperty::kNormal, priority, "KVStoreSparseRetain"); - CopyFromTo(retained_gpu, out, priority); - }*/ - } - private: void EnableP2P() { #if MXNET_USE_CUDA @@ -476,7 +370,7 @@ class CommDeviceTree : public Comm { // 2) Force copy_buf to be of kRecvBufferSize // 3) Do not use greedy assignment; all keys are assigned to each GPU for (unsigned i = 0; i < devs_.size(); ++i) - merge_buf_.push_back( std::unordered_map() ); + tree_merge_buf_.push_back( std::unordered_map() ); std::map key_dist; @@ -499,12 +393,12 @@ class CommDeviceTree : public Comm { // first time => allocate memory // second time => do nothing // 2) must use either mapping from dev_id to 0, 1, ..., n_gpus or must - // allocate merge_buf_ to be next biggest power of 2 sized or use + // allocate tree_merge_buf_ to be next biggest power of 2 sized or use // 0, 1, ..., n_gpus (same mapping as dev_id) - // e.g. 5, 6, 7, 8 must all have merge_buf_.size() == 8 + // e.g. 5, 6, 7, 8 must all have tree_merge_buf_.size() == 8 for (int j = start; j < end; ++j) { int topo_id = topology_[0][j]; - auto& buf = merge_buf_[topo_id][key]; + auto& buf = tree_merge_buf_[topo_id][key]; Context ctx = devs_[topo_id]; // buf.merged enforces that we only visit each GPU once @@ -558,7 +452,7 @@ class CommDeviceTree : public Comm { std::vector sorted_key_attrs_; /// \brief temporal space for pushing and pulling - struct BufferEntry { + struct TreeBufferEntry { /// \brief the dense merged value for reduce and broadcast operations std::vector merged; /// \brief the gpu buffer for copy during reduce operation @@ -590,9 +484,9 @@ class CommDeviceTree : public Comm { /// \brief the sparse merged value for reduce and rowsparse broadcast operations NDArray sparse_merged; }; - /// \brief intent of merge_buf_ in old comm.h: store key->gpu mapping + /// \brief intent of tree_merge_buf_ in old comm.h: store key->gpu mapping /// new intent: for every gpu: store key->memory mapping - std::vector> merge_buf_; + std::vector> tree_merge_buf_; /// \brief NVLink-connected topology in full binary tree format std::vector> topology_; @@ -603,8 +497,6 @@ class CommDeviceTree : public Comm { int max_dev_; int depth_; int gpuarray_bound_; - bool inited_; - bool stream_; bool backtrack_; float link_usage_penalty_; From ba60aaa927c6b83f22f501d7c349814d75d1d7be Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Fri, 22 Jun 2018 21:50:03 +0000 Subject: [PATCH 17/36] Fix lint errors --- src/kvstore/comm.h | 6 +- src/kvstore/comm_tree.h | 57 ++- src/kvstore/gpu_topology.h | 593 +++++++++++++------------ tests/cpp/kvstore/gpu_topology_test.cc | 257 ++++++----- 4 files changed, 454 insertions(+), 459 deletions(-) diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 39d040dbbf2c..09928c9da81f 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -543,7 +543,7 @@ class CommDevice : public Comm { ElementwiseSum(reduce, &buf_merged, priority); } else { // sparse reduce - buf_merged = ReduceRowSparse( key, src, priority ); + buf_merged = ReduceRowSparse(key, src, priority); } return buf_merged; } @@ -780,13 +780,13 @@ class CommDevice : public Comm { return sparse_merged; } - private: + private: /// \brief the sparse merged value for reduce and rowsparse broadcast operations NDArray sparse_merged; }; std::unordered_map merge_buf_; - public: + public: bool inited_; }; diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 2a175e04b76d..3e0de61fe8e4 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -30,6 +30,7 @@ #include #include #include +#include #include "mxnet/ndarray.h" #include "gradient_compression.h" #include "../ndarray/ndarray_function.h" @@ -88,7 +89,6 @@ class CommDeviceTree : public CommDevice { NDArray buf_slice; if (stype == kDefaultStorage) { - // Copy everything into buf.merged for each gpu for (size_t i = 0; i < src.size(); ++i) { int start = scan_[root][depth_ ]; @@ -98,7 +98,7 @@ class CommDeviceTree : public CommDevice { int topo_id = topology[j]; TreeBufferEntry& buf = tree_merge_buf_[topo_id][key]; - if ( devs_[topo_id] == src[i].ctx() ) { + if (devs_[topo_id] == src[i].ctx()) { CopyFromTo(src[i], &(buf.merged[merged_row]), priority); } } @@ -112,19 +112,20 @@ class CommDeviceTree : public CommDevice { int dest_id = 0; for (int j = start; j < end; ++j) { int topo_id = topology[j]; - dest_id = (is_dest==0) ? topo_id : dest_id; + dest_id = (is_dest == 0) ? topo_id : dest_id; TreeBufferEntry& buf_dest = tree_merge_buf_[dest_id][key]; TreeBufferEntry& buf_from = tree_merge_buf_[topo_id][key]; if (!is_dest) { - reduce[dest_id].push_back( buf_dest.merged[merged_row] ); + reduce[dest_id].push_back(buf_dest.merged[merged_row]); } else { if (dest_id != topo_id) { CopyFromTo(buf_from.merged[merged_row], &(buf_dest.copy_buf[merged_row][is_dest-1]), priority); - reduce[dest_id].push_back( buf_dest.copy_buf[merged_row][is_dest-1] ); + reduce[dest_id].push_back( + buf_dest.copy_buf[merged_row][is_dest-1]); } } @@ -138,7 +139,7 @@ class CommDeviceTree : public CommDevice { int gpu_id = topology[i]; // conditional to detect whether operation must be done - if ( reduce[gpu_id].size() > 1 ) { + if (reduce[gpu_id].size() > 1) { TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; ElementwiseSum(reduce[gpu_id], &(buf.merged[merged_row]), priority); } @@ -227,12 +228,12 @@ class CommDeviceTree : public CommDevice { return src[gpu_id]; } else { // sparse reduce - return ReduceRowSparse( key, src, priority ); + return ReduceRowSparse(key, src, priority); } } - void BroadcastInner(int key, const NDArray& src, - const std::vector& dst, int root, + void BroadcastInner(int key, const NDArray& src, + const std::vector& dst, int root, int merged_row, int priority) { // copy to root of tree std::vector& topology = topology_[root]; @@ -250,7 +251,7 @@ class CommDeviceTree : public CommDevice { int src_id = 0; for (int j = start; j < end; ++j) { int topo_id = topology[j]; - src_id = (is_src==0) ? topo_id : src_id; + src_id = (is_src == 0) ? topo_id : src_id; if (is_src && src_id != topo_id) { CopyFromTo(temp[src_id], dst[topo_id], priority); @@ -288,7 +289,7 @@ class CommDeviceTree : public CommDevice { for (unsigned gpu_id = 0; gpu_id < dst.size(); ++gpu_id) { TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; for (unsigned i = 0; i < devs_.size(); ++i) { - if ( devs_[gpu_id] == dst[gpu_id]->ctx() ) { + if (devs_[gpu_id] == dst[gpu_id]->ctx()) { NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i+1]); CopyFromTo(buf.merged[i], &curr_slice, priority); } @@ -347,13 +348,13 @@ class CommDeviceTree : public CommDevice { void QueryTopology() { #if MXNET_USE_CUDA std::vector link_matrix(devs_.size()*devs_.size()); - GetP2PWeight( devs_, link_matrix ); + GetP2PWeight(devs_, &link_matrix); if (backtrack_) LOG(WARNING) << "Using Backtracking to generate trees"; else LOG(WARNING) << "Using Kernighan-Lin to generate trees"; - ComputeTrees( link_matrix, devs_.size(), link_usage_penalty_, backtrack_, - topology_, scan_ ); + ComputeTrees(link_matrix, devs_.size(), link_usage_penalty_, backtrack_, + &topology_, &scan_); depth_ = ComputeDepth(devs_.size()); #endif @@ -362,7 +363,6 @@ class CommDeviceTree : public CommDevice { using KeyAttrs = std::tuple; // try to allocate buff on device evenly void InitMergeBuffer() { - LOG(WARNING) << "Using Tree"; // same as all-reduce, except: @@ -370,9 +370,9 @@ class CommDeviceTree : public CommDevice { // 2) Force copy_buf to be of kRecvBufferSize // 3) Do not use greedy assignment; all keys are assigned to each GPU for (unsigned i = 0; i < devs_.size(); ++i) - tree_merge_buf_.push_back( std::unordered_map() ); + tree_merge_buf_.push_back(std::unordered_map()); - std::map key_dist; + std::map key_dist; for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { const int key = std::get<0>(sorted_key_attrs_[i]); @@ -387,15 +387,14 @@ class CommDeviceTree : public CommDevice { int start = scan_[0][depth_ ]; int end = scan_[0][depth_+1]; - // In order to generalize to any number of GPUs, there are many - // strategies: - // 1) detect whether we are encountering gpu for first time - // first time => allocate memory - // second time => do nothing - // 2) must use either mapping from dev_id to 0, 1, ..., n_gpus or must - // allocate tree_merge_buf_ to be next biggest power of 2 sized or use - // 0, 1, ..., n_gpus (same mapping as dev_id) - // e.g. 5, 6, 7, 8 must all have tree_merge_buf_.size() == 8 + // In order to generalize to any number of GPUs, we use strategy of having + // found the mapping from 0, 1, ..., n_gpus to dev_id i.e. + // idx: 0 1 2 3 4 5 6 + // dev_id: 4 2 3 1 7 5 0 + // and generated an n_gpus x n_gpus link topology matrix: + // + // 1) The reduction trees are saved as indices on 0, 1, ..., n_gpus + // 2) We use the mapping to retrieve dev_id and device context for (int j = start; j < end; ++j) { int topo_id = topology_[0][j]; auto& buf = tree_merge_buf_[topo_id][key]; @@ -438,8 +437,6 @@ class CommDeviceTree : public CommDevice { } } } - } else { - //LOG(WARNING) << topo_id << " has been allocated already"; } } } @@ -474,7 +471,7 @@ class CommDeviceTree : public CommDevice { // check if sparse_merged is initialized if (sparse_merged.is_none()) { CHECK(merged.size() > 0 && !merged[0].is_none()); - sparse_merged = NDArray(kRowSparseStorage, merged[0].shape(), + sparse_merged = NDArray(kRowSparseStorage, merged[0].shape(), merged[0].ctx(), true, merged[0].dtype()); } return sparse_merged; @@ -487,7 +484,7 @@ class CommDeviceTree : public CommDevice { /// \brief intent of tree_merge_buf_ in old comm.h: store key->gpu mapping /// new intent: for every gpu: store key->memory mapping std::vector> tree_merge_buf_; - + /// \brief NVLink-connected topology in full binary tree format std::vector> topology_; std::vector> scan_; diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index fe061ff76e44..9808f39c848b 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -28,8 +28,10 @@ #include #include #include +#include #include #include +#include #include #include @@ -39,7 +41,7 @@ namespace mxnet { namespace kvstore { template -void PrintVector( const std::string& str, const std::vector& vec ) { +void PrintVector(const std::string& str, const std::vector& vec) { std::cout << str << ":\n"; for (unsigned i = 0; i < vec.size(); ++i) std::cout << vec[i] << " "; @@ -47,8 +49,8 @@ void PrintVector( const std::string& str, const std::vector& vec ) { } template -void PrintMatrix( const std::string& str, const std::vector& matrix, - int num_rows, int num_cols ) { +void PrintMatrix(const std::string& str, const std::vector& matrix, + int num_rows, int num_cols) { std::cout << str << ":\n"; int count = 0; for (int row = 0; row < num_rows; ++row) { @@ -59,8 +61,8 @@ void PrintMatrix( const std::string& str, const std::vector& matrix, } } -void PrintTopo( const std::string& str, const std::vector& topo_row, - std::vector scan_row ) { +void PrintTopo(const std::string& str, const std::vector& topo_row, + std::vector scan_row) { PrintVector("Topo vector", topo_row); PrintVector("Scan vector", scan_row); std::cout << str << ":\n"; @@ -68,11 +70,10 @@ void PrintTopo( const std::string& str, const std::vector& topo_row, for (int row = 0; row < depth; ++row) { int start = scan_row[row]; int end = scan_row[row+1]; - for (; start& topo_row, // 2: 1 NVLink connection // 3: 2 NVLink connections template -void GetP2PWeight( const std::vector& devs, - std::vector& matrix ) { +void GetP2PWeight(const std::vector& devs, + std::vector* matrix) { int num_gpus = devs.size(); int count = 0; std::vector zero_dev_id(num_gpus, -1); @@ -104,16 +105,16 @@ void GetP2PWeight( const std::vector& devs, for (int row = 0; row < num_gpus; ++row) { for (int col = 0; col < num_gpus; ++col) { - if (row==col) { - matrix[row*num_gpus+col] = 0; + if (row == col) { + (*matrix)[row*num_gpus+col] = 0; } else { int value; int row_gpu = zero_dev_id[row]; int col_gpu = zero_dev_id[col]; - cudaDeviceGetP2PAttribute( &value, attr, row_gpu, col_gpu ); + cudaDeviceGetP2PAttribute(&value, attr, row_gpu, col_gpu); if (value > max[row]) max[row] = value; - matrix[row*num_gpus+col] = static_cast(value)+1; + (*matrix)[row*num_gpus+col] = static_cast(value)+1; } } } @@ -128,52 +129,52 @@ void GetP2PWeight( const std::vector& devs, // If all GPUs have at least 1 NVLink connection, then we can use NVLink only // to communicate instead of going over PCI-E if (max_value > 0) { - for (auto& matrix_value : matrix) { - matrix_value = (matrix_value==1) ? 0 : matrix_value; + for (auto& matrix_value : *matrix) { + matrix_value = (matrix_value == 1) ? 0 : matrix_value; } } - PrintMatrix( "Weight W", matrix, num_gpus, num_gpus ); + PrintMatrix("Weight W", *matrix, num_gpus, num_gpus); } // Dense matrix-vector multiplication // Assume: matrix is square // y = A*x (no accumulate) template -void gemv( const std::vector& A, - const std::vector& x, - std::vector& y ) { +void gemv(const std::vector& A, + const std::vector& x, + std::vector* y) { int nrows = x.size(); int count = 0; - for (int row=0; row(x[col]); + for (int row=0; row < nrows; ++row) { + (*y)[row] = 0; + for (int col=0; col < nrows; ++col) { + (*y)[row] += A[count]*static_cast(x[col]); count++; } - } + } } // Element-wise multiplication between 2 dense vectors // w = w * alpha*u template -void ewisemult( const std::vector& u, - T alpha, - std::vector& w ) { +void ewisemult(const std::vector& u, + T alpha, + std::vector* w) { int nelem = u.size(); - for (int i=0; i(u[i]); + for (int i=0; i < nelem; ++i) { + (*w)[i] *= alpha*static_cast(u[i]); } } // Element-wise addition between 2 dense vectors // w = w + alpha*u template -void ewiseadd( const std::vector& u, - T alpha, - std::vector& w ) { +void ewiseadd(const std::vector& u, + T alpha, + std::vector* w) { int nelem = u.size(); - for (int i=0; i(u[i]); + for (int i=0; i < nelem; ++i) { + (*w)[i] += alpha*static_cast(u[i]); } } @@ -183,27 +184,27 @@ void ewiseadd( const std::vector& u, // Optimization: Only need to look at upper triangular since weight matrix is // symmetric template -void FindBestMove( const std::vector& W, - const std::vector& P_temp, - const std::vector& D, - const std::unordered_set& used, - int& a, - int& b, - T& g ) { +void FindBestMove(const std::vector& W, + const std::vector& P_temp, + const std::vector& D, + const std::unordered_set& used, + int* a, + int* b, + T* g) { int nrows = P_temp.size(); - g = 0; - a = -1; - b = -1; - for (int row=0; rowg ) { - g = cost; - a = row; - b = col; + if (cost > *g) { + *g = cost; + *a = row; + *b = col; } } } @@ -216,31 +217,30 @@ void FindBestMove( const std::vector& W, // cluster_pairs stores the mapping that tells us which 2 clusters are // the output of partitioning one large cluster template -bool KernighanLin( const std::vector& W, - std::vector& P, - int& num_partitions, - std::vector>& cluster_pairs, - std::mt19937& gen ) { - - std::vector histogram(num_partitions, 0); - std::vector P_temp(P.size(), 0); - std::vector P_temp2(P.size(), 0); - std::vector D(P.size(), 0); - std::vector D_temp(P.size(), 0); +bool KernighanLin(const std::vector& W, + std::vector* P, + int* num_partitions, + std::vector>* cluster_pairs, + std::mt19937* gen) { + std::vector histogram(*num_partitions, 0); + std::vector P_temp(P->size(), 0); + std::vector P_temp2(P->size(), 0); + std::vector D(P->size(), 0); + std::vector D_temp(P->size(), 0); // 0) For every partition, determine if it can be partitioned further. // To do this, we must do a histogram of each partition: - for (unsigned i=0; isize(); ++i) { + histogram[(*P)[i]]++; } bool stop = true; - for (unsigned color=0; color( - static_cast(color),-partition_size)); + cluster_pairs->push_back( + std::pair(static_cast(color), -partition_size)); // Do Kernighan-Lin if clustering is necessary } else { @@ -254,18 +254,19 @@ bool KernighanLin( const std::vector& W, int target_partition = partition_size/2; std::vector cluster_list; - for (unsigned i = 0; i < P.size(); ++i) { + for (unsigned i = 0; i < P->size(); ++i) { // Required to shift from [0,1] to {-1,1} // 1 means vertex i is in Cluster A // -1 means vertex i is in Cluster B - if (P[i] == static_cast(color)) { + if ((*P)[i] == static_cast(color)) { cluster_list.push_back(i); - } else + } else { P_temp[i] = 0; + } } // 1b) Shuffle using random generator - std::shuffle(cluster_list.begin(), cluster_list.end(), gen); + std::shuffle(cluster_list.begin(), cluster_list.end(), *gen); for (unsigned i = 0; i < cluster_list.size(); ++i) { if (first_partition < target_partition) { int dest = cluster_list[i]; @@ -285,10 +286,10 @@ bool KernighanLin( const std::vector& W, count++; P_temp2 = P_temp; - // a) Compute difference between external and internal costs of all + // a) Compute difference between external and internal costs of all // elements in vector D - gemv( W, P_temp, D ); - ewisemult( P_temp, -1.f, D ); + gemv(W, P_temp, &D); + ewisemult(P_temp, -1.f, &D); // av and bv are used to hold candidates for moving // gv stores the score associated with move @@ -298,11 +299,11 @@ bool KernighanLin( const std::vector& W, std::unordered_set used; - for (int iter=0; iter 0) { } else { g_max = 0; @@ -321,8 +322,8 @@ bool KernighanLin( const std::vector& W, used.insert(b); // e) Update D using P_temp - gemv( W, P_temp, D ); - ewisemult( P_temp, -1.f, D ); + gemv(W, P_temp, &D); + ewisemult(P_temp, -1.f, &D); D[a] = 0; D[b] = 0; } @@ -353,20 +354,20 @@ bool KernighanLin( const std::vector& W, } else { P_temp = P_temp2; } - } while (g_max > 0 && count <= P.size()); + } while (g_max > 0 && count <= P->size()); // 5) Update P using P_temp int moves = 0; - for (unsigned i=0; isize(); ++i) { + if (P_temp[i] == -1) { + (*P)[i] = *num_partitions; moves++; } } - cluster_pairs.push_back(std::make_pair(static_cast(color), - static_cast(num_partitions))); + cluster_pairs->push_back(std::pair(static_cast(color), + static_cast(*num_partitions))); - num_partitions++; + (*num_partitions)++; } } @@ -375,11 +376,11 @@ bool KernighanLin( const std::vector& W, // Returns root of a given color if found in roots // Returns -1 if it is not found -int GetRoot( const std::vector& P, - int color, - const std::unordered_set& roots ) { +int GetRoot(const std::vector& P, + int color, + const std::unordered_set& roots) { for (auto root : roots) { - if (P[root]==color) + if (P[root] == color) return root; } return -1; @@ -387,9 +388,9 @@ int GetRoot( const std::vector& P, // Returns root of a given color if found in roots // Returns -1 if it is not found -int GetChild( const std::vector& P, - int color, - int parent ) { +int GetChild(const std::vector& P, + int color, + int parent) { for (unsigned i = 0; i < P.size(); ++i) { if (P[i] == color && static_cast(i) != parent) return i; @@ -407,90 +408,90 @@ int GetChild( const std::vector& P, // g is weight of edge // Optimization: Only need to look at row a in matrix template -void FindBestEdge( const std::vector& W, - const std::vector& P, - int parent, - int dest_cluster, - std::vector& b, - T& g ) { +void FindBestEdge(const std::vector& W, + const std::vector& P, + int parent, + int dest_cluster, + std::vector* b, + T* g) { int nrows = P.size(); int row = parent; - g = 0; - b.push_back(-1); - for (int col=0; colpush_back(-1); + for (int col=0; col < nrows; ++col) { + if (col == row || P[col] != dest_cluster) continue; T cost = W[row*nrows+col]; - if( cost > g ) { - b.clear(); + if (cost > *g) { + b->clear(); } - if( cost >= g ) { - b.push_back(col); - g = cost; + if (cost >= *g) { + b->push_back(col); + *g = cost; } } } // Given a vector of color pairs, appends to binary tree matrix topo -// @input: cluster_pairs gives pairing between clusters, an edge is found +// @input: W gives the link topology +// P gives the result of KL partitioning +// cluster_pairs gives pairing between clusters, an edge is found // between each pairing -// roots gives source vertex +// roots gives source vertices // gen gives random number generation to break ties // @output: cluster_pairs // topo_row says where new edges are appended to // scan_row says where we should start looking for topo_row template -int KLGenerateBinaryTree( std::vector& W, - const std::vector& P, - std::vector>& cluster_pairs, - std::unordered_set& roots, - std::vector& topo_row, - std::vector& scan_row, - std::mt19937& gen ) { - std::unordered_set new_roots; - std::unordered_map new_topo; +int KLGenerateBinaryTree(const std::vector& W, + const std::vector& P, + std::vector>* cluster_pairs, + std::unordered_set* roots, + std::vector* topo_row, + std::vector* scan_row, + std::mt19937* gen) { + std::unordered_set new_roots; + std::unordered_map new_topo; int reset = 0; - for (unsigned i = 0; i < cluster_pairs.size(); ++i) { - if (i==0) - scan_row.push_back(topo_row.size()); + for (unsigned i = 0; i < cluster_pairs->size(); ++i) { + if (i == 0) + scan_row->push_back(topo_row->size()); int parent, child = -1; - if (cluster_pairs[i].second==-2) { + if ((*cluster_pairs)[i].second == -2) { // Root must be color of pair.first - int color = cluster_pairs[i].first; - parent = GetRoot( P, color, roots ); + int color = (*cluster_pairs)[i].first; + parent = GetRoot(P, color, *roots); if (parent == -1) return 1; child = GetChild(P, color, parent); - } else if (cluster_pairs[i].second==-1) { - int color = cluster_pairs[i].first; - parent = GetRoot( P, color, roots ); + } else if ((*cluster_pairs)[i].second == -1) { + int color = (*cluster_pairs)[i].first; + parent = GetRoot(P, color, *roots); if (parent == -1) return 1; child = parent; } else { // Root must exist in either first or second element of pair - int color = cluster_pairs[i].first; - parent = GetRoot(P, color, roots); - color = (parent==-1) ? cluster_pairs[i].second : color; - parent = (parent==-1) ? GetRoot(P, color, roots) : parent; + int color = (*cluster_pairs)[i].first; + parent = GetRoot(P, color, *roots); + color = (parent == -1) ? (*cluster_pairs)[i].second : color; + parent = (parent == -1) ? GetRoot(P, color, *roots) : parent; int from_cluster = color; - int dest_cluster = (from_cluster==cluster_pairs[i].first) ? - cluster_pairs[i].second : cluster_pairs[i].first; + int dest_cluster = (from_cluster == (*cluster_pairs)[i].first) ? + (*cluster_pairs)[i].second : (*cluster_pairs)[i].first; std::vector candidates; T weight; - FindBestEdge( W, P, parent, dest_cluster, candidates, weight ); + FindBestEdge(W, P, parent, dest_cluster, &candidates, &weight); // If no candidates - if (candidates[0]!=-1) { - std::shuffle(candidates.begin(), candidates.end(), gen); + if (candidates[0] != -1) { + std::shuffle(candidates.begin(), candidates.end(), *gen); child = candidates[0]; } if (child == -1) { new_roots.insert(parent); - - //child = parent; return 1; } else { new_roots.insert(parent); @@ -501,34 +502,34 @@ int KLGenerateBinaryTree( std::vector& W, new_topo[parent] = child; } - int depth = scan_row.size(); - int start = scan_row[depth-2]; - int end = scan_row[depth-1]; + int depth = scan_row->size(); + int start = (*scan_row)[depth-2]; + int end = (*scan_row)[depth-1]; for (int i = start; i < end; ++i) { - int parent = topo_row[i]; + int parent = (*topo_row)[i]; int child; - // If not first, check previous level whether or not we are encountering + // If not first, check previous level whether or not we are encountering // this root for the first time in this level of the tree - if (i != start && parent == static_cast(topo_row[i-1])) + if (i != start && parent == static_cast((*topo_row)[i-1])) child = parent; else child = new_topo[parent]; - topo_row.push_back(parent); - topo_row.push_back(child); + topo_row->push_back(parent); + topo_row->push_back(child); } - cluster_pairs.clear(); - roots.clear(); - roots = std::move(new_roots); + cluster_pairs->clear(); + roots->clear(); + *roots = std::move(new_roots); return reset; } // @input: n is the number of nodes in a balanced binary tree // @output: returns how many levels of binary tree there are -int ComputeDepth( int n ) { +int ComputeDepth(int n) { for (int depth = 0; depth < MAX_DEPTH; ++depth) { int num = 2 << depth; if (n <= num) @@ -543,12 +544,11 @@ int ComputeDepth( int n ) { // -each edge in tree corresponds to link in network topology // -each edge in tree does not form self-loop template -bool IsValid( const std::vector& W, - const std::vector& state, - int num_elements, - int row, - int depth ) { - +bool IsValid(const std::vector& W, + const std::vector& state, + int num_elements, + int row, + int depth) { // At each level of tree, check whether edge: // -corresponds to link in network topology // -corresponds to self-loop @@ -566,7 +566,7 @@ bool IsValid( const std::vector& W, // If we encounter GPU for first time, increment found_vec. // Otherwise, do nothing std::unordered_set found; - std::vector found_vec(num_elements,0); + std::vector found_vec(num_elements, 0); for (auto val : state) { if (val == -1) continue; @@ -583,8 +583,8 @@ bool IsValid( const std::vector& W, // modifier is maximum number of repeats a single GPU can take // e.g. 5 GPUs in 3-level binary tree => one GPU can repeat 3x // GPU0 GPU0 GPU0 GPU0 GPU1 GPU2 GPU3 GPU4 - int modifier = (1 << depth) - num_elements; - int num_found= found.size(); + int modifier = (1 << depth) - num_elements; + int num_found = found.size(); // So we know we have an invalid state if we find: // -only 4 unique GPUs @@ -607,7 +607,8 @@ bool IsValid( const std::vector& W, return true; } -// This function takes a spanning tree encoded as state (result), which may have// repeated GPUs representing NO-SENDs and converts it into a unique format. +// This function takes a spanning tree encoded as state (result), which may have +// repeated GPUs representing NO-SENDs and converts it into a unique format. // This has the effect of recognizing redundant sends, grouping them together, // so that the Reduce call knows not to perform a CopyFromTo. // @@ -625,25 +626,25 @@ bool IsValid( const std::vector& W, // 3 1 // 3 0 1 5 // 3 3 0 4 1 2 5 6 // GPU3 knows not to make redundant send to itself -void Postprocess( std::vector& result, int num_elements, int depth) { +void Postprocess(std::vector* result, int num_elements, int depth) { for (int level = depth - 1; level >= 0; --level) { int stride = 1 << level; - std::vector histogram_above(num_elements,0); - for (unsigned i = 0; i < result.size(); i += 2*stride) { - int val = result[i]; + std::vector histogram_above(num_elements, 0); + for (unsigned i = 0; i < result->size(); i += 2*stride) { + int val = (*result)[i]; histogram_above[val]++; } std::vector histogram(num_elements, 0); - for (unsigned i = 0; i < result.size(); i += stride) { - int val = result[i]; + for (unsigned i = 0; i < result->size(); i += stride) { + int val = (*result)[i]; histogram[val]++; } - for (int i = result.size()-stride; i-stride >= 0; i -= 2*stride) { - int from = result[i]; - int dest = result[i-stride]; + for (int i = result->size()-stride; i-stride >= 0; i -= 2*stride) { + int from = (*result)[i]; + int dest = (*result)[i-stride]; if ((histogram[from] > 1 || histogram_above[from] >= 1) && from != dest) { - result[i] = dest; + (*result)[i] = dest; histogram[from]--; } } @@ -656,11 +657,11 @@ void Postprocess( std::vector& result, int num_elements, int depth) { // -usually turned on when backtracking to get better solutions // -usually turned off when outside the penalty to get weight of tree template -T ComputeTreeWeight( const std::vector& W, - const std::vector& result, - int num_elements, - int depth, - bool penalty ) { +T ComputeTreeWeight(const std::vector& W, + const std::vector& result, + int num_elements, + int depth, + bool penalty) { T weight = 0.f; std::unordered_set links_used; @@ -674,10 +675,10 @@ T ComputeTreeWeight( const std::vector& W, weight += W[from*num_elements+dest]; // Penalize: (1) use of redundant edges in a single tree - // (2) repeated use of a GPU in a single tree at the same + // (2) repeated use of a GPU in a single tree at the same // level above the leaf level - if (links_used.find(from*num_elements+dest) != links_used.end() - && penalty) { + if (links_used.find(from*num_elements+dest) != links_used.end() + && penalty) { weight -= 100; } links_used.insert(from*num_elements+dest); @@ -709,48 +710,48 @@ T ComputeTreeWeight( const std::vector& W, // 3 1 // 3 0 1 5 // 3 3 0 4 1 2 5 6 -void FormTopology( const std::vector& result, - std::vector& topo_row, - std::vector& scan_row, - int depth ) { - scan_row.push_back(topo_row.size()); +void FormTopology(const std::vector& result, + std::vector* topo_row, + std::vector* scan_row, + int depth) { + scan_row->push_back(topo_row->size()); for (int i = depth; i > 0; --i) { int stride = 1 << i; for (unsigned j = 0; j < result.size(); j += stride) { int from = result[j]; - topo_row.push_back(from); + topo_row->push_back(from); } - scan_row.push_back(topo_row.size()); + scan_row->push_back(topo_row->size()); } // Insert at the end, result vector - topo_row.insert(topo_row.end(), result.begin(), result.end()); - scan_row.push_back(topo_row.size()); + topo_row->insert(topo_row->end(), result.begin(), result.end()); + scan_row->push_back(topo_row->size()); } // Recursive function that finds a spanning tree, which fulfills the following // conditions: // -balanced // -binary -// -maximum weight +// -maximum weight template -bool RecursiveBacktrack( const std::vector& W, - std::vector& state, - std::vector& best_result, - T& best_result_weight, - int row, - int num_elements, - int depth, - bool optimal ) { - if (row == static_cast(state.size())) { - std::vector result = state; - Postprocess(result, num_elements, depth); +bool RecursiveBacktrack(const std::vector& W, + std::vector* state, + std::vector* best_result, + T* best_result_weight, + int row, + int num_elements, + int depth, + bool optimal) { + if (row == static_cast(state->size())) { + std::vector result = *state; + Postprocess(&result, num_elements, depth); T weight = ComputeTreeWeight(W, result, num_elements, depth, true); // Save this spanning tree if it is highest weight tree found sofar - if (weight > best_result_weight) { - std::swap(best_result_weight, weight); - best_result = result; + if (weight > *best_result_weight) { + std::swap(*best_result_weight, weight); + *best_result = result; } return !optimal; } @@ -758,13 +759,11 @@ bool RecursiveBacktrack( const std::vector& W, // If not last recursive level, try to find valid tree for next level bool stop = false; for (int j = 0; j < num_elements; ++j) { - state[row] = j; - if (IsValid(W, state, num_elements, row+1, depth)) { - stop = Backtrack( W, state, best_result, best_result_weight, row+1, - num_elements, depth, optimal ); - state[row] = -1; - } else - state[row] = -1; + (*state)[row] = j; + if (IsValid(W, state, num_elements, row+1, depth)) + stop = RecursiveBacktrack(W, state, best_result, best_result_weight, + row+1, num_elements, depth, optimal); + (*state)[row] = -1; if (stop) return stop; } @@ -772,14 +771,14 @@ bool RecursiveBacktrack( const std::vector& W, } template -void IterativeBacktrack( const std::vector& W, - std::vector& state, - std::vector& best_result, - T& best_result_weight, - int row, - int num_elements, - int depth, - bool optimal ) { +void IterativeBacktrack(const std::vector& W, + std::vector* state, + std::vector* best_result, + T* best_result_weight, + int row, + int num_elements, + int depth, + bool optimal) { std::stack state_stack; row = 1; int pos = 0; @@ -788,48 +787,48 @@ void IterativeBacktrack( const std::vector& W, while (true) { // If there is no valid position, 2 cases: // a) if stack is empty, break and stop search - // b) if stack is not empty, pop stack and set current position to next + // b) if stack is not empty, pop stack and set current position to next // position backtrack to previous row while (!state_stack.empty() && pos >= num_elements) { pos = state_stack.top(); pos++; state_stack.pop(); - state[state_stack.size()+1] = -1; + (*state)[state_stack.size()+1] = -1; row--; } if (state_stack.empty()) break; - state[row] = pos; - // If there is a valid position push the position to stack, set current + (*state)[row] = pos; + // If there is a valid position push the position to stack, set current // position to 0 and move to next row - if (IsValid(W, state, num_elements, row+1, depth)) { + if (IsValid(W, *state, num_elements, row+1, depth)) { state_stack.push(pos); pos = 0; row++; } else { pos++; - state[row] = -1; + (*state)[row] = -1; } // If stack has size N, a solution is found // Pop stack, set current position to next position // Backtrack to find next solution - if (row == static_cast(state.size())) { - std::vector result = state; - Postprocess(result, num_elements, depth); + if (row == static_cast(state->size())) { + std::vector result = *state; + Postprocess(&result, num_elements, depth); T weight = ComputeTreeWeight(W, result, num_elements, depth, true); - // Save this spanning tree if it is highest weight tree found sofar - if (weight > best_result_weight) { - std::swap(best_result_weight, weight); - best_result = result; + // Save this spanning tree if it is highest weight tree found so far + if (weight > *best_result_weight) { + std::swap(*best_result_weight, weight); + *best_result = result; } if (!optimal) break; - + pos = state_stack.top(); pos++; state_stack.pop(); - state[state_stack.size()+1] = -1; + (*state)[state_stack.size()+1] = -1; row--; } } @@ -838,17 +837,17 @@ void IterativeBacktrack( const std::vector& W, // Apply penalty factor alpha to each link in link topology graph that is used // by the spanning tree template -void UpdateWeight( std::vector& W, - const std::vector& topo_row, - int num_elements, - float alpha ) { +void UpdateWeight(std::vector* W, + const std::vector& topo_row, + int num_elements, + float alpha) { for (unsigned i = 1; i < topo_row.size() - 1; i += 2) { unsigned parent = topo_row[i]; unsigned child = topo_row[i+1]; - if (!(parent >= num_elements*num_elements || + if (!(parent >= num_elements*num_elements || child >= num_elements*num_elements) && (parent != child)) { - W[parent*num_elements+child] *= alpha; - W[child*num_elements+parent] *= alpha; + (*W)[parent*num_elements+child] *= alpha; + (*W)[child*num_elements+parent] *= alpha; } } } @@ -861,15 +860,14 @@ void UpdateWeight( std::vector& W, // 2) maximize edge weight // 3) tree is binary template -void BacktrackGenerateBinaryTree( std::vector& W, - int num_elements, - int root, - std::vector& topo_row, - std::vector& scan_row ) { - +void BacktrackGenerateBinaryTree(std::vector* W, + int num_elements, + int root, + std::vector* topo_row, + std::vector* scan_row) { // Clear before starting - topo_row.clear(); - scan_row.clear(); + topo_row->clear(); + scan_row->clear(); // Compute depth // num_elements: depth @@ -879,7 +877,7 @@ void BacktrackGenerateBinaryTree( std::vector& W, // 8: 3 // 9: 4 int depth = ComputeDepth(num_elements); - int depth_leaves = 1<& W, state[0] = root; // Seek optimal solution until depth <= 3 i.e. 8 GPUs - // For larger numbers of GPUs, settle for first tree found (non-optimal), but + // For larger numbers of GPUs, settle for first tree found (non-optimal), but // this saves a lot of runtime, because Backtrack is exponential time if (depth <= 3) - IterativeBacktrack( W, state, result, result_weight, 1, num_elements, depth, true ); + IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, + depth, true); else - IterativeBacktrack( W, state, result, result_weight, 1, num_elements, depth, false ); - FormTopology( result, topo_row, scan_row, depth ); + IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, + depth, false); + FormTopology(result, topo_row, scan_row, depth); } // ComputeTreesFromRoot does the same thing as ComputeTrees, with the only // exception being it will do it from a fixed GPU as root template -void ComputeTreesFromRoot( std::vector& W, - int num_elements, - int root, - float alpha, - bool backtrack, - std::vector& topo, - std::vector& scan ) { - +void ComputeTreesFromRoot(std::vector* W, + int num_elements, + int root, + float alpha, + bool backtrack, + std::vector* topo, + std::vector* scan) { int num_partitions = 1; // Initialize partition array to indicate which partition each element belongs @@ -919,18 +918,18 @@ void ComputeTreesFromRoot( std::vector& W, // Initialize vector of pairs that will tell us edges between what 2 clusters // we should be looking to build the tree from - std::vector> cluster_pairs; + std::vector> cluster_pairs; - // Initialize vector of roots that will tell us edges between + // Initialize vector of roots that will tell us edges between std::unordered_set roots; roots.insert(root); // Will be used to obtain a seed for the random number engine // RNG: Standard mersenne_twister_engine seeded with rd() // -use 0 for testing (TODO: remove this) - std::random_device rd; + // std::random_device rd; + // std::mt19937 gen(rd()); std::mt19937 gen(1); - //std::mt19937 gen(rd()); // Temporary variables for rewinding std::vector P_temp; @@ -951,17 +950,19 @@ void ComputeTreesFromRoot( std::vector& W, P_temp = P; num_partitions_temp = num_partitions; roots_temp = roots; - topo_temp = topo; - scan_temp = scan; + topo_temp = *topo; + scan_temp = *scan; } // Run Kernighan-Lin to generate partition - stop = KernighanLin(W, P_temp, num_partitions_temp, cluster_pairs, gen); + stop = KernighanLin(*W, &P_temp, &num_partitions_temp, &cluster_pairs, + &gen); - // Use partitions found and a given root to find best inter-cluster edge for // each pair of clusters, and returns them as roots of next cluster + // Use partitions found and a given root to find best inter-cluster edge for + // each pair of clusters, and returns them as roots of next cluster // If reset is true, then rewind back to previous clustering - reset = KLGenerateBinaryTree(W, P_temp, cluster_pairs, roots_temp, - topo_temp, scan_temp, gen); + reset = KLGenerateBinaryTree(*W, P_temp, &cluster_pairs, &roots_temp, + &topo_temp, &scan_temp, &gen); if (reset) level++; @@ -969,15 +970,15 @@ void ComputeTreesFromRoot( std::vector& W, } if (reset == 1) { - if (!backtrack) - LOG(WARNING) << "No valid binary tree found from root " << root << ", try backtracking"; + // if (!backtrack) + // LOG(WARNING) << "No valid binary tree found from root " << root << ", try backtracking"; BacktrackGenerateBinaryTree(W, num_elements, root, topo, scan); } else { - topo = topo_temp; - scan = scan_temp; - scan.push_back(topo.size()); + *topo = topo_temp; + *scan = scan_temp; + scan->push_back(topo->size()); } - UpdateWeight( W, topo, num_elements, alpha ); + UpdateWeight(W, *topo, num_elements, alpha); } // ComputeTrees computes balanced binary spanning trees of maximum edge weight @@ -989,32 +990,32 @@ void ComputeTreesFromRoot( std::vector& W, // @output: topo stores the trees generated // scan stores the start of each level of each tree template -void ComputeTrees( const std::vector& W, - int num_elements, - float alpha, - bool backtrack, - std::vector>& topo, - std::vector>& scan ) { +void ComputeTrees(const std::vector& W, + int num_elements, + float alpha, + bool backtrack, + std::vector>* topo, + std::vector>* scan) { std::vector W_copy = W; - topo.clear(); - scan.clear(); + topo->clear(); + scan->clear(); for (int i = 0; i < num_elements; ++i) { - topo.push_back(std::vector()); - scan.push_back(std::vector()); - topo[i].push_back(i); - scan[i].push_back(0); - ComputeTreesFromRoot(W_copy, num_elements, i, alpha, backtrack, topo[i], - scan[i]); + topo->push_back(std::vector()); + scan->push_back(std::vector()); + (*topo)[i].push_back(i); + (*scan)[i].push_back(0); + ComputeTreesFromRoot(&W_copy, num_elements, i, alpha, backtrack, + &((*topo)[i]), &((*scan)[i])); } // Note: must sum up adj matrix to show link usage before we readjust topo // from 0, 1, ..., n_gpus format to dev_id format, which will cause segfault std::vector adj(W.size(), 0); for (int row = 0; row < num_elements; ++row) { - for (unsigned col = 1; col < topo[0].size(); col += 2) { - int from = std::min(topo[row][col], topo[row][col+1]); - int dest = std::max(topo[row][col], topo[row][col+1]); + for (unsigned col = 1; col < (*topo)[0].size(); col += 2) { + int from = std::min((*topo)[row][col], (*topo)[row][col+1]); + int dest = std::max((*topo)[row][col], (*topo)[row][col+1]); if (from != dest) { adj[from*num_elements+dest] += 1; adj[dest*num_elements+from] += 1; @@ -1034,4 +1035,4 @@ void ComputeTrees( const std::vector& W, } // namespace kvstore } // namespace mxnet -#endif // MXNET_KVSTORE_GPU_TOPOLOGY_H +#endif // MXNET_KVSTORE_GPU_TOPOLOGY_H_ diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index d1aab257c085..1aed5f568cb3 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -28,30 +28,30 @@ #include #include "../src/kvstore/gpu_topology.h" -void GenerateMatrix( std::vector& W, int num_gpus, float k, - std::mt19937& gen) { +void GenerateMatrix(std::vector* W, int num_gpus, float k, + std::mt19937* gen) { std::uniform_real_distribution<> dis(0., 1.); for (int row = 0; row < num_gpus; ++row) { for (int col = row+1; col < num_gpus; ++col) { - float sample = dis(gen); + float sample = dis(*gen); if (sample < k) continue; - sample = dis(gen); + sample = dis(*gen); if (sample < 0.33f) { - W[row*num_gpus+col] = 1.f; - W[col*num_gpus+row] = 1.f; + (*W)[row*num_gpus+col] = 1.f; + (*W)[col*num_gpus+row] = 1.f; } else if (sample < 0.66f) { - W[row*num_gpus+col] = 2.f; - W[col*num_gpus+row] = 2.f; + (*W)[row*num_gpus+col] = 2.f; + (*W)[col*num_gpus+row] = 2.f; } else { - W[row*num_gpus+col] = 3.f; - W[col*num_gpus+row] = 3.f; + (*W)[row*num_gpus+col] = 3.f; + (*W)[col*num_gpus+row] = 3.f; } } } } -bool IsSatisfactory( const std::vector& W, int num_gpus, int depth ) { +bool IsSatisfactory(const std::vector& W, int num_gpus, int depth) { for (int row = 0; row < num_gpus; ++row) { int out_edges = 0; for (int col = 0; col < num_gpus; ++col) { @@ -65,25 +65,22 @@ bool IsSatisfactory( const std::vector& W, int num_gpus, int depth ) { } // Generates random link topology matrix using random number generator -void TestComputeTreesRandomized( int num_gpus, float alpha, int backtrack, - std::mt19937& gen ) { +void TestComputeTreesRandomized(int num_gpus, float alpha, int backtrack, + std::mt19937* gen) { std::uniform_real_distribution<> dis(0.f, 1.f); bool satisfied = false; std::vector W(num_gpus*num_gpus, 0.f); int depth = mxnet::kvstore::ComputeDepth(num_gpus); while (!satisfied) { - float k = dis(gen); + float k = dis(*gen); std::fill(W.begin(), W.end(), 0.f); - GenerateMatrix(W, num_gpus, k, gen); + GenerateMatrix(&W, num_gpus, k, gen); satisfied = IsSatisfactory(W, num_gpus, depth); - //if (!satisfied) - // LOG(WARNING) << k << " is not satisfactory"; } std::vector> topo; std::vector> scan; - //mxnet::kvstore::PrintMatrix("W", W, num_gpus, num_gpus); - mxnet::kvstore::ComputeTrees( W, num_gpus, alpha, backtrack, topo, scan ); + mxnet::kvstore::ComputeTrees(W, num_gpus, alpha, backtrack, &topo, &scan); unsigned correct_topo_size = (1 << (depth + 1)) - 1; unsigned correct_scan_size = depth+2; @@ -95,15 +92,15 @@ void TestComputeTreesRandomized( int num_gpus, float alpha, int backtrack, // Permutes matrix W using permutation vector P and stores output in matrix A // Assumption: W is square and symmetric -void PermuteMatrix( const std::vector& W, - const std::vector& P, - std::vector& A ) { +void PermuteMatrix(const std::vector& W, + const std::vector& P, + std::vector* A) { int nrows = P.size(); - std::vector temp(nrows*nrows,0); + std::vector temp(nrows*nrows, 0); int count = 0; - for (int row=0; row& W, } count = 0; - for (int row=0; row state0 = {3, 2, 1, 5, 0, 0, 4, 6}; + std::vector state0 = {3, 2, 1, 5, 0, 0, 4, 6}; std::vector topo0; std::vector scan0; - std::vector correct0= {3, 3, 0, 3, 1, 0, 4, 3, 2, 1, 5, 0, 0, 4, 6}; + std::vector correct0 = {3, 3, 0, 3, 1, 0, 4, 3, 2, 1, 5, 0, 0, 4, 6}; std::vector correct_scan0 = {0, 1, 3, 7, 15}; - mxnet::kvstore::FormTopology(state0, topo0, scan0, 3); + mxnet::kvstore::FormTopology(state0, &topo0, &scan0, 3); ASSERT_EQ(topo0.size(), correct0.size()); for (unsigned i = 0; i < correct0.size(); ++i) ASSERT_EQ(static_cast(topo0[i]), correct0[i]); @@ -134,12 +131,12 @@ TEST(GpuTopology, TestFormTopology) { for (unsigned i = 0; i < correct_scan0.size(); ++i) ASSERT_EQ(static_cast(scan0[i]), correct_scan0[i]); - std::vector state1 = {3, 2, 0, 4, 1, 1, 5, 6}; + std::vector state1 = {3, 2, 0, 4, 1, 1, 5, 6}; std::vector topo1; std::vector scan1; - std::vector correct1= {3, 3, 1, 3, 0, 1, 5, 3, 2, 0, 4, 1, 1, 5, 6}; + std::vector correct1 = {3, 3, 1, 3, 0, 1, 5, 3, 2, 0, 4, 1, 1, 5, 6}; std::vector correct_scan1 = {0, 1, 3, 7, 15}; - mxnet::kvstore::FormTopology(state1, topo1, scan1, 3); + mxnet::kvstore::FormTopology(state1, &topo1, &scan1, 3); ASSERT_EQ(topo1.size(), correct1.size()); for (unsigned i = 0; i < correct1.size(); ++i) ASSERT_EQ(static_cast(topo1[i]), correct1[i]); @@ -149,8 +146,7 @@ TEST(GpuTopology, TestFormTopology) { } TEST(GpuTopology, TestComputeTreeWeight) { - - std::vector W = {0, 2, 2, 3, 3, 0, 0, + std::vector W = {0, 2, 2, 3, 3, 0, 0, 2, 0, 3, 2, 0, 3, 0, 2, 3, 0, 3, 0, 0, 2, 3, 2, 3, 0, 0, 0, 0, @@ -166,27 +162,27 @@ TEST(GpuTopology, TestComputeTreeWeight) { } TEST(GpuTopology, TestPostprocess) { - std::vector result0 = {3, 0, 0, 4, 1, 2, 5, 6}; - std::vector correct0= {3, 3, 0, 4, 1, 2, 5, 6}; - mxnet::kvstore::Postprocess( result0, 7, 3 ); + std::vector result0 = {3, 0, 0, 4, 1, 2, 5, 6}; + std::vector correct0 = {3, 3, 0, 4, 1, 2, 5, 6}; + mxnet::kvstore::Postprocess(&result0, 7, 3); for (unsigned i = 0; i < correct0.size(); ++i) ASSERT_EQ(result0[i], correct0[i]); - std::vector result1 = {2, 0, 0, 4, 1, 3, 5, 1}; - std::vector correct1= {2, 2, 0, 4, 1, 3, 5, 5}; - mxnet::kvstore::Postprocess( result1, 6, 3 ); + std::vector result1 = {2, 0, 0, 4, 1, 3, 5, 1}; + std::vector correct1 = {2, 2, 0, 4, 1, 3, 5, 5}; + mxnet::kvstore::Postprocess(&result1, 6, 3); for (unsigned i = 0; i < correct1.size(); ++i) ASSERT_EQ(result1[i], correct1[i]); - std::vector result2 = {5, 4, 1, 3, 1, 0, 2, 0}; - std::vector correct2= {5, 4, 5, 3, 1, 0, 2, 2}; - mxnet::kvstore::Postprocess( result2, 6, 3 ); + std::vector result2 = {5, 4, 1, 3, 1, 0, 2, 0}; + std::vector correct2 = {5, 4, 5, 3, 1, 0, 2, 2}; + mxnet::kvstore::Postprocess(&result2, 6, 3); for (unsigned i = 0; i < correct2.size(); ++i) ASSERT_EQ(result2[i], correct2[i]); - std::vector result3 = {10,10, 0, 0, 0, 0, 0, 1, 2, 3, 6, 4, 7, 5, 8, 9}; - std::vector correct3= {10,10,10,10, 0, 0, 0, 1, 2, 3, 6, 4, 7, 5, 8, 9}; - mxnet::kvstore::Postprocess( result3, 11, 4 ); + std::vector result3 = {10, 10, 0, 0, 0, 0, 0, 1, 2, 3, 6, 4, 7, 5, 8, 9}; + std::vector correct3 = {10, 10, 10, 10, 0, 0, 0, 1, 2, 3, 6, 4, 7, 5, 8, 9}; + mxnet::kvstore::Postprocess(&result3, 11, 4); for (unsigned i = 0; i < correct3.size(); ++i) ASSERT_EQ(result3[i], correct3[i]); } @@ -200,8 +196,7 @@ TEST(GpuTopology, TestDepth) { } TEST(GpuTopology, TestIsValid) { - - std::vector W = {0, 2, 2, 3, 3, 0, 0, + std::vector W = {0, 2, 2, 3, 3, 0, 0, 2, 0, 3, 2, 0, 3, 0, 2, 3, 0, 3, 0, 0, 2, 3, 2, 3, 0, 0, 0, 0, @@ -239,23 +234,23 @@ TEST(GpuTopology, TestIsValid) { // gemvTest TEST(GpuTopology, TestGemv) { - std::vector A = {0, 2, 2, 3, 3, 1, 1, 1, // 13 - 2, 0, 3, 2, 1, 3, 1, 1, // 13 - 2, 3, 0, 3, 1, 1, 2, 1, // 13 - 3, 2, 3, 0, 1, 1, 1, 2, // 13 - 3, 1, 1, 1, 0, 2, 2, 3, // 13 - 1, 3, 1, 1, 2, 0, 3, 2, // 13 - 1, 1, 2, 1, 2, 3, 0, 3, // 13 - 1, 1, 1, 2, 3, 2, 3, 0}; // 13 + std::vector A = {0, 2, 2, 3, 3, 1, 1, 1, // 13 + 2, 0, 3, 2, 1, 3, 1, 1, // 13 + 2, 3, 0, 3, 1, 1, 2, 1, // 13 + 3, 2, 3, 0, 1, 1, 1, 2, // 13 + 3, 1, 1, 1, 0, 2, 2, 3, // 13 + 1, 3, 1, 1, 2, 0, 3, 2, // 13 + 1, 1, 2, 1, 2, 3, 0, 3, // 13 + 1, 1, 1, 2, 3, 2, 3, 0}; // 13 std::vector x(8, 1); std::vector y(8, 0); std::iota(y.begin(), y.end(), 0); std::vector correct_y(8, 13); - mxnet::kvstore::gemv( A, x, y ); + mxnet::kvstore::gemv(A, x, &y); ASSERT_EQ(y.size(), correct_y.size()); - for (unsigned i = 0; i < y.size(); ++i ) - ASSERT_EQ( y[i], correct_y[i] ); + for (unsigned i = 0; i < y.size(); ++i) + ASSERT_EQ(y[i], correct_y[i]); } // ewisemultTest @@ -265,11 +260,11 @@ TEST(GpuTopology, TestEwisemult) { std::iota(y.begin(), y.end(), 0); int alpha = 5; std::vector correct_y = {0, 5, 10, 15, 20, 25, 30, 35}; - mxnet::kvstore::ewisemult( x, alpha, y ); + mxnet::kvstore::ewisemult(x, alpha, &y); ASSERT_EQ(y.size(), correct_y.size()); - for (unsigned i = 0; i < y.size(); ++i ) - ASSERT_EQ( y[i], correct_y[i] ); + for (unsigned i = 0; i < y.size(); ++i) + ASSERT_EQ(y[i], correct_y[i]); } // ewiseaddTest @@ -278,18 +273,18 @@ TEST(GpuTopology, TestEwiseadd) { std::vector y(8, 0); std::iota(y.begin(), y.end(), 0); int alpha = 5; - std::vector correct_y(8,0); + std::vector correct_y(8, 0); std::iota(correct_y.begin(), correct_y.end(), 5); - mxnet::kvstore::ewiseadd( x, alpha, y ); + mxnet::kvstore::ewiseadd(x, alpha, &y); ASSERT_EQ(y.size(), correct_y.size()); - for (unsigned i = 0; i < y.size(); ++i ) - ASSERT_EQ( y[i], correct_y[i] ); + for (unsigned i = 0; i < y.size(); ++i) + ASSERT_EQ(y[i], correct_y[i]); } // FindBestMoveTest TEST(GpuTopology, TestFindBestMove) { - std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, 2, 0, 3, 2, 1, 3, 1, 1, 2, 3, 0, 3, 1, 1, 2, 1, 3, 2, 3, 0, 1, 1, 1, 2, @@ -301,12 +296,12 @@ TEST(GpuTopology, TestFindBestMove) { std::iota(P.begin(), P.end(), 1); std::unordered_set used; - std::vector D1 = {20,0, 0, 0, 0, 0, 0,20}; + std::vector D1 = {20, 0, 0, 0, 0, 0, 0, 20}; int a1, b1, g1; int correct_a1 = 0; int correct_b1 = 7; int correct_g1 = 38; - mxnet::kvstore::FindBestMove( W, P, D1, used, a1, b1, g1 ); + mxnet::kvstore::FindBestMove(W, P, D1, used, &a1, &b1, &g1); ASSERT_EQ(a1, correct_a1); ASSERT_EQ(b1, correct_b1); ASSERT_EQ(g1, correct_g1); @@ -317,7 +312,7 @@ TEST(GpuTopology, TestFindBestMove) { int correct_a2 = -1; int correct_b2 = -1; int correct_g2 = 0; - mxnet::kvstore::FindBestMove( W, P, D2, used, a2, b2, g2 ); + mxnet::kvstore::FindBestMove(W, P, D2, used, &a2, &b2, &g2); ASSERT_EQ(a2, correct_a2); ASSERT_EQ(b2, correct_b2); ASSERT_EQ(g2, correct_g2); @@ -363,30 +358,30 @@ TEST(GpuTopology, TestGetChild) { std::vector P = {0, 0, 1, 2, 2, 2, 3, 3}; // Test when color is not found - int color1 = 4; - int parent1= 4; + int color1 = 4; + int parent1 = 4; int correct_child1 = -1; - int child1 = mxnet::kvstore::GetChild(P, color1, parent1); + int child1 = mxnet::kvstore::GetChild(P, color1, parent1); ASSERT_EQ(child1, correct_child1); // Test when color is found, but is equal to parent - int color2 = 1; - int parent2= 2; + int color2 = 1; + int parent2 = 2; int correct_child2 = -1; - int child2 = mxnet::kvstore::GetChild(P, color2, parent2); + int child2 = mxnet::kvstore::GetChild(P, color2, parent2); ASSERT_EQ(child2, correct_child2); // Test when color is found and not equal to parent - int color3 = 3; - int parent3= 6; + int color3 = 3; + int parent3 = 6; int correct_child3 = 7; - int child3 = mxnet::kvstore::GetChild(P, color3, parent3); + int child3 = mxnet::kvstore::GetChild(P, color3, parent3); ASSERT_EQ(child3, correct_child3); } // FindBestEdgeTest TEST(GpuTopology, TestFindBestEdge) { - std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, 2, 0, 3, 2, 1, 3, 1, 1, 2, 3, 0, 3, 1, 1, 2, 1, 3, 2, 3, 0, 1, 1, 1, 2, @@ -403,7 +398,7 @@ TEST(GpuTopology, TestFindBestEdge) { int g1; std::vector correct_b1 = {0, 2}; int correct_g1 = 3; - mxnet::kvstore::FindBestEdge( W, P, parent1, dest1, b1, g1 ); + mxnet::kvstore::FindBestEdge(W, P, parent1, dest1, &b1, &g1); ASSERT_EQ(b1.size(), correct_b1.size()); for (unsigned i = 0; i < b1.size(); ++i) ASSERT_EQ(b1[i], correct_b1[i]); @@ -416,7 +411,7 @@ TEST(GpuTopology, TestFindBestEdge) { int g2; std::vector correct_b2 = {-1}; int correct_g2 = 0; - mxnet::kvstore::FindBestEdge( W, P, parent2, dest2, b2, g2 ); + mxnet::kvstore::FindBestEdge(W, P, parent2, dest2, &b2, &g2); ASSERT_EQ(b2.size(), correct_b2.size()); for (unsigned i = 0; i < b2.size(); ++i) ASSERT_EQ(b2[i], correct_b2[i]); @@ -425,7 +420,7 @@ TEST(GpuTopology, TestFindBestEdge) { // KLGenerateBinaryTreeTest TEST(GpuTopology, TestKLGenerateBinaryTree1) { - std::vector W = {0, 2, 3, 3, 3, 1, 1, 1, + std::vector W = {0, 2, 3, 3, 3, 1, 1, 1, 2, 0, 3, 2, 1, 3, 1, 1, 2, 3, 0, 3, 1, 1, 2, 1, 3, 2, 3, 0, 1, 1, 1, 2, @@ -434,17 +429,17 @@ TEST(GpuTopology, TestKLGenerateBinaryTree1) { 1, 1, 2, 1, 2, 3, 0, 3, 1, 1, 1, 2, 3, 2, 3, 0}; std::vector P = {0, 1, 1, 0, 2, 3, 3, 2}; - std::vector> cluster_pairs; - cluster_pairs.push_back(std::make_pair(0,-2)); - cluster_pairs.push_back(std::make_pair(1,-2)); - cluster_pairs.push_back(std::make_pair(2,-2)); - cluster_pairs.push_back(std::make_pair(3,-2)); + std::vector> cluster_pairs; + cluster_pairs.push_back(std::pair(0, -2)); + cluster_pairs.push_back(std::pair(1, -2)); + cluster_pairs.push_back(std::pair(2, -2)); + cluster_pairs.push_back(std::pair(3, -2)); std::unordered_set roots = {0, 2, 4, 6}; std::vector topo = {0, 2, 4, 6}; - std::vector scan(2,0); + std::vector scan(2, 0); std::mt19937 gen(1); - mxnet::kvstore::KLGenerateBinaryTree(W, P, cluster_pairs, roots, topo, scan, - gen); + mxnet::kvstore::KLGenerateBinaryTree(W, P, &cluster_pairs, &roots, &topo, + &scan, &gen); std::vector correct_topo = {0, 2, 4, 6, 0, 3, 2, 1, 4, 7, 6, 5}; std::vector correct_scan = {0, 0, 4}; ASSERT_EQ(topo.size(), correct_topo.size()); @@ -456,7 +451,7 @@ TEST(GpuTopology, TestKLGenerateBinaryTree1) { } TEST(GpuTopology, TestKLGenerateBinaryTree2) { - std::vector W = {0, 2, 3, 3, 3, 1, 1, 1, + std::vector W = {0, 2, 3, 3, 3, 1, 1, 1, 2, 0, 3, 2, 1, 3, 1, 1, 2, 3, 0, 3, 1, 1, 2, 1, 3, 2, 3, 0, 1, 1, 1, 2, @@ -465,17 +460,17 @@ TEST(GpuTopology, TestKLGenerateBinaryTree2) { 1, 1, 2, 1, 2, 3, 0, 3, 1, 1, 1, 2, 3, 2, 3, 0}; std::vector P = {0, 1, 1, 0, 2, 3, 3, 2}; - std::vector> cluster_pairs; - cluster_pairs.push_back(std::make_pair(0,-2)); - cluster_pairs.push_back(std::make_pair(1,-2)); - cluster_pairs.push_back(std::make_pair(2,-2)); - cluster_pairs.push_back(std::make_pair(3,-2)); + std::vector> cluster_pairs; + cluster_pairs.push_back(std::pair(0, -2)); + cluster_pairs.push_back(std::pair(1, -2)); + cluster_pairs.push_back(std::pair(2, -2)); + cluster_pairs.push_back(std::pair(3, -2)); std::unordered_set roots = {0, 2, 4, 6}; std::vector topo = {0, 6, 4, 2}; - std::vector scan(2,0); + std::vector scan(2, 0); std::mt19937 gen(1); - mxnet::kvstore::KLGenerateBinaryTree(W, P, cluster_pairs, roots, topo, scan, - gen); + mxnet::kvstore::KLGenerateBinaryTree(W, P, &cluster_pairs, &roots, &topo, + &scan, &gen); std::vector correct_topo = {0, 6, 4, 2, 0, 3, 6, 5, 4, 7, 2, 1}; std::vector correct_scan = {0, 0, 4}; ASSERT_EQ(topo.size(), correct_topo.size()); @@ -490,12 +485,12 @@ TEST(GpuTopology, TestKLGenerateBinaryTree2) { TEST(GpuTopology, TestUpdateWeight) { std::vector W = {0.f, 1.f, 1.f, 0.f}; - std::vector topo= {1, 1, 0}; + std::vector topo = {1, 1, 0}; int num_gpus = 2; float alpha = 0.7; std::vector correct_W = {0.f, 0.7f, 0.7f, 0.f}; - mxnet::kvstore::UpdateWeight(W, topo, num_gpus, alpha); + mxnet::kvstore::UpdateWeight(&W, topo, num_gpus, alpha); ASSERT_EQ(W.size(), correct_W.size()); for (unsigned i = 0; i < W.size(); ++i) { ASSERT_EQ(W[i], correct_W[i]); @@ -505,7 +500,7 @@ TEST(GpuTopology, TestUpdateWeight) { // BacktrackGenerateBinaryTree // ComputeTreesFromRoot TEST(GpuTopology, TestComputeTreesFromRoot) { - std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, 2, 0, 3, 2, 1, 3, 1, 1, 2, 3, 0, 3, 1, 1, 2, 1, 3, 2, 3, 0, 1, 1, 1, 2, @@ -522,8 +517,8 @@ TEST(GpuTopology, TestComputeTreesFromRoot) { std::vector topo; std::vector scan; - mxnet::kvstore::ComputeTreesFromRoot( W, num_gpus, root, alpha, backtrack, - topo, scan ); + mxnet::kvstore::ComputeTreesFromRoot(&W, num_gpus, root, alpha, backtrack, + &topo, &scan); ASSERT_EQ(topo.size(), correct_topo_size); ASSERT_EQ(scan.size(), correct_scan_size); @@ -534,10 +529,10 @@ TEST(GpuTopology, TestComputeTrees1) { std::mt19937 gen(1); float alpha = 0.7; bool backtrack = true; - // Do 100 randomized tests per GPU count from 2 to 16 + // Do 5 randomized tests per GPU count from 2 to 16 for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { for (int i = 0; i < 5; ++i) { - TestComputeTreesRandomized( num_gpus, alpha, backtrack, gen ); + TestComputeTreesRandomized(num_gpus, alpha, backtrack, &gen); } } } @@ -547,16 +542,16 @@ TEST(GpuTopology, TestComputeTrees2) { std::mt19937 gen(1); float alpha = 0.7; bool backtrack = false; - // Do 100 randomized tests per GPU count from 2 to 16 + // Do 5 randomized tests per GPU count from 2 to 16 for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { for (int i = 0; i < 5; ++i) { - TestComputeTreesRandomized( num_gpus, alpha, backtrack, gen ); + TestComputeTreesRandomized(num_gpus, alpha, backtrack, &gen); } } } TEST(GpuTopology, TestPermuteMatrix) { - std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, 2, 0, 3, 2, 1, 3, 1, 1, 2, 3, 0, 3, 1, 1, 2, 1, 3, 2, 3, 0, 1, 1, 1, 2, @@ -567,8 +562,8 @@ TEST(GpuTopology, TestPermuteMatrix) { std::vector P1 = {0, 1, 2, 3, 4, 5, 6, 7}; std::vector A(8*8, 0); - PermuteMatrix( W, P1, A ); - for (unsigned i=0; i P(6, 0); - std::vector> cluster_pairs; + std::vector> cluster_pairs; int num_partitions = 1; std::mt19937 gen(1); - bool stop = mxnet::kvstore::KernighanLin( W, P, num_partitions, cluster_pairs, gen ); + bool stop = mxnet::kvstore::KernighanLin(W, &P, &num_partitions, + &cluster_pairs, &gen); - std::vector> correct_pairs; - correct_pairs.push_back(std::make_pair(0,1)); + std::vector> correct_pairs; + correct_pairs.push_back(std::pair(0, 1)); std::vector correct_P = {0, 1, 0, 1, 1, 0}; ASSERT_EQ(stop, false); ASSERT_EQ(num_partitions, 2); @@ -601,10 +597,10 @@ TEST(GpuTopology, TestKernighanLin1) { if (P[i] != correct_P[i]) error++; } - EXPECT_TRUE (error == 0 || error == P.size()) - << "Where real value: " << error - << " not equal neither: " << 0 - << " nor: " << P.size() << "."; + EXPECT_TRUE(error == 0 || error == P.size()) + << "Where real value: " << error + << " not equal neither: " << 0 + << " nor: " << P.size() << "."; } TEST(GpuTopology, TestKernighanLin2) { @@ -617,13 +613,14 @@ TEST(GpuTopology, TestKernighanLin2) { 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0}; std::vector P(8, 0); - std::vector> cluster_pairs; + std::vector> cluster_pairs; int num_partitions = 1; std::mt19937 gen(1); - bool stop = mxnet::kvstore::KernighanLin( W, P, num_partitions, cluster_pairs, gen ); + bool stop = mxnet::kvstore::KernighanLin(W, &P, &num_partitions, + &cluster_pairs, &gen); - std::vector> correct_pairs; - correct_pairs.push_back(std::make_pair(0,1)); + std::vector> correct_pairs; + correct_pairs.push_back(std::pair(0, 1)); std::vector correct_P = {0, 0, 1, 1, 0, 0, 1, 1}; ASSERT_EQ(stop, false); ASSERT_EQ(num_partitions, 2); @@ -638,8 +635,8 @@ TEST(GpuTopology, TestKernighanLin2) { if (P[i] != correct_P[i]) error++; } - EXPECT_TRUE (error == 0 || error == P.size()) - << "Where real value: " << error - << " not equal neither: " << 0 - << " nor: " << P.size() << "."; + EXPECT_TRUE(error == 0 || error == P.size()) + << "Where real value: " << error + << " not equal neither: " << 0 + << " nor: " << P.size() << "."; } From 972e9c0de980b7ef31891ef8cbce30fa45bb11c4 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Wed, 27 Jun 2018 00:25:56 +0000 Subject: [PATCH 18/36] Add Python test using MXNET_KVSTORE_USETREE, fix CMake compilation problem, add header guard --- src/kvstore/comm.h | 77 +++++++++++++++------------- src/kvstore/comm_tree.h | 46 ++++++++--------- src/kvstore/gpu_topology.h | 52 +++++++++---------- src/kvstore/kvstore_local.h | 8 ++- tests/python/gpu/test_kvstore_gpu.py | 51 +++++++++++------- 5 files changed, 125 insertions(+), 109 deletions(-) diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 09928c9da81f..0b139f9a7c21 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -671,6 +671,42 @@ class CommDevice : public Comm { } } + using KeyAttrs = std::tuple; + // try to allocate buff on device evenly + void InitMergeBuffer(const std::vector& devs) { + std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), []( + const KeyAttrs& a, const KeyAttrs& b) { + return std::get<1>(a).Size() > std::get<1>(b).Size(); + }); + + std::unordered_map> ctx_info; + for (auto d : devs) { + ctx_info[d.dev_id] = std::make_pair(d, 0); + } + + for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { + const int key = std::get<0>(sorted_key_attrs_[i]); + const TShape& shape = std::get<1>(sorted_key_attrs_[i]); + const int type = std::get<2>(sorted_key_attrs_[i]); + auto& buf = merge_buf_[key]; + Context ctx; + size_t min_size = std::numeric_limits::max(); + for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) { + size_t size = it->second.second; + if (size <= min_size) { + ctx = it->second.first; + min_size = size; + } + } + // Delayed allocation - as the dense merged buffer might not be used at all if push() + // only sees sparse arrays + bool delay_alloc = true; + buf.merged = NDArray(shape, ctx, delay_alloc, type); + ctx_info[ctx.dev_id].second += shape.Size(); + } + inited_ = true; + } + private: void EnableP2P(const std::vector& devs) { #if MXNET_USE_CUDA @@ -714,43 +750,6 @@ class CommDevice : public Comm { #endif } - using KeyAttrs = std::tuple; - // try to allocate buff on device evenly - void InitMergeBuffer(const std::vector& devs) { - std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), []( - const KeyAttrs& a, const KeyAttrs& b) { - return std::get<1>(a).Size() > std::get<1>(b).Size(); - }); - - std::unordered_map> ctx_info; - for (auto d : devs) { - ctx_info[d.dev_id] = std::make_pair(d, 0); - } - - for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { - const int key = std::get<0>(sorted_key_attrs_[i]); - const TShape& shape = std::get<1>(sorted_key_attrs_[i]); - const int type = std::get<2>(sorted_key_attrs_[i]); - auto& buf = merge_buf_[key]; - Context ctx; - size_t min_size = std::numeric_limits::max(); - for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) { - size_t size = it->second.second; - if (size <= min_size) { - ctx = it->second.first; - min_size = size; - } - } - // Delayed allocation - as the dense merged buffer might not be used at all if push() - // only sees sparse arrays - bool delay_alloc = true; - buf.merged = NDArray(shape, ctx, delay_alloc, type); - ctx_info[ctx.dev_id].second += shape.Size(); - } - inited_ = true; - } - - std::vector sorted_key_attrs_; /// \brief temporal space for pushing and pulling struct BufferEntry { /// \brief the dense merged value for reduce and broadcast operations @@ -763,6 +762,8 @@ class CommDevice : public Comm { std::vector compressed_send_buf; /// \brief the small buffer for compressed data in receiver std::vector compressed_recv_buf; + /// \brief size of allocation in case we do not actually allocate merged + TShape merged_size; /// \brief the merged buffer for the given storage type (could be either dense or row_sparse) inline NDArray& merged_buf(NDArrayStorageType stype) { @@ -785,9 +786,11 @@ class CommDevice : public Comm { NDArray sparse_merged; }; std::unordered_map merge_buf_; + public: bool inited_; + std::vector sorted_key_attrs_; }; } // namespace kvstore diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 3e0de61fe8e4..b44d84401edc 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -60,7 +60,9 @@ class CommDeviceTree : public CommDevice { void Init(int key, const NDArrayStorageType stype, const TShape& shape, int dtype = mshadow::kFloat32) override { + tree_sorted_key_attrs_.emplace_back(key, shape, dtype); sorted_key_attrs_.emplace_back(key, shape, dtype); + bool delay_alloc = true; } void InitBuffersAndComm(const std::vector& src) { @@ -69,7 +71,13 @@ class CommDeviceTree : public CommDevice { devs_.push_back(a.ctx()); } QueryTopology(); - InitMergeBuffer(); + // Note: delayed allocation set to true, because we do not want to allocate + // both in TreeBufferEntry and BufferEntry, so we use a size_t to keep + // track of each key's shape within BufferEntry + // -this information is required for inherited Reduce- and + // BroadcastRowSparse + InitMergeBuffer(devs_); + InitMergeBufferTree(); if (dmlc::GetEnv("MXNET_ENABLE_GPU_P2P", 1)) { EnableP2P(); } @@ -362,7 +370,7 @@ class CommDeviceTree : public CommDevice { using KeyAttrs = std::tuple; // try to allocate buff on device evenly - void InitMergeBuffer() { + void InitMergeBufferTree() { LOG(WARNING) << "Using Tree"; // same as all-reduce, except: @@ -372,12 +380,13 @@ class CommDeviceTree : public CommDevice { for (unsigned i = 0; i < devs_.size(); ++i) tree_merge_buf_.push_back(std::unordered_map()); + bool delay_alloc = true; std::map key_dist; - for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { - const int key = std::get<0>(sorted_key_attrs_[i]); - const TShape& shape = std::get<1>(sorted_key_attrs_[i]); - const int type = std::get<2>(sorted_key_attrs_[i]); + for (size_t i = 0; i < tree_sorted_key_attrs_.size(); ++i) { + const int key = std::get<0>(tree_sorted_key_attrs_[i]); + const TShape& shape = std::get<1>(tree_sorted_key_attrs_[i]); + const int type = std::get<2>(tree_sorted_key_attrs_[i]); if (key_dist.find(shape.Size()) == key_dist.end()) key_dist[shape.Size()] = 1; @@ -414,13 +423,14 @@ class CommDeviceTree : public CommDevice { for (unsigned row = 0; row < devs_.size(); ++row) { if (row == devs_.size()-1) shape_copy[0] = last_slice; - buf.merged[row] = NDArray(shape_copy, ctx, false, type); + buf.merged[row] = NDArray(shape_copy, ctx, delay_alloc, type); buf.copy_buf.push_back(std::vector()); if (buf.copy_buf[row].empty()) { buf.copy_buf[row].resize(kBranch-1); for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { buf.copy_buf[row][col] = NDArray(buf.merged[row].shape(), - buf.merged[row].ctx(), false, + buf.merged[row].ctx(), + delay_alloc, buf.merged[row].dtype()); } } @@ -432,7 +442,7 @@ class CommDeviceTree : public CommDevice { buf.copy_buf[0].resize(kBranch-1); for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { buf.copy_buf[0][col] = NDArray(buf.merged[0].shape(), - buf.merged[0].ctx(), false, + buf.merged[0].ctx(), delay_alloc, buf.merged[0].dtype()); } } @@ -447,7 +457,7 @@ class CommDeviceTree : public CommDevice { inited_ = true; } - std::vector sorted_key_attrs_; + std::vector tree_sorted_key_attrs_; /// \brief temporal space for pushing and pulling struct TreeBufferEntry { /// \brief the dense merged value for reduce and broadcast operations @@ -461,22 +471,6 @@ class CommDeviceTree : public CommDevice { /// \brief the small buffer for compressed data in receiver std::vector compressed_recv_buf; - /// \brief the merged buffer for the given storage type (could be either dense or row_sparse) - inline NDArray& merged_buf(NDArrayStorageType stype) { - if (stype == kDefaultStorage) { - CHECK(merged.size() > 0 && !merged[0].is_none()) << "unintialized merge buffer detected"; - return merged[0]; - } - CHECK(stype == kRowSparseStorage) << "unexpected storage type " << stype; - // check if sparse_merged is initialized - if (sparse_merged.is_none()) { - CHECK(merged.size() > 0 && !merged[0].is_none()); - sparse_merged = NDArray(kRowSparseStorage, merged[0].shape(), - merged[0].ctx(), true, merged[0].dtype()); - } - return sparse_merged; - } - private: /// \brief the sparse merged value for reduce and rowsparse broadcast operations NDArray sparse_merged; diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 9808f39c848b..87bdb26cd867 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -35,13 +35,13 @@ #include #include -#define MAX_DEPTH 16 +#define MXNET_KVSTORE_MAXDEPTH 16 namespace mxnet { namespace kvstore { template -void PrintVector(const std::string& str, const std::vector& vec) { +inline void PrintVector(const std::string& str, const std::vector& vec) { std::cout << str << ":\n"; for (unsigned i = 0; i < vec.size(); ++i) std::cout << vec[i] << " "; @@ -49,7 +49,7 @@ void PrintVector(const std::string& str, const std::vector& vec) { } template -void PrintMatrix(const std::string& str, const std::vector& matrix, +inline void PrintMatrix(const std::string& str, const std::vector& matrix, int num_rows, int num_cols) { std::cout << str << ":\n"; int count = 0; @@ -61,7 +61,7 @@ void PrintMatrix(const std::string& str, const std::vector& matrix, } } -void PrintTopo(const std::string& str, const std::vector& topo_row, +inline void PrintTopo(const std::string& str, const std::vector& topo_row, std::vector scan_row) { PrintVector("Topo vector", topo_row); PrintVector("Scan vector", scan_row); @@ -89,7 +89,7 @@ void PrintTopo(const std::string& str, const std::vector& topo_row, // 2: 1 NVLink connection // 3: 2 NVLink connections template -void GetP2PWeight(const std::vector& devs, +inline void GetP2PWeight(const std::vector& devs, std::vector* matrix) { int num_gpus = devs.size(); int count = 0; @@ -140,7 +140,7 @@ void GetP2PWeight(const std::vector& devs, // Assume: matrix is square // y = A*x (no accumulate) template -void gemv(const std::vector& A, +inline void gemv(const std::vector& A, const std::vector& x, std::vector* y) { int nrows = x.size(); @@ -157,7 +157,7 @@ void gemv(const std::vector& A, // Element-wise multiplication between 2 dense vectors // w = w * alpha*u template -void ewisemult(const std::vector& u, +inline void ewisemult(const std::vector& u, T alpha, std::vector* w) { int nelem = u.size(); @@ -169,7 +169,7 @@ void ewisemult(const std::vector& u, // Element-wise addition between 2 dense vectors // w = w + alpha*u template -void ewiseadd(const std::vector& u, +inline void ewiseadd(const std::vector& u, T alpha, std::vector* w) { int nelem = u.size(); @@ -184,7 +184,7 @@ void ewiseadd(const std::vector& u, // Optimization: Only need to look at upper triangular since weight matrix is // symmetric template -void FindBestMove(const std::vector& W, +inline void FindBestMove(const std::vector& W, const std::vector& P_temp, const std::vector& D, const std::unordered_set& used, @@ -217,7 +217,7 @@ void FindBestMove(const std::vector& W, // cluster_pairs stores the mapping that tells us which 2 clusters are // the output of partitioning one large cluster template -bool KernighanLin(const std::vector& W, +inline bool KernighanLin(const std::vector& W, std::vector* P, int* num_partitions, std::vector>* cluster_pairs, @@ -376,7 +376,7 @@ bool KernighanLin(const std::vector& W, // Returns root of a given color if found in roots // Returns -1 if it is not found -int GetRoot(const std::vector& P, +inline int GetRoot(const std::vector& P, int color, const std::unordered_set& roots) { for (auto root : roots) { @@ -388,7 +388,7 @@ int GetRoot(const std::vector& P, // Returns root of a given color if found in roots // Returns -1 if it is not found -int GetChild(const std::vector& P, +inline int GetChild(const std::vector& P, int color, int parent) { for (unsigned i = 0; i < P.size(); ++i) { @@ -408,7 +408,7 @@ int GetChild(const std::vector& P, // g is weight of edge // Optimization: Only need to look at row a in matrix template -void FindBestEdge(const std::vector& W, +inline void FindBestEdge(const std::vector& W, const std::vector& P, int parent, int dest_cluster, @@ -443,7 +443,7 @@ void FindBestEdge(const std::vector& W, // topo_row says where new edges are appended to // scan_row says where we should start looking for topo_row template -int KLGenerateBinaryTree(const std::vector& W, +inline int KLGenerateBinaryTree(const std::vector& W, const std::vector& P, std::vector>* cluster_pairs, std::unordered_set* roots, @@ -529,8 +529,8 @@ int KLGenerateBinaryTree(const std::vector& W, // @input: n is the number of nodes in a balanced binary tree // @output: returns how many levels of binary tree there are -int ComputeDepth(int n) { - for (int depth = 0; depth < MAX_DEPTH; ++depth) { +inline int ComputeDepth(int n) { + for (int depth = 0; depth < MXNET_KVSTORE_MAXDEPTH; ++depth) { int num = 2 << depth; if (n <= num) return depth+1; @@ -544,7 +544,7 @@ int ComputeDepth(int n) { // -each edge in tree corresponds to link in network topology // -each edge in tree does not form self-loop template -bool IsValid(const std::vector& W, +inline bool IsValid(const std::vector& W, const std::vector& state, int num_elements, int row, @@ -626,7 +626,7 @@ bool IsValid(const std::vector& W, // 3 1 // 3 0 1 5 // 3 3 0 4 1 2 5 6 // GPU3 knows not to make redundant send to itself -void Postprocess(std::vector* result, int num_elements, int depth) { +inline void Postprocess(std::vector* result, int num_elements, int depth) { for (int level = depth - 1; level >= 0; --level) { int stride = 1 << level; std::vector histogram_above(num_elements, 0); @@ -657,7 +657,7 @@ void Postprocess(std::vector* result, int num_elements, int depth) { // -usually turned on when backtracking to get better solutions // -usually turned off when outside the penalty to get weight of tree template -T ComputeTreeWeight(const std::vector& W, +inline T ComputeTreeWeight(const std::vector& W, const std::vector& result, int num_elements, int depth, @@ -710,7 +710,7 @@ T ComputeTreeWeight(const std::vector& W, // 3 1 // 3 0 1 5 // 3 3 0 4 1 2 5 6 -void FormTopology(const std::vector& result, +inline void FormTopology(const std::vector& result, std::vector* topo_row, std::vector* scan_row, int depth) { @@ -735,7 +735,7 @@ void FormTopology(const std::vector& result, // -binary // -maximum weight template -bool RecursiveBacktrack(const std::vector& W, +inline bool RecursiveBacktrack(const std::vector& W, std::vector* state, std::vector* best_result, T* best_result_weight, @@ -771,7 +771,7 @@ bool RecursiveBacktrack(const std::vector& W, } template -void IterativeBacktrack(const std::vector& W, +inline void IterativeBacktrack(const std::vector& W, std::vector* state, std::vector* best_result, T* best_result_weight, @@ -837,7 +837,7 @@ void IterativeBacktrack(const std::vector& W, // Apply penalty factor alpha to each link in link topology graph that is used // by the spanning tree template -void UpdateWeight(std::vector* W, +inline void UpdateWeight(std::vector* W, const std::vector& topo_row, int num_elements, float alpha) { @@ -860,7 +860,7 @@ void UpdateWeight(std::vector* W, // 2) maximize edge weight // 3) tree is binary template -void BacktrackGenerateBinaryTree(std::vector* W, +inline void BacktrackGenerateBinaryTree(std::vector* W, int num_elements, int root, std::vector* topo_row, @@ -903,7 +903,7 @@ void BacktrackGenerateBinaryTree(std::vector* W, // ComputeTreesFromRoot does the same thing as ComputeTrees, with the only // exception being it will do it from a fixed GPU as root template -void ComputeTreesFromRoot(std::vector* W, +inline void ComputeTreesFromRoot(std::vector* W, int num_elements, int root, float alpha, @@ -990,7 +990,7 @@ void ComputeTreesFromRoot(std::vector* W, // @output: topo stores the trees generated // scan stores the start of each level of each tree template -void ComputeTrees(const std::vector& W, +inline void ComputeTrees(const std::vector& W, int num_elements, float alpha, bool backtrack, diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 89550ed49b66..3d85f70ca8c8 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -34,7 +34,9 @@ #include #include #include "./comm.h" -#include "./comm_tree.h" +#if MXNET_USE_CUDA + #include "./comm_tree.h" +#endif #include "./kvstore_utils.h" #include "../ndarray/ndarray_function.h" @@ -57,9 +59,11 @@ class KVStoreLocal : public KVStore { */ explicit KVStoreLocal(bool use_device_comm) : KVStore() { if (use_device_comm) { - bool tree = dmlc::GetEnv("MXNET_KVSTORE_USETREE", 0); + bool tree = dmlc::GetEnv("MXNET_KVSTORE_USETREE", 0) & MXNET_USE_CUDA; if (tree) { +#if MXNET_USE_CUDA comm_ = new CommDeviceTree(); +#endif } else { comm_ = new CommDevice(); } diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index 76231fbe90ee..2435bc784ee3 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -21,6 +21,7 @@ import mxnet as mx import numpy as np import unittest +import logging from mxnet.test_utils import assert_almost_equal, default_context curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) @@ -88,34 +89,48 @@ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False): # test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/9384 # check_rsp_push_pull('local') + os.environ["MXNET_KVSTORE_USETREE"] = "" + check_rsp_push_pull('device') + check_rsp_push_pull('device', is_push_cpu=False) + os.environ["MXNET_KVSTORE_USETREE"] = "1" + logging.info("Setting env to use tree reduce...") check_rsp_push_pull('device') check_rsp_push_pull('device', is_push_cpu=False) def test_row_sparse_pull_single_device(): - kvstore = mx.kv.create('device') - copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0)) - grad = copy.tostype("row_sparse") + envs = ["","1"] + for env in envs: + os.environ["MXNET_KVSTORE_USETREE"] = env - key = 0 - kvstore.init(key, grad) - idx = grad.indices - kvstore.push(key, grad) - kvstore.row_sparse_pull(key, out=grad, row_ids=idx) + kvstore = mx.kv.create('device') + copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0)) + grad = copy.tostype("row_sparse") - assert_almost_equal(grad.asnumpy(), copy.asnumpy()) + key = 0 + kvstore.init(key, grad) + idx = grad.indices + kvstore.push(key, grad) + kvstore.row_sparse_pull(key, out=grad, row_ids=idx) + assert_almost_equal(grad.asnumpy(), copy.asnumpy()) -def test_rsp_push_pull_large_rowid(): - num_rows = 793470 - val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu()) - kv = mx.kv.create('device') - kv.init('a', val) - out = mx.nd.zeros((num_rows,1), stype='row_sparse').copyto(mx.gpu()) - kv.push('a', val) - kv.row_sparse_pull('a', out=out, row_ids=mx.nd.arange(0, num_rows, dtype='int64')) - assert(out.indices.shape[0] == num_rows) +def test_rsp_push_pull_large_rowid(): + envs = ["","1"] + for env in envs: + os.environ["MXNET_KVSTORE_USETREE"] = env + + num_rows = 793470 + val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu()) + kv = mx.kv.create('device') + kv.init('a', val) + out = mx.nd.zeros((num_rows,1), stype='row_sparse').copyto(mx.gpu()) + kv.push('a', val) + kv.row_sparse_pull('a', out=out, row_ids=mx.nd.arange(0, num_rows, dtype='int64')) + assert(out.indices.shape[0] == num_rows) + + os.environ["MXNET_KVSTORE_USETREE"] = "" if __name__ == '__main__': import nose nose.runmodule() From 6627dcfdec13dbbd853991be478d25f90d0e582f Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Wed, 27 Jun 2018 00:28:44 +0000 Subject: [PATCH 19/36] fix lint errors --- src/kvstore/comm.h | 1 - src/kvstore/comm_tree.h | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 0b139f9a7c21..cbbd1e90e970 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -786,7 +786,6 @@ class CommDevice : public Comm { NDArray sparse_merged; }; std::unordered_map merge_buf_; - public: bool inited_; diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index b44d84401edc..59d6a0ad405e 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -72,9 +72,9 @@ class CommDeviceTree : public CommDevice { } QueryTopology(); // Note: delayed allocation set to true, because we do not want to allocate - // both in TreeBufferEntry and BufferEntry, so we use a size_t to keep + // both in TreeBufferEntry and BufferEntry, so we use a size_t to keep // track of each key's shape within BufferEntry - // -this information is required for inherited Reduce- and + // -this information is required for inherited Reduce- and // BroadcastRowSparse InitMergeBuffer(devs_); InitMergeBufferTree(); @@ -429,7 +429,7 @@ class CommDeviceTree : public CommDevice { buf.copy_buf[row].resize(kBranch-1); for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { buf.copy_buf[row][col] = NDArray(buf.merged[row].shape(), - buf.merged[row].ctx(), + buf.merged[row].ctx(), delay_alloc, buf.merged[row].dtype()); } From 4de89a75de6e89a925c1f3a88f3e2a4c8d6db35b Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Tue, 26 Jun 2018 23:35:54 -0700 Subject: [PATCH 20/36] better header guard that works for tests --- src/kvstore/gpu_topology.h | 10 ++++++++-- src/kvstore/kvstore_local.h | 6 +----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 87bdb26cd867..f029b9d0afef 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -22,8 +22,10 @@ */ #ifndef MXNET_KVSTORE_GPU_TOPOLOGY_H_ #define MXNET_KVSTORE_GPU_TOPOLOGY_H_ -#include -#include +#if MXNET_USE_CUDA + #include + #include +#endif #include #include #include @@ -99,6 +101,7 @@ inline void GetP2PWeight(const std::vector& devs, count++; } +#if MXNET_USE_CUDA cudaDeviceP2PAttr attr; attr = cudaDevP2PAttrPerformanceRank; std::vector max(num_gpus, 0); @@ -134,6 +137,9 @@ inline void GetP2PWeight(const std::vector& devs, } } PrintMatrix("Weight W", *matrix, num_gpus, num_gpus); +#else + LOG(WARNING) << "GPU required for link topology"; +#endif } // Dense matrix-vector multiplication diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 3d85f70ca8c8..791ad3362010 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -34,9 +34,7 @@ #include #include #include "./comm.h" -#if MXNET_USE_CUDA - #include "./comm_tree.h" -#endif +#include "./comm_tree.h" #include "./kvstore_utils.h" #include "../ndarray/ndarray_function.h" @@ -61,9 +59,7 @@ class KVStoreLocal : public KVStore { if (use_device_comm) { bool tree = dmlc::GetEnv("MXNET_KVSTORE_USETREE", 0) & MXNET_USE_CUDA; if (tree) { -#if MXNET_USE_CUDA comm_ = new CommDeviceTree(); -#endif } else { comm_ = new CommDevice(); } From 317c66bfb24d11ba18faf0f86dfd0f5b8d7ae8b7 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Tue, 26 Jun 2018 23:44:14 -0700 Subject: [PATCH 21/36] get rid of unused variable warning --- src/kvstore/comm_tree.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 59d6a0ad405e..f2cf4861ca2b 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -62,7 +62,6 @@ class CommDeviceTree : public CommDevice { int dtype = mshadow::kFloat32) override { tree_sorted_key_attrs_.emplace_back(key, shape, dtype); sorted_key_attrs_.emplace_back(key, shape, dtype); - bool delay_alloc = true; } void InitBuffersAndComm(const std::vector& src) { From c364fd3b1672e7f5f0fd91a399c585aae22b8a27 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Thu, 28 Jun 2018 23:54:45 +0000 Subject: [PATCH 22/36] retrigger jenkins --- src/kvstore/gpu_topology.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index f029b9d0afef..311d4f5f9453 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -1038,7 +1038,6 @@ inline void ComputeTrees(const std::vector& W, PrintMatrix("W", W, num_elements, num_elements); PrintMatrix("Links", adj, num_elements, num_elements);*/ } - } // namespace kvstore } // namespace mxnet #endif // MXNET_KVSTORE_GPU_TOPOLOGY_H_ From 3241d71662e19b92a47cbf38b98b738bc6bfbfc3 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Fri, 29 Jun 2018 19:05:09 +0000 Subject: [PATCH 23/36] resolve 2 comments --- src/kvstore/comm.h | 2 - src/kvstore/gpu_topology.h | 150 +++++++++++++++++-------------------- 2 files changed, 69 insertions(+), 83 deletions(-) diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index cbbd1e90e970..43081dec25c4 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -762,8 +762,6 @@ class CommDevice : public Comm { std::vector compressed_send_buf; /// \brief the small buffer for compressed data in receiver std::vector compressed_recv_buf; - /// \brief size of allocation in case we do not actually allocate merged - TShape merged_size; /// \brief the merged buffer for the given storage type (could be either dense or row_sparse) inline NDArray& merged_buf(NDArrayStorageType stype) { diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 311d4f5f9453..04ad8963d1b5 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -92,7 +92,7 @@ inline void PrintTopo(const std::string& str, const std::vector& topo_ro // 3: 2 NVLink connections template inline void GetP2PWeight(const std::vector& devs, - std::vector* matrix) { + std::vector* matrix) { int num_gpus = devs.size(); int count = 0; std::vector zero_dev_id(num_gpus, -1); @@ -147,8 +147,8 @@ inline void GetP2PWeight(const std::vector& devs, // y = A*x (no accumulate) template inline void gemv(const std::vector& A, - const std::vector& x, - std::vector* y) { + const std::vector& x, + std::vector* y) { int nrows = x.size(); int count = 0; for (int row=0; row < nrows; ++row) { @@ -164,26 +164,14 @@ inline void gemv(const std::vector& A, // w = w * alpha*u template inline void ewisemult(const std::vector& u, - T alpha, - std::vector* w) { + T alpha, + std::vector* w) { int nelem = u.size(); for (int i=0; i < nelem; ++i) { (*w)[i] *= alpha*static_cast(u[i]); } } -// Element-wise addition between 2 dense vectors -// w = w + alpha*u -template -inline void ewiseadd(const std::vector& u, - T alpha, - std::vector* w) { - int nelem = u.size(); - for (int i=0; i < nelem; ++i) { - (*w)[i] += alpha*static_cast(u[i]); - } -} - // Computes best 2 nodes a,b to swap given objective function: // g = max_{a \in A, b \in B} D(a) + D(b) - 2*W(a,b) // @@ -191,8 +179,8 @@ inline void ewiseadd(const std::vector& u, // symmetric template inline void FindBestMove(const std::vector& W, - const std::vector& P_temp, - const std::vector& D, + const std::vector& P_temp, + const std::vector& D, const std::unordered_set& used, int* a, int* b, @@ -224,10 +212,10 @@ inline void FindBestMove(const std::vector& W, // the output of partitioning one large cluster template inline bool KernighanLin(const std::vector& W, - std::vector* P, - int* num_partitions, - std::vector>* cluster_pairs, - std::mt19937* gen) { + std::vector* P, + int* num_partitions, + std::vector>* cluster_pairs, + std::mt19937* gen) { std::vector histogram(*num_partitions, 0); std::vector P_temp(P->size(), 0); std::vector P_temp2(P->size(), 0); @@ -383,8 +371,8 @@ inline bool KernighanLin(const std::vector& W, // Returns root of a given color if found in roots // Returns -1 if it is not found inline int GetRoot(const std::vector& P, - int color, - const std::unordered_set& roots) { + int color, + const std::unordered_set& roots) { for (auto root : roots) { if (P[root] == color) return root; @@ -395,8 +383,8 @@ inline int GetRoot(const std::vector& P, // Returns root of a given color if found in roots // Returns -1 if it is not found inline int GetChild(const std::vector& P, - int color, - int parent) { + int color, + int parent) { for (unsigned i = 0; i < P.size(); ++i) { if (P[i] == color && static_cast(i) != parent) return i; @@ -415,11 +403,11 @@ inline int GetChild(const std::vector& P, // Optimization: Only need to look at row a in matrix template inline void FindBestEdge(const std::vector& W, - const std::vector& P, - int parent, - int dest_cluster, - std::vector* b, - T* g) { + const std::vector& P, + int parent, + int dest_cluster, + std::vector* b, + T* g) { int nrows = P.size(); int row = parent; *g = 0; @@ -450,12 +438,12 @@ inline void FindBestEdge(const std::vector& W, // scan_row says where we should start looking for topo_row template inline int KLGenerateBinaryTree(const std::vector& W, - const std::vector& P, - std::vector>* cluster_pairs, - std::unordered_set* roots, - std::vector* topo_row, - std::vector* scan_row, - std::mt19937* gen) { + const std::vector& P, + std::vector>* cluster_pairs, + std::unordered_set* roots, + std::vector* topo_row, + std::vector* scan_row, + std::mt19937* gen) { std::unordered_set new_roots; std::unordered_map new_topo; int reset = 0; @@ -551,10 +539,10 @@ inline int ComputeDepth(int n) { // -each edge in tree does not form self-loop template inline bool IsValid(const std::vector& W, - const std::vector& state, - int num_elements, - int row, - int depth) { + const std::vector& state, + int num_elements, + int row, + int depth) { // At each level of tree, check whether edge: // -corresponds to link in network topology // -corresponds to self-loop @@ -664,10 +652,10 @@ inline void Postprocess(std::vector* result, int num_elements, int depth) { // -usually turned off when outside the penalty to get weight of tree template inline T ComputeTreeWeight(const std::vector& W, - const std::vector& result, - int num_elements, - int depth, - bool penalty) { + const std::vector& result, + int num_elements, + int depth, + bool penalty) { T weight = 0.f; std::unordered_set links_used; @@ -717,9 +705,9 @@ inline T ComputeTreeWeight(const std::vector& W, // 3 0 1 5 // 3 3 0 4 1 2 5 6 inline void FormTopology(const std::vector& result, - std::vector* topo_row, - std::vector* scan_row, - int depth) { + std::vector* topo_row, + std::vector* scan_row, + int depth) { scan_row->push_back(topo_row->size()); for (int i = depth; i > 0; --i) { int stride = 1 << i; @@ -742,13 +730,13 @@ inline void FormTopology(const std::vector& result, // -maximum weight template inline bool RecursiveBacktrack(const std::vector& W, - std::vector* state, - std::vector* best_result, - T* best_result_weight, - int row, - int num_elements, - int depth, - bool optimal) { + std::vector* state, + std::vector* best_result, + T* best_result_weight, + int row, + int num_elements, + int depth, + bool optimal) { if (row == static_cast(state->size())) { std::vector result = *state; Postprocess(&result, num_elements, depth); @@ -778,13 +766,13 @@ inline bool RecursiveBacktrack(const std::vector& W, template inline void IterativeBacktrack(const std::vector& W, - std::vector* state, - std::vector* best_result, - T* best_result_weight, - int row, - int num_elements, - int depth, - bool optimal) { + std::vector* state, + std::vector* best_result, + T* best_result_weight, + int row, + int num_elements, + int depth, + bool optimal) { std::stack state_stack; row = 1; int pos = 0; @@ -844,9 +832,9 @@ inline void IterativeBacktrack(const std::vector& W, // by the spanning tree template inline void UpdateWeight(std::vector* W, - const std::vector& topo_row, - int num_elements, - float alpha) { + const std::vector& topo_row, + int num_elements, + float alpha) { for (unsigned i = 1; i < topo_row.size() - 1; i += 2) { unsigned parent = topo_row[i]; unsigned child = topo_row[i+1]; @@ -867,10 +855,10 @@ inline void UpdateWeight(std::vector* W, // 3) tree is binary template inline void BacktrackGenerateBinaryTree(std::vector* W, - int num_elements, - int root, - std::vector* topo_row, - std::vector* scan_row) { + int num_elements, + int root, + std::vector* topo_row, + std::vector* scan_row) { // Clear before starting topo_row->clear(); scan_row->clear(); @@ -910,12 +898,12 @@ inline void BacktrackGenerateBinaryTree(std::vector* W, // exception being it will do it from a fixed GPU as root template inline void ComputeTreesFromRoot(std::vector* W, - int num_elements, - int root, - float alpha, - bool backtrack, - std::vector* topo, - std::vector* scan) { + int num_elements, + int root, + float alpha, + bool backtrack, + std::vector* topo, + std::vector* scan) { int num_partitions = 1; // Initialize partition array to indicate which partition each element belongs @@ -997,11 +985,11 @@ inline void ComputeTreesFromRoot(std::vector* W, // scan stores the start of each level of each tree template inline void ComputeTrees(const std::vector& W, - int num_elements, - float alpha, - bool backtrack, - std::vector>* topo, - std::vector>* scan) { + int num_elements, + float alpha, + bool backtrack, + std::vector>* topo, + std::vector>* scan) { std::vector W_copy = W; topo->clear(); From bd926bf57b0c1dc4af6d1f9c0b9f91bfa25d7871 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 2 Jul 2018 22:18:20 +0000 Subject: [PATCH 24/36] address comment using Class to do test, get rid of extraneous test, use PCI-E as fallback for GPUs that are not linked by NVLink --- src/kvstore/gpu_topology.h | 40 ++++++++++++- tests/cpp/kvstore/gpu_topology_test.cc | 65 ++++++++++++++++------ tests/python/gpu/test_kvstore_gpu.py | 77 +++++++++++++++----------- 3 files changed, 130 insertions(+), 52 deletions(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 04ad8963d1b5..bb466bc4f58b 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -82,6 +83,39 @@ inline void PrintTopo(const std::string& str, const std::vector& topo_ro } } +// Uses BFS to find whether undirected graph is connected or not given its +// adjacency matrix +// Note: only consider matrix values > 1, because we care about whether it is +// connected using only NVLink connections +template +inline bool IsConnected(const std::vector& matrix, + int num_gpus) { + int source = 0; + std::vector visited(num_gpus, false); + std::queue work_list; + + work_list.push(source); + visited[source] = true; + while (!work_list.empty()) { + int curr = work_list.front(); + work_list.pop(); + + for (int i = 0; i < num_gpus; ++i) { + int neighbour = matrix[curr*num_gpus + i]; + if (i != curr && neighbour > 1 && visited[i] == false) { + visited[i] = true; + work_list.push(i); + } + } + } + + for (int i = 0; i < num_gpus; ++i) { + if (visited[i] == false) + return false; + } + return true; +} + // Generate adjacency matrix with row/col numbering from 0, 1, ..., n_gpu // @input: devs is a vector of GPU contexts // @output: matrix is adjacency matrix of link topology graph @@ -129,9 +163,11 @@ inline void GetP2PWeight(const std::vector& devs, max_value = max[i]; } - // If all GPUs have at least 1 NVLink connection, then we can use NVLink only + // If all GPUs are connected by NVLink, then we can use NVLink only // to communicate instead of going over PCI-E - if (max_value > 0) { + bool connected = IsConnected(*matrix, num_gpus); + + if (connected) { for (auto& matrix_value : *matrix) { matrix_value = (matrix_value == 1) ? 0 : matrix_value; } diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index 1aed5f568cb3..1fba05ef64cf 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -267,21 +267,6 @@ TEST(GpuTopology, TestEwisemult) { ASSERT_EQ(y[i], correct_y[i]); } -// ewiseaddTest -TEST(GpuTopology, TestEwiseadd) { - std::vector x(8, 1); - std::vector y(8, 0); - std::iota(y.begin(), y.end(), 0); - int alpha = 5; - std::vector correct_y(8, 0); - std::iota(correct_y.begin(), correct_y.end(), 5); - mxnet::kvstore::ewiseadd(x, alpha, &y); - - ASSERT_EQ(y.size(), correct_y.size()); - for (unsigned i = 0; i < y.size(); ++i) - ASSERT_EQ(y[i], correct_y[i]); -} - // FindBestMoveTest TEST(GpuTopology, TestFindBestMove) { std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, @@ -496,10 +481,9 @@ TEST(GpuTopology, TestUpdateWeight) { ASSERT_EQ(W[i], correct_W[i]); } } -// Backtrack -// BacktrackGenerateBinaryTree + // ComputeTreesFromRoot -TEST(GpuTopology, TestComputeTreesFromRoot) { +TEST(GpuTopology, TestComputeTreesFromRoot1) { std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, 2, 0, 3, 2, 1, 3, 1, 1, 2, 3, 0, 3, 1, 1, 2, 1, @@ -524,6 +508,51 @@ TEST(GpuTopology, TestComputeTreesFromRoot) { ASSERT_EQ(scan.size(), correct_scan_size); } +// IsConnected +// Test on graph that is "disconnected" by NVLink +TEST(GpuTopology, TestIsConnected1) { + std::vector W = {0, 0, 2, 0, + 0, 0, 0, 2, + 2, 0, 0, 0, + 0, 2, 0, 0}; + int num_gpus = 4; + + bool connected = mxnet::kvstore::IsConnected(W, num_gpus); + + bool correct_connected = false; + ASSERT_EQ(connected, correct_connected); +} + +// IsConnected +// Test on graph that is "disconnected" by NVLink +TEST(GpuTopology, TestIsConnected2) { + std::vector W = {1, 1, 2, 1, + 1, 1, 1, 2, + 2, 1, 1, 1, + 1, 2, 1, 1}; + int num_gpus = 4; + + bool connected = mxnet::kvstore::IsConnected(W, num_gpus); + + bool correct_connected = false; + ASSERT_EQ(connected, correct_connected); +} + +// IsConnected +// Test on graph that is "disconnected" by NVLink +TEST(GpuTopology, TestIsConnected3) { + std::vector W = {1, 1, 2, 2, + 1, 1, 1, 2, + 2, 1, 1, 1, + 2, 2, 1, 1}; + int num_gpus = 4; + + bool connected = mxnet::kvstore::IsConnected(W, num_gpus); + + bool correct_connected = true; + ASSERT_EQ(connected, correct_connected); +} + // ComputeTreesTest with backtracking TEST(GpuTopology, TestComputeTrees1) { std::mt19937 gen(1); diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index 2435bc784ee3..4406a50060e9 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -31,6 +31,21 @@ keys = [5, 7, 11] str_keys = ['b', 'c', 'd'] +class EnvManager: + def __init__(self, key, val): + self._key = key + self._next_val = val + self._prev_val = None + + def __enter__(self): + try: + self._prev_val = os.environ[self._key] + except KeyError: + self._prev_val = "" + os.environ[self._key] = self._next_val + + def __exit__(self, ptype, value, trace): + os.environ[self._key] = self._prev_val def init_kv_with_str(stype='default', kv_type='local'): """init kv """ @@ -89,48 +104,46 @@ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False): # test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/9384 # check_rsp_push_pull('local') - os.environ["MXNET_KVSTORE_USETREE"] = "" - check_rsp_push_pull('device') - check_rsp_push_pull('device', is_push_cpu=False) - os.environ["MXNET_KVSTORE_USETREE"] = "1" - logging.info("Setting env to use tree reduce...") - check_rsp_push_pull('device') - check_rsp_push_pull('device', is_push_cpu=False) + envs = ["","1"] + key = "MXNET_KVSTORE_USETREE" + for val in envs: + with EnvManager(key, val): + check_rsp_push_pull('device') + check_rsp_push_pull('device', is_push_cpu=False) def test_row_sparse_pull_single_device(): envs = ["","1"] - for env in envs: - os.environ["MXNET_KVSTORE_USETREE"] = env - - kvstore = mx.kv.create('device') - copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0)) - grad = copy.tostype("row_sparse") + key = "MXNET_KVSTORE_USETREE" + for val in envs: + with EnvManager(key, val): + kvstore = mx.kv.create('device') + copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0)) + grad = copy.tostype("row_sparse") - key = 0 - kvstore.init(key, grad) - idx = grad.indices - kvstore.push(key, grad) - kvstore.row_sparse_pull(key, out=grad, row_ids=idx) + k = 0 + kvstore.init(k, grad) + idx = grad.indices + kvstore.push(k, grad) + kvstore.row_sparse_pull(k, out=grad, row_ids=idx) - assert_almost_equal(grad.asnumpy(), copy.asnumpy()) + assert_almost_equal(grad.asnumpy(), copy.asnumpy()) def test_rsp_push_pull_large_rowid(): envs = ["","1"] - for env in envs: - os.environ["MXNET_KVSTORE_USETREE"] = env - - num_rows = 793470 - val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu()) - kv = mx.kv.create('device') - kv.init('a', val) - out = mx.nd.zeros((num_rows,1), stype='row_sparse').copyto(mx.gpu()) - kv.push('a', val) - kv.row_sparse_pull('a', out=out, row_ids=mx.nd.arange(0, num_rows, dtype='int64')) - assert(out.indices.shape[0] == num_rows) - - os.environ["MXNET_KVSTORE_USETREE"] = "" + key = "MXNET_KVSTORE_USETREE" + for val in envs: + with EnvManager(key, val): + num_rows = 793470 + val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu()) + kv = mx.kv.create('device') + kv.init('a', val) + out = mx.nd.zeros((num_rows,1), stype='row_sparse').copyto(mx.gpu()) + kv.push('a', val) + kv.row_sparse_pull('a', out=out, row_ids=mx.nd.arange(0, num_rows, dtype='int64')) + assert(out.indices.shape[0] == num_rows) + if __name__ == '__main__': import nose nose.runmodule() From a29f284e1f3dac34ff32af4aeb091b799332e910 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Fri, 13 Jul 2018 22:42:47 +0000 Subject: [PATCH 25/36] address comments --- src/kvstore/comm_tree.h | 45 +++--- src/kvstore/gpu_topology.h | 228 ++++++++++++--------------- tests/python/gpu/test_kvstore_gpu.py | 51 +++++- 3 files changed, 176 insertions(+), 148 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index f2cf4861ca2b..c5ba97ba52f8 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -83,9 +83,14 @@ class CommDeviceTree : public CommDevice { } } - // src is sliced shape - // copy_buf not sliced - // merged not sliced + /** + * \brief Reduce src to tree_merge_buf_ + * \param key is the id of the gradient we are doing Reduce on + * \param src is the array of values located on different GPUs + * \param root is the id of the GPU we want to send result of reduce to + * \param merged_row is the id of the slice we are taking + * \param priority the priority of the operation + */ const NDArray& ReduceInner(int key, const std::vector& src, int root, int merged_row, int priority) { std::vector> reduce(devs_.size()); @@ -98,8 +103,8 @@ class CommDeviceTree : public CommDevice { if (stype == kDefaultStorage) { // Copy everything into buf.merged for each gpu for (size_t i = 0; i < src.size(); ++i) { - int start = scan_[root][depth_ ]; - int end = scan_[root][depth_+1]; + int start = scan_[root][depth_]; + int end = scan_[root][depth_+1]; for (int j = start; j < end; ++j) { int topo_id = topology[j]; @@ -113,13 +118,13 @@ class CommDeviceTree : public CommDevice { for (int level = depth_; level > 0; --level) { int start = scan_[root][level ]; - int end = scan_[root][level+1]; + int end = scan_[root][level+1]; unsigned is_dest = 0; - int dest_id = 0; + int dest_id = 0; for (int j = start; j < end; ++j) { int topo_id = topology[j]; - dest_id = (is_dest == 0) ? topo_id : dest_id; + dest_id = (is_dest == 0) ? topo_id : dest_id; TreeBufferEntry& buf_dest = tree_merge_buf_[dest_id][key]; TreeBufferEntry& buf_from = tree_merge_buf_[topo_id][key]; @@ -141,7 +146,7 @@ class CommDeviceTree : public CommDevice { } start = scan_[root][level-1]; - end = scan_[root][level ]; + end = scan_[root][level]; for (int i = start; i < end; ++i) { int gpu_id = topology[i]; @@ -158,7 +163,7 @@ class CommDeviceTree : public CommDevice { } } } else { - LOG(WARNING) << "Only dense input supported for now"; + LOG(FATAL) << "Only dense input supported for now"; } int topo_id = topology[0]; @@ -231,7 +236,7 @@ class CommDeviceTree : public CommDevice { } // Copy from list of small NDArrays to one big NDArray, which is returned - int gpu_id = 0; + int gpu_id = 0; return src[gpu_id]; } else { // sparse reduce @@ -252,13 +257,13 @@ class CommDeviceTree : public CommDevice { for (int level = 1; level <= depth_; ++level) { int start = scan_[root][level]; - int end = scan_[root][level+1]; + int end = scan_[root][level+1]; unsigned is_src = 0; - int src_id = 0; + int src_id = 0; for (int j = start; j < end; ++j) { int topo_id = topology[j]; - src_id = (is_src == 0) ? topo_id : src_id; + src_id = (is_src == 0) ? topo_id : src_id; if (is_src && src_id != topo_id) { CopyFromTo(temp[src_id], dst[topo_id], priority); @@ -392,8 +397,8 @@ class CommDeviceTree : public CommDevice { else key_dist[shape.Size()]++; - int start = scan_[0][depth_ ]; - int end = scan_[0][depth_+1]; + int start = scan_[0][depth_]; + int end = scan_[0][depth_+1]; // In order to generalize to any number of GPUs, we use strategy of having // found the mapping from 0, 1, ..., n_gpus to dev_id i.e. @@ -484,10 +489,10 @@ class CommDeviceTree : public CommDevice { std::vector devs_; /// \brief Highest numbered device - int max_dev_; - int depth_; - int gpuarray_bound_; - bool backtrack_; + int max_dev_; + int depth_; + int gpuarray_bound_; + bool backtrack_; float link_usage_penalty_; /// \brief constant for maximum size of recv buffer per GPU diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index bb466bc4f58b..3b67d2395082 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -43,6 +43,8 @@ namespace mxnet { namespace kvstore { +static bool kLogTree = dmlc::GetEnv("MXNET_KVSTORE_LOGTREE", false); + template inline void PrintVector(const std::string& str, const std::vector& vec) { std::cout << str << ":\n"; @@ -72,7 +74,7 @@ inline void PrintTopo(const std::string& str, const std::vector& topo_ro int depth = scan_row.size()-1; for (int row = 0; row < depth; ++row) { int start = scan_row[row]; - int end = scan_row[row+1]; + int end = scan_row[row+1]; for (; start < end; start++) { for (int i = 0; i < (2 << (depth-row-2))+1; ++i) { std::cout << " "; @@ -88,8 +90,7 @@ inline void PrintTopo(const std::string& str, const std::vector& topo_ro // Note: only consider matrix values > 1, because we care about whether it is // connected using only NVLink connections template -inline bool IsConnected(const std::vector& matrix, - int num_gpus) { +inline bool IsConnected(const std::vector& matrix, int num_gpus) { int source = 0; std::vector visited(num_gpus, false); std::queue work_list; @@ -125,8 +126,7 @@ inline bool IsConnected(const std::vector& matrix, // 2: 1 NVLink connection // 3: 2 NVLink connections template -inline void GetP2PWeight(const std::vector& devs, - std::vector* matrix) { +inline void GetP2PWeight(const std::vector& devs, std::vector* matrix) { int num_gpus = devs.size(); int count = 0; std::vector zero_dev_id(num_gpus, -1); @@ -172,7 +172,8 @@ inline void GetP2PWeight(const std::vector& devs, matrix_value = (matrix_value == 1) ? 0 : matrix_value; } } - PrintMatrix("Weight W", *matrix, num_gpus, num_gpus); + if (kLogTree) + PrintMatrix("Weight W", *matrix, num_gpus, num_gpus); #else LOG(WARNING) << "GPU required for link topology"; #endif @@ -182,9 +183,8 @@ inline void GetP2PWeight(const std::vector& devs, // Assume: matrix is square // y = A*x (no accumulate) template -inline void gemv(const std::vector& A, - const std::vector& x, - std::vector* y) { +inline void gemv(const std::vector& A, const std::vector& x, + std::vector* y) { int nrows = x.size(); int count = 0; for (int row=0; row < nrows; ++row) { @@ -199,9 +199,7 @@ inline void gemv(const std::vector& A, // Element-wise multiplication between 2 dense vectors // w = w * alpha*u template -inline void ewisemult(const std::vector& u, - T alpha, - std::vector* w) { +inline void ewisemult(const std::vector& u, T alpha, std::vector* w) { int nelem = u.size(); for (int i=0; i < nelem; ++i) { (*w)[i] *= alpha*static_cast(u[i]); @@ -214,13 +212,11 @@ inline void ewisemult(const std::vector& u, // Optimization: Only need to look at upper triangular since weight matrix is // symmetric template -inline void FindBestMove(const std::vector& W, - const std::vector& P_temp, - const std::vector& D, - const std::unordered_set& used, - int* a, - int* b, - T* g) { +inline void FindBestMove(const std::vector& W, + const std::vector& P_temp, + const std::vector& D, + const std::unordered_set& used, + int* a, int* b, T* g) { int nrows = P_temp.size(); *g = 0; *a = -1; @@ -247,16 +243,15 @@ inline void FindBestMove(const std::vector& W, // cluster_pairs stores the mapping that tells us which 2 clusters are // the output of partitioning one large cluster template -inline bool KernighanLin(const std::vector& W, - std::vector* P, - int* num_partitions, +inline bool KernighanLin(const std::vector& W, std::vector* P, + int* num_partitions, std::vector>* cluster_pairs, - std::mt19937* gen) { + std::mt19937* gen) { std::vector histogram(*num_partitions, 0); std::vector P_temp(P->size(), 0); std::vector P_temp2(P->size(), 0); - std::vector D(P->size(), 0); - std::vector D_temp(P->size(), 0); + std::vector D(P->size(), 0); + std::vector D_temp(P->size(), 0); // 0) For every partition, determine if it can be partitioned further. // To do this, we must do a histogram of each partition: @@ -264,7 +259,7 @@ inline bool KernighanLin(const std::vector& W, histogram[(*P)[i]]++; } - bool stop = true; + bool stop = true; for (unsigned color=0; color < histogram.size(); ++color) { int partition_size = histogram[color]; // Save cluster in preparation for push to topo in GenerateBinaryTree() @@ -309,8 +304,8 @@ inline bool KernighanLin(const std::vector& W, } // 2) Do iterations of Kernighan-Lin until convergence - T g_max = 0; - int g_k = -1; + T g_max = 0; + int g_k = -1; unsigned count = 0; do { count++; @@ -325,14 +320,14 @@ inline bool KernighanLin(const std::vector& W, // gv stores the score associated with move std::vector av; std::vector bv; - std::vector gv; + std::vector gv; std::unordered_set used; for (int iter=0; iter < partition_size/2; ++iter) { // b) Find best move by looking through upper triangular of W matrix int a, b; - T g; + T g; FindBestMove(W, P_temp, D, used, &a, &b, &g); if (g > 0) { } else { @@ -365,7 +360,7 @@ inline bool KernighanLin(const std::vector& W, gv[k] += gv[k-1]; if (gv[k] > g_max) { g_max = gv[k]; - g_k = k + 1; + g_k = k + 1; } } @@ -373,13 +368,13 @@ inline bool KernighanLin(const std::vector& W, // Otherwise, rollback changes to P_temp2 if (g_max > 0) { for (int i = 0; i < g_k; i++) { - int a = av[i]; - int b = bv[i]; - int temp = P_temp2[a]; + int a = av[i]; + int b = bv[i]; + int temp = P_temp2[a]; P_temp2[a] = P_temp2[b]; P_temp2[b] = temp; - P_temp = P_temp2; + P_temp = P_temp2; } } else { P_temp = P_temp2; @@ -406,8 +401,7 @@ inline bool KernighanLin(const std::vector& W, // Returns root of a given color if found in roots // Returns -1 if it is not found -inline int GetRoot(const std::vector& P, - int color, +inline int GetRoot(const std::vector& P, int color, const std::unordered_set& roots) { for (auto root : roots) { if (P[root] == color) @@ -418,9 +412,7 @@ inline int GetRoot(const std::vector& P, // Returns root of a given color if found in roots // Returns -1 if it is not found -inline int GetChild(const std::vector& P, - int color, - int parent) { +inline int GetChild(const std::vector& P, int color, int parent) { for (unsigned i = 0; i < P.size(); ++i) { if (P[i] == color && static_cast(i) != parent) return i; @@ -438,15 +430,11 @@ inline int GetChild(const std::vector& P, // g is weight of edge // Optimization: Only need to look at row a in matrix template -inline void FindBestEdge(const std::vector& W, - const std::vector& P, - int parent, - int dest_cluster, - std::vector* b, - T* g) { +inline void FindBestEdge(const std::vector& W, const std::vector& P, + int parent, int dest_cluster, std::vector* b, T* g) { int nrows = P.size(); - int row = parent; - *g = 0; + int row = parent; + *g = 0; b->push_back(-1); for (int col=0; col < nrows; ++col) { if (col == row || P[col] != dest_cluster) continue; @@ -473,14 +461,14 @@ inline void FindBestEdge(const std::vector& W, // topo_row says where new edges are appended to // scan_row says where we should start looking for topo_row template -inline int KLGenerateBinaryTree(const std::vector& W, - const std::vector& P, +inline int KLGenerateBinaryTree(const std::vector& W, + const std::vector& P, std::vector>* cluster_pairs, - std::unordered_set* roots, - std::vector* topo_row, - std::vector* scan_row, - std::mt19937* gen) { - std::unordered_set new_roots; + std::unordered_set* roots, + std::vector* topo_row, + std::vector* scan_row, + std::mt19937* gen) { + std::unordered_set new_roots; std::unordered_map new_topo; int reset = 0; @@ -491,20 +479,20 @@ inline int KLGenerateBinaryTree(const std::vector& W, if ((*cluster_pairs)[i].second == -2) { // Root must be color of pair.first int color = (*cluster_pairs)[i].first; - parent = GetRoot(P, color, *roots); + parent = GetRoot(P, color, *roots); if (parent == -1) return 1; - child = GetChild(P, color, parent); + child = GetChild(P, color, parent); } else if ((*cluster_pairs)[i].second == -1) { - int color = (*cluster_pairs)[i].first; - parent = GetRoot(P, color, *roots); + int color = (*cluster_pairs)[i].first; + parent = GetRoot(P, color, *roots); if (parent == -1) return 1; - child = parent; + child = parent; } else { // Root must exist in either first or second element of pair - int color = (*cluster_pairs)[i].first; - parent = GetRoot(P, color, *roots); - color = (parent == -1) ? (*cluster_pairs)[i].second : color; - parent = (parent == -1) ? GetRoot(P, color, *roots) : parent; + int color = (*cluster_pairs)[i].first; + parent = GetRoot(P, color, *roots); + color = (parent == -1) ? (*cluster_pairs)[i].second : color; + parent = (parent == -1) ? GetRoot(P, color, *roots) : parent; int from_cluster = color; int dest_cluster = (from_cluster == (*cluster_pairs)[i].first) ? @@ -534,7 +522,7 @@ inline int KLGenerateBinaryTree(const std::vector& W, int depth = scan_row->size(); int start = (*scan_row)[depth-2]; - int end = (*scan_row)[depth-1]; + int end = (*scan_row)[depth-1]; for (int i = start; i < end; ++i) { int parent = (*topo_row)[i]; @@ -574,19 +562,16 @@ inline int ComputeDepth(int n) { // -each edge in tree corresponds to link in network topology // -each edge in tree does not form self-loop template -inline bool IsValid(const std::vector& W, - const std::vector& state, - int num_elements, - int row, - int depth) { +inline bool IsValid(const std::vector& W, const std::vector& state, + int num_elements, int row, int depth) { // At each level of tree, check whether edge: // -corresponds to link in network topology // -corresponds to self-loop for (int i = 0; i < depth; ++i) { int stride = 1 << i; for (int j = 0; j+stride < row; j += 2*stride) { - int from = state[j]; - int dest = state[j+stride]; + int from = state[j]; + int dest = state[j+stride]; if (W[from*num_elements + dest] == static_cast(0) && from != dest) { return false; } @@ -596,7 +581,7 @@ inline bool IsValid(const std::vector& W, // If we encounter GPU for first time, increment found_vec. // Otherwise, do nothing std::unordered_set found; - std::vector found_vec(num_elements, 0); + std::vector found_vec(num_elements, 0); for (auto val : state) { if (val == -1) continue; @@ -613,7 +598,7 @@ inline bool IsValid(const std::vector& W, // modifier is maximum number of repeats a single GPU can take // e.g. 5 GPUs in 3-level binary tree => one GPU can repeat 3x // GPU0 GPU0 GPU0 GPU0 GPU1 GPU2 GPU3 GPU4 - int modifier = (1 << depth) - num_elements; + int modifier = (1 << depth) - num_elements; int num_found = found.size(); // So we know we have an invalid state if we find: @@ -687,11 +672,8 @@ inline void Postprocess(std::vector* result, int num_elements, int depth) { // -usually turned on when backtracking to get better solutions // -usually turned off when outside the penalty to get weight of tree template -inline T ComputeTreeWeight(const std::vector& W, - const std::vector& result, - int num_elements, - int depth, - bool penalty) { +inline T ComputeTreeWeight(const std::vector& W, const std::vector& result, + int num_elements, int depth, bool penalty) { T weight = 0.f; std::unordered_set links_used; @@ -699,8 +681,8 @@ inline T ComputeTreeWeight(const std::vector& W, int stride = 1 << i; std::vector nodes_used(num_elements, false); for (unsigned j = 0; j+stride < result.size(); j += 2*stride) { - int from = result[j]; - int dest = result[j+stride]; + int from = result[j]; + int dest = result[j+stride]; if (from != dest) { weight += W[from*num_elements+dest]; @@ -741,9 +723,9 @@ inline T ComputeTreeWeight(const std::vector& W, // 3 0 1 5 // 3 3 0 4 1 2 5 6 inline void FormTopology(const std::vector& result, - std::vector* topo_row, - std::vector* scan_row, - int depth) { + std::vector* topo_row, + std::vector* scan_row, + int depth) { scan_row->push_back(topo_row->size()); for (int i = depth; i > 0; --i) { int stride = 1 << i; @@ -766,13 +748,13 @@ inline void FormTopology(const std::vector& result, // -maximum weight template inline bool RecursiveBacktrack(const std::vector& W, - std::vector* state, - std::vector* best_result, - T* best_result_weight, - int row, - int num_elements, - int depth, - bool optimal) { + std::vector* state, + std::vector* best_result, + T* best_result_weight, + int row, + int num_elements, + int depth, + bool optimal) { if (row == static_cast(state->size())) { std::vector result = *state; Postprocess(&result, num_elements, depth); @@ -802,13 +784,13 @@ inline bool RecursiveBacktrack(const std::vector& W, template inline void IterativeBacktrack(const std::vector& W, - std::vector* state, - std::vector* best_result, - T* best_result_weight, - int row, - int num_elements, - int depth, - bool optimal) { + std::vector* state, + std::vector* best_result, + T* best_result_weight, + int row, + int num_elements, + int depth, + bool optimal) { std::stack state_stack; row = 1; int pos = 0; @@ -867,13 +849,11 @@ inline void IterativeBacktrack(const std::vector& W, // Apply penalty factor alpha to each link in link topology graph that is used // by the spanning tree template -inline void UpdateWeight(std::vector* W, - const std::vector& topo_row, - int num_elements, - float alpha) { +inline void UpdateWeight(std::vector* W, const std::vector& topo_row, + int num_elements, float alpha) { for (unsigned i = 1; i < topo_row.size() - 1; i += 2) { unsigned parent = topo_row[i]; - unsigned child = topo_row[i+1]; + unsigned child = topo_row[i+1]; if (!(parent >= num_elements*num_elements || child >= num_elements*num_elements) && (parent != child)) { (*W)[parent*num_elements+child] *= alpha; @@ -890,9 +870,9 @@ inline void UpdateWeight(std::vector* W, // 2) maximize edge weight // 3) tree is binary template -inline void BacktrackGenerateBinaryTree(std::vector* W, - int num_elements, - int root, +inline void BacktrackGenerateBinaryTree(std::vector* W, + int num_elements, + int root, std::vector* topo_row, std::vector* scan_row) { // Clear before starting @@ -933,11 +913,11 @@ inline void BacktrackGenerateBinaryTree(std::vector* W, // ComputeTreesFromRoot does the same thing as ComputeTrees, with the only // exception being it will do it from a fixed GPU as root template -inline void ComputeTreesFromRoot(std::vector* W, - int num_elements, - int root, - float alpha, - bool backtrack, +inline void ComputeTreesFromRoot(std::vector* W, + int num_elements, + int root, + float alpha, + bool backtrack, std::vector* topo, std::vector* scan) { int num_partitions = 1; @@ -977,11 +957,11 @@ inline void ComputeTreesFromRoot(std::vector* W, while (!backtrack && (!stop || reset)) { if (reset == 1) { cluster_pairs.clear(); - P_temp = P; + P_temp = P; num_partitions_temp = num_partitions; - roots_temp = roots; - topo_temp = *topo; - scan_temp = *scan; + roots_temp = roots; + topo_temp = *topo; + scan_temp = *scan; } // Run Kernighan-Lin to generate partition @@ -1020,10 +1000,10 @@ inline void ComputeTreesFromRoot(std::vector* W, // @output: topo stores the trees generated // scan stores the start of each level of each tree template -inline void ComputeTrees(const std::vector& W, - int num_elements, - float alpha, - bool backtrack, +inline void ComputeTrees(const std::vector& W, + int num_elements, + float alpha, + bool backtrack, std::vector>* topo, std::vector>* scan) { std::vector W_copy = W; @@ -1056,11 +1036,13 @@ inline void ComputeTrees(const std::vector& W, std::vector> topo_temp(num_elements, std::vector()); - /*for (int i = 0; i < num_elements; ++i) - PrintTopo("Topo", topo[i], scan[i]); + if (kLogTree) { + for (int i = 0; i < num_elements; ++i) + PrintTopo("Topo", (*topo)[i], (*scan)[i]); - PrintMatrix("W", W, num_elements, num_elements); - PrintMatrix("Links", adj, num_elements, num_elements);*/ + PrintMatrix("W", W, num_elements, num_elements); + PrintMatrix("Links", adj, num_elements, num_elements); + } } } // namespace kvstore } // namespace mxnet diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index 70a776e2d76c..7f7fddf03348 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -30,6 +30,7 @@ shape = (4, 4) keys = [5, 7, 11] str_keys = ['b', 'c', 'd'] +logging.basicConfig(level=logging.INFO) class EnvManager: def __init__(self, key, val): @@ -41,7 +42,7 @@ def __enter__(self): try: self._prev_val = os.environ[self._key] except KeyError: - self._prev_val = "" + self._prev_val = '' os.environ[self._key] = self._next_val def __exit__(self, ptype, value, trace): @@ -56,6 +57,47 @@ def init_kv_with_str(stype='default', kv_type='local'): kv.init(str_keys, [mx.nd.zeros(shape=shape, stype=stype)] * len(keys)) return kv +def test_dense_push_pull(): + shapes = [(1026), (1,2,3,4,5,6,7,8)] + keys = [1,2,3,4,5,6,7] + + def check_dense_push_pull(kv_type): + def check_dense_pull(kv_type, ctxs): + n = 0 + n_devs = len(ctxs) + for context in ctxs: + kv = mx.kv.create(kv_type) + a = mx.nd.ones(shape, ctxs[0]) + cur_key = str(key*n_devs+n) + kv.init(cur_key, a) + arr_list = [mx.nd.ones(shape, ctx=context) for context in ctxs] + res = [mx.nd.zeros(shape, ctx=context) for context in ctxs] + kv.push(cur_key, arr_list) + kv.pull(cur_key, res) + n += 1 + for x in range(n_devs): + #if np.sum(np.abs((res[x]-n_devs).asnumpy()))!=0: + print(x, (res[x]-n_devs).asnumpy()) + assert(np.sum(np.abs((res[x]-n_devs).asnumpy()))==0) + + for key in keys: + check_dense_pull(kv_type, [mx.gpu(0)]) + check_dense_pull(kv_type, [mx.cpu(0)]) + check_dense_pull(kv_type, [mx.gpu(i) for i in range(4)]) + check_dense_pull(kv_type, [mx.cpu(i) for i in range(4)]) + + key1 = 'MXNET_KVSTORE_GPUARRAY_BOUND' + envs2 = ['','1'] + key2 = 'MXNET_KVSTORE_USETREE' + for i in range(2): + for val2 in envs2: + with EnvManager(key2, val2): + check_dense_push_pull('local') + check_dense_push_pull('device') + + os.environ[key1] = '0' + os.environ[key1] = '' + # Test seed 89411477 (module seed 1829754103) resulted in a py3-gpu CI runner core dump. # Not reproducible, so this test is back on random seeds. @with_seed() @@ -102,17 +144,16 @@ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False): check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], use_slice=True) check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], use_slice=True) - # test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/9384 - # check_rsp_push_pull('local') envs = ["","1"] key = "MXNET_KVSTORE_USETREE" for val in envs: with EnvManager(key, val): + print('done') check_rsp_push_pull('local') check_rsp_push_pull('device') check_rsp_push_pull('device', is_push_cpu=False) - +@with_seed() def test_row_sparse_pull_single_device(): envs = ["","1"] key = "MXNET_KVSTORE_USETREE" @@ -130,7 +171,7 @@ def test_row_sparse_pull_single_device(): assert_almost_equal(grad.asnumpy(), copy.asnumpy()) - +@with_seed() def test_rsp_push_pull_large_rowid(): envs = ["","1"] key = "MXNET_KVSTORE_USETREE" From 18c1700aaaa0985ccf6bc8160785cb028a4a67dc Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Sat, 21 Jul 2018 00:07:21 +0000 Subject: [PATCH 26/36] fix a few bugs --- src/kvstore/comm_tree.h | 29 +++++- src/kvstore/gpu_topology.h | 16 +++- tests/python/gpu/test_kvstore_gpu.py | 133 +++++++++------------------ tests/python/gpu/test_nccl.py | 51 +++++++++- 4 files changed, 132 insertions(+), 97 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index c5ba97ba52f8..51e8c71371f8 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -119,20 +119,27 @@ class CommDeviceTree : public CommDevice { for (int level = depth_; level > 0; --level) { int start = scan_[root][level ]; int end = scan_[root][level+1]; + //LOG(WARNING) << "Reduce level: " << level; + //LOG(WARNING) << "From " << start << " to " << end; unsigned is_dest = 0; int dest_id = 0; for (int j = start; j < end; ++j) { int topo_id = topology[j]; dest_id = (is_dest == 0) ? topo_id : dest_id; + //LOG(WARNING) << is_dest << ": " << topo_id << " -> " << dest_id; TreeBufferEntry& buf_dest = tree_merge_buf_[dest_id][key]; TreeBufferEntry& buf_from = tree_merge_buf_[topo_id][key]; if (!is_dest) { - reduce[dest_id].push_back(buf_dest.merged[merged_row]); + if (reduce[dest_id].size() == 0) { + reduce[dest_id].push_back(buf_dest.merged[merged_row]); + //LOG(WARNING) << topo_id << " == " << dest_id; + } } else { if (dest_id != topo_id) { + //LOG(WARNING) << "Reduce from " << dest_id << " " << topo_id; CopyFromTo(buf_from.merged[merged_row], &(buf_dest.copy_buf[merged_row][is_dest-1]), priority); @@ -149,11 +156,16 @@ class CommDeviceTree : public CommDevice { end = scan_[root][level]; for (int i = start; i < end; ++i) { int gpu_id = topology[i]; + //LOG(WARNING) << "Doing reduce on GPU" << gpu_id; + //LOG(WARNING) << "With #elems " << reduce[gpu_id].size(); // conditional to detect whether operation must be done if (reduce[gpu_id].size() > 1) { TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; ElementwiseSum(reduce[gpu_id], &(buf.merged[merged_row]), priority); + //LOG(WARNING) << "reduce input 1 " << reduce[gpu_id][0].ctx(); + //LOG(WARNING) << "reduce input 2 " << reduce[gpu_id][1].ctx(); + //LOG(WARNING) << "buf.mg output " << buf.merged[merged_row].ctx(); } } @@ -196,7 +208,7 @@ class CommDeviceTree : public CommDevice { const NDArrayStorageType stype = src[0].storage_type(); // normal dense reduce if (stype == kDefaultStorage) { - if (total_size > gpuarray_bound_ && first_size >= devs_.size()) { + if (total_size > gpuarray_bound_ && first_size >= 2*devs_.size()) { // Find slice bounds slice_scan[0] = 0; int slice_size = (first_size + devs_.size()-1)/devs_.size(); @@ -204,6 +216,7 @@ class CommDeviceTree : public CommDevice { slice_scan[i] = slice_scan[i-1] + slice_size; } slice_scan[devs_.size()] = src[0].shape()[0]; + LOG(WARNING) << "Using multiple tree"; // row: which slice // col: which gpu @@ -228,6 +241,7 @@ class CommDeviceTree : public CommDevice { BroadcastInner(key, *(broadcast_slice[i][i]), broadcast_slice[i], i, i, priority); } } else { + LOG(WARNING) << "Using single tree"; int root = 0; ReduceInner(key, src, root, 0, priority); @@ -240,6 +254,7 @@ class CommDeviceTree : public CommDevice { return src[gpu_id]; } else { // sparse reduce + //LOG(WARNING) << "Using sparse reduce"; return ReduceRowSparse(key, src, priority); } } @@ -266,6 +281,7 @@ class CommDeviceTree : public CommDevice { src_id = (is_src == 0) ? topo_id : src_id; if (is_src && src_id != topo_id) { + //LOG(WARNING) << "Broadcast from " << src_id << " " << topo_id; CopyFromTo(temp[src_id], dst[topo_id], priority); temp[topo_id] = *dst[topo_id]; } @@ -289,7 +305,10 @@ class CommDeviceTree : public CommDevice { } else { int total_size = src.shape().Size(); unsigned first_size = src.shape()[0]; - if (total_size > gpuarray_bound_ && first_size >= devs_.size()) { + const NDArrayStorageType stype = src.storage_type(); + // normal dense reduce + if (stype == kDefaultStorage) { + if (total_size > gpuarray_bound_ && first_size >= 2*devs_.size()) { std::vector slice_scan(devs_.size()+1); slice_scan[0] = 0; int slice_size = (dst[0]->shape()[0]+devs_.size()-1)/devs_.size(); @@ -310,6 +329,8 @@ class CommDeviceTree : public CommDevice { } else { int root = 0; BroadcastInner(key, src, dst, root, -1, priority); + }} else { + LOG(FATAL) << "Only dense input supported for now"; } } } @@ -418,7 +439,7 @@ class CommDeviceTree : public CommDevice { TShape shape_copy = shape; int total_size = shape.Size(); unsigned first_size = shape[0]; - if (total_size > gpuarray_bound_ && first_size >= devs_.size()) { + if (total_size > gpuarray_bound_ && first_size >= 2*devs_.size()) { // Find slice bounds int slice_size = (first_size+devs_.size()-1)/devs_.size(); int last_slice = first_size-(devs_.size()-1)*slice_size; diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 3b67d2395082..912056248541 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -163,15 +163,27 @@ inline void GetP2PWeight(const std::vector& devs, std::vector* matri max_value = max[i]; } - // If all GPUs are connected by NVLink, then we can use NVLink only - // to communicate instead of going over PCI-E + // If all GPUs are connected by NVLink, then we can use NVLink only + // to communicate instead of going over PCI-E, so we set PCI-E links to 0 + // + // Otherwise, we will make distinction between PCI-E GPUDirect links and + // PCI-E through CPU links, which are slower and show queueing effect (i.e. + // The most packets there are, the slower). + // + // For the latter links, we will set links that were 0 to 1/num_gpus to + // account for this queuing effect. bool connected = IsConnected(*matrix, num_gpus); if (connected) { for (auto& matrix_value : *matrix) { matrix_value = (matrix_value == 1) ? 0 : matrix_value; } + } else { + for (auto& matrix_value : *matrix) { + matrix_value = (matrix_value == 1) ? 1./num_gpus : matrix_value; + } } + if (kLogTree) PrintMatrix("Weight W", *matrix, num_gpus, num_gpus); #else diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index dd02ff0ecd12..a38b9db64f25 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -21,7 +21,6 @@ import mxnet as mx import numpy as np import unittest -import logging from mxnet.test_utils import assert_almost_equal, default_context curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) @@ -30,7 +29,6 @@ shape = (4, 4) keys = [5, 7, 11] str_keys = ['b', 'c', 'd'] -logging.basicConfig(level=logging.INFO) class EnvManager: def __init__(self, key, val): @@ -57,58 +55,17 @@ def init_kv_with_str(stype='default', kv_type='local'): kv.init(str_keys, [mx.nd.zeros(shape=shape, stype=stype)] * len(keys)) return kv -def test_dense_push_pull(): - shapes = [(1026), (1,2,3,4,5,6,7,8)] - keys = [1,2,3,4,5,6,7] - - def check_dense_push_pull(kv_type): - def check_dense_pull(kv_type, ctxs): - n = 0 - n_devs = len(ctxs) - for context in ctxs: - kv = mx.kv.create(kv_type) - a = mx.nd.ones(shape, ctxs[0]) - cur_key = str(key*n_devs+n) - kv.init(cur_key, a) - arr_list = [mx.nd.ones(shape, ctx=context) for context in ctxs] - res = [mx.nd.zeros(shape, ctx=context) for context in ctxs] - kv.push(cur_key, arr_list) - kv.pull(cur_key, res) - n += 1 - for x in range(n_devs): - #if np.sum(np.abs((res[x]-n_devs).asnumpy()))!=0: - print(x, (res[x]-n_devs).asnumpy()) - assert(np.sum(np.abs((res[x]-n_devs).asnumpy()))==0) - - for key in keys: - check_dense_pull(kv_type, [mx.gpu(0)]) - check_dense_pull(kv_type, [mx.cpu(0)]) - check_dense_pull(kv_type, [mx.gpu(i) for i in range(4)]) - check_dense_pull(kv_type, [mx.cpu(i) for i in range(4)]) - - key1 = 'MXNET_KVSTORE_GPUARRAY_BOUND' - envs2 = ['','1'] - key2 = 'MXNET_KVSTORE_USETREE' - for i in range(2): - for val2 in envs2: - with EnvManager(key2, val2): - check_dense_push_pull('local') - check_dense_push_pull('device') - - os.environ[key1] = '0' - os.environ[key1] = '' - # Test seed 89411477 (module seed 1829754103) resulted in a py3-gpu CI runner core dump. # Not reproducible, so this test is back on random seeds. @with_seed() def test_rsp_push_pull(): - def check_rsp_push_pull(kv_type, is_push_cpu=True): + def check_rsp_push_pull(kv_type, sparse_pull, is_push_cpu=True): kv = init_kv_with_str('row_sparse', kv_type) kv.init('e', mx.nd.ones(shape).tostype('row_sparse')) push_ctxs = [mx.cpu(i) if is_push_cpu else mx.gpu(i) for i in range(2)] kv.push('e', [mx.nd.ones(shape, ctx=context).tostype('row_sparse') for context in push_ctxs]) - def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False): + def check_rsp_pull(kv, count, ctxs, sparse_pull, is_same_rowid=False, use_slice=False): num_rows = shape[0] row_ids = [] all_row_ids = np.arange(num_rows) @@ -135,63 +92,59 @@ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False): expected_val += 0 if row in excluded_row_ids else 2 assert_almost_equal(retained[row], expected_val) - kv.pull('e', out=vals_to_pull, ignore_sparse=False) - for val in vals: - retained = val.asnumpy() - expected_val = np.zeros_like(retained) - expected_val[:] = 2 - assert_almost_equal(retained, expected_val) - - check_rsp_pull(kv, 1, [mx.gpu(0)]) - check_rsp_pull(kv, 1, [mx.cpu(0)]) - check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)]) - check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], is_same_rowid=True) - check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)]) - check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], is_same_rowid=True) - check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], use_slice=True) - check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], use_slice=True) + if sparse_pull is True: + kv.pull('e', out=vals_to_pull, ignore_sparse=False) + for val in vals: + retained = val.asnumpy() + expected_val = np.zeros_like(retained) + expected_val[:] = 2 + assert_almost_equal(retained, expected_val) + + check_rsp_pull(kv, 1, [mx.gpu(0)], sparse_pull) + check_rsp_pull(kv, 1, [mx.cpu(0)], sparse_pull) + check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], sparse_pull) + check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], sparse_pull, is_same_rowid=True) + check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], sparse_pull) + check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], sparse_pull, is_same_rowid=True) + check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], sparse_pull, use_slice=True) + check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], sparse_pull, use_slice=True) envs = ["","1"] key = "MXNET_KVSTORE_USETREE" for val in envs: with EnvManager(key, val): print('done') - check_rsp_push_pull('local') - check_rsp_push_pull('device') - check_rsp_push_pull('device', is_push_cpu=False) + if val is "1": + sparse_pull = False + else: + sparse_pull = True + check_rsp_push_pull('local', sparse_pull) + check_rsp_push_pull('device', sparse_pull) + check_rsp_push_pull('device', sparse_pull, is_push_cpu=False) -@with_seed() def test_row_sparse_pull_single_device(): - envs = ["","1"] - key = "MXNET_KVSTORE_USETREE" - for val in envs: - with EnvManager(key, val): - kvstore = mx.kv.create('device') - copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0)) - grad = copy.tostype("row_sparse") + kvstore = mx.kv.create('device') + copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0)) + grad = copy.tostype("row_sparse") - k = 0 - kvstore.init(k, grad) - idx = grad.indices - kvstore.push(k, grad) - kvstore.row_sparse_pull(k, out=grad, row_ids=idx) + key = 0 + kvstore.init(key, grad) + idx = grad.indices + kvstore.push(key, grad) + kvstore.row_sparse_pull(key, out=grad, row_ids=idx) + + assert_almost_equal(grad.asnumpy(), copy.asnumpy()) - assert_almost_equal(grad.asnumpy(), copy.asnumpy()) -@with_seed() def test_rsp_push_pull_large_rowid(): - envs = ["","1"] - key = "MXNET_KVSTORE_USETREE" - for val in envs: - with EnvManager(key, val): - num_rows = 793470 - val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu()) - kv = mx.kv.create('device') - kv.init('a', val) - out = mx.nd.zeros((num_rows,1), stype='row_sparse').copyto(mx.gpu()) - kv.push('a', val) - kv.row_sparse_pull('a', out=out, row_ids=mx.nd.arange(0, num_rows, dtype='int64')) - assert(out.indices.shape[0] == num_rows) + num_rows = 793470 + val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu()) + kv = mx.kv.create('device') + kv.init('a', val) + out = mx.nd.zeros((num_rows,1), stype='row_sparse').copyto(mx.gpu()) + kv.push('a', val) + kv.row_sparse_pull('a', out=out, row_ids=mx.nd.arange(0, num_rows, dtype='int64')) + assert(out.indices.shape[0] == num_rows) if __name__ == '__main__': import nose diff --git a/tests/python/gpu/test_nccl.py b/tests/python/gpu/test_nccl.py index 8e00ba05f7f6..51dfe786e937 100644 --- a/tests/python/gpu/test_nccl.py +++ b/tests/python/gpu/test_nccl.py @@ -18,6 +18,7 @@ import mxnet as mx import numpy as np import unittest +import os shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)] keys = [1,2,3,4,5,6,7] @@ -29,7 +30,54 @@ print("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.") num_gpus = 8; -gpus = range(1,1+num_gpus) +gpus = range(1, 1+num_gpus) + +class EnvManager: + def __init__(self, key, val): + self._key = key + self._next_val = val + self._prev_val = None + + def __enter__(self): + try: + self._prev_val = os.environ[self._key] + except KeyError: + self._prev_val = '' + os.environ[self._key] = self._next_val + + def __exit__(self, ptype, value, trace): + os.environ[self._key] = self._prev_val + +def test_device_pushpull(): + def check_dense_pushpull(kv_type): + for shape, key in zip(shapes, keys): + for n_gpus in gpus: + kv_device = mx.kv.create(kv_type) + a = mx.nd.ones(shape, mx.gpu(0)) + cur_key = str(key*max(gpus)+n_gpus) + kv_device.init(cur_key, a) + arr_list = [mx.nd.ones(shape, mx.gpu(x)) for x in range(n_gpus)] + res = [mx.nd.zeros(shape, mx.gpu(x)) for x in range(n_gpus)] + kv_device.push(cur_key, arr_list) + kv_device.pull(cur_key, res) + for x in range(n_gpus): + if np.sum(np.abs((res[x]-n_gpus).asnumpy()))!=0: + print(shape, key, n_gpus, x, (res[x]-n_gpus).asnumpy()) + assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0) + + envs1 = '1' + key1 = 'MXNET_KVSTORE_GPUARRAY_BOUND' + envs2 = ['','1'] + key2 = 'MXNET_KVSTORE_USETREE' + for i in range(2): + for val2 in envs2: + with EnvManager(key2, val2): + print(i, val2) + #check_dense_pushpull('local') + check_dense_pushpull('device') + + os.environ[key1] = envs1 + os.environ[key1] = '' @unittest.skip("Test requires NCCL library installed and enabled during build") def test_nccl_pushpull(): @@ -49,4 +97,5 @@ def test_nccl_pushpull(): print ("Passed") if __name__ == '__main__': + test_device_pushpull() test_nccl_pushpull() From c65a6202f264c57d81c78be5c039f83b21381017 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Sat, 21 Jul 2018 18:49:08 +0000 Subject: [PATCH 27/36] get rid of printfs --- src/kvstore/comm_tree.h | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 51e8c71371f8..54732d8a2d8d 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -119,15 +119,12 @@ class CommDeviceTree : public CommDevice { for (int level = depth_; level > 0; --level) { int start = scan_[root][level ]; int end = scan_[root][level+1]; - //LOG(WARNING) << "Reduce level: " << level; - //LOG(WARNING) << "From " << start << " to " << end; unsigned is_dest = 0; int dest_id = 0; for (int j = start; j < end; ++j) { int topo_id = topology[j]; dest_id = (is_dest == 0) ? topo_id : dest_id; - //LOG(WARNING) << is_dest << ": " << topo_id << " -> " << dest_id; TreeBufferEntry& buf_dest = tree_merge_buf_[dest_id][key]; TreeBufferEntry& buf_from = tree_merge_buf_[topo_id][key]; @@ -135,11 +132,9 @@ class CommDeviceTree : public CommDevice { if (!is_dest) { if (reduce[dest_id].size() == 0) { reduce[dest_id].push_back(buf_dest.merged[merged_row]); - //LOG(WARNING) << topo_id << " == " << dest_id; } } else { if (dest_id != topo_id) { - //LOG(WARNING) << "Reduce from " << dest_id << " " << topo_id; CopyFromTo(buf_from.merged[merged_row], &(buf_dest.copy_buf[merged_row][is_dest-1]), priority); @@ -156,16 +151,11 @@ class CommDeviceTree : public CommDevice { end = scan_[root][level]; for (int i = start; i < end; ++i) { int gpu_id = topology[i]; - //LOG(WARNING) << "Doing reduce on GPU" << gpu_id; - //LOG(WARNING) << "With #elems " << reduce[gpu_id].size(); // conditional to detect whether operation must be done if (reduce[gpu_id].size() > 1) { TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; ElementwiseSum(reduce[gpu_id], &(buf.merged[merged_row]), priority); - //LOG(WARNING) << "reduce input 1 " << reduce[gpu_id][0].ctx(); - //LOG(WARNING) << "reduce input 2 " << reduce[gpu_id][1].ctx(); - //LOG(WARNING) << "buf.mg output " << buf.merged[merged_row].ctx(); } } @@ -216,7 +206,6 @@ class CommDeviceTree : public CommDevice { slice_scan[i] = slice_scan[i-1] + slice_size; } slice_scan[devs_.size()] = src[0].shape()[0]; - LOG(WARNING) << "Using multiple tree"; // row: which slice // col: which gpu @@ -241,7 +230,6 @@ class CommDeviceTree : public CommDevice { BroadcastInner(key, *(broadcast_slice[i][i]), broadcast_slice[i], i, i, priority); } } else { - LOG(WARNING) << "Using single tree"; int root = 0; ReduceInner(key, src, root, 0, priority); @@ -254,7 +242,6 @@ class CommDeviceTree : public CommDevice { return src[gpu_id]; } else { // sparse reduce - //LOG(WARNING) << "Using sparse reduce"; return ReduceRowSparse(key, src, priority); } } @@ -281,7 +268,6 @@ class CommDeviceTree : public CommDevice { src_id = (is_src == 0) ? topo_id : src_id; if (is_src && src_id != topo_id) { - //LOG(WARNING) << "Broadcast from " << src_id << " " << topo_id; CopyFromTo(temp[src_id], dst[topo_id], priority); temp[topo_id] = *dst[topo_id]; } From 628ba6e2a98d45bceb913ef28434e8fcad1c527a Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Sat, 21 Jul 2018 19:35:19 +0000 Subject: [PATCH 28/36] get rid of print --- tests/python/gpu/test_kvstore_gpu.py | 1 - tests/python/gpu/test_nccl.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index a38b9db64f25..5e9b120f14ee 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -113,7 +113,6 @@ def check_rsp_pull(kv, count, ctxs, sparse_pull, is_same_rowid=False, use_slice= key = "MXNET_KVSTORE_USETREE" for val in envs: with EnvManager(key, val): - print('done') if val is "1": sparse_pull = False else: diff --git a/tests/python/gpu/test_nccl.py b/tests/python/gpu/test_nccl.py index 51dfe786e937..ee822784966f 100644 --- a/tests/python/gpu/test_nccl.py +++ b/tests/python/gpu/test_nccl.py @@ -61,8 +61,6 @@ def check_dense_pushpull(kv_type): kv_device.push(cur_key, arr_list) kv_device.pull(cur_key, res) for x in range(n_gpus): - if np.sum(np.abs((res[x]-n_gpus).asnumpy()))!=0: - print(shape, key, n_gpus, x, (res[x]-n_gpus).asnumpy()) assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0) envs1 = '1' @@ -72,8 +70,7 @@ def check_dense_pushpull(kv_type): for i in range(2): for val2 in envs2: with EnvManager(key2, val2): - print(i, val2) - #check_dense_pushpull('local') + check_dense_pushpull('local') check_dense_pushpull('device') os.environ[key1] = envs1 From a0e1366eb4697a819ad54ce01add0150d18d595f Mon Sep 17 00:00:00 2001 From: ctcyang Date: Sun, 22 Jul 2018 23:00:04 -0700 Subject: [PATCH 29/36] Comment out test for now --- tests/cpp/kvstore/gpu_topology_test.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index 1fba05ef64cf..18a574e7e4e0 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -554,7 +554,8 @@ TEST(GpuTopology, TestIsConnected3) { } // ComputeTreesTest with backtracking -TEST(GpuTopology, TestComputeTrees1) { +// TODO(carlyang): comment out test for now +/*TEST(GpuTopology, TestComputeTrees1) { std::mt19937 gen(1); float alpha = 0.7; bool backtrack = true; @@ -577,7 +578,7 @@ TEST(GpuTopology, TestComputeTrees2) { TestComputeTreesRandomized(num_gpus, alpha, backtrack, &gen); } } -} +}*/ TEST(GpuTopology, TestPermuteMatrix) { std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, From 63fd14ee1185e2f894cd57f3c4e09dd6213f5cf1 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 23 Jul 2018 07:01:33 +0000 Subject: [PATCH 30/36] fix 2 more bugs --- src/kvstore/comm_tree.h | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 51e8c71371f8..29db0292bba4 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -154,13 +154,19 @@ class CommDeviceTree : public CommDevice { start = scan_[root][level-1]; end = scan_[root][level]; + int source = end; for (int i = start; i < end; ++i) { int gpu_id = topology[i]; //LOG(WARNING) << "Doing reduce on GPU" << gpu_id; //LOG(WARNING) << "With #elems " << reduce[gpu_id].size(); + // source keeps track of 2 leaf nodes, while start keeps track of parent + int dest_id = topology[source]; + int from_id = topology[source+1]; + source += 2; + // conditional to detect whether operation must be done - if (reduce[gpu_id].size() > 1) { + if (reduce[gpu_id].size() > 1 && dest_id != from_id) { TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; ElementwiseSum(reduce[gpu_id], &(buf.merged[merged_row]), priority); //LOG(WARNING) << "reduce input 1 " << reduce[gpu_id][0].ctx(); @@ -211,7 +217,7 @@ class CommDeviceTree : public CommDevice { if (total_size > gpuarray_bound_ && first_size >= 2*devs_.size()) { // Find slice bounds slice_scan[0] = 0; - int slice_size = (first_size + devs_.size()-1)/devs_.size(); + int slice_size = first_size/devs_.size(); for (unsigned i = 1; i < devs_.size(); ++i) { slice_scan[i] = slice_scan[i-1] + slice_size; } @@ -311,7 +317,7 @@ class CommDeviceTree : public CommDevice { if (total_size > gpuarray_bound_ && first_size >= 2*devs_.size()) { std::vector slice_scan(devs_.size()+1); slice_scan[0] = 0; - int slice_size = (dst[0]->shape()[0]+devs_.size()-1)/devs_.size(); + int slice_size = (dst[0]->shape()[0])/devs_.size(); for (unsigned i = 1; i < devs_.size(); ++i) { slice_scan[i] = slice_scan[i-1] + slice_size; } @@ -441,7 +447,7 @@ class CommDeviceTree : public CommDevice { unsigned first_size = shape[0]; if (total_size > gpuarray_bound_ && first_size >= 2*devs_.size()) { // Find slice bounds - int slice_size = (first_size+devs_.size()-1)/devs_.size(); + int slice_size = first_size/devs_.size(); int last_slice = first_size-(devs_.size()-1)*slice_size; shape_copy[0] = slice_size; buf.merged.resize(devs_.size()); From 9f5c24a347653b3b10f7d41e6991ca81b3e12825 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Mon, 23 Jul 2018 23:42:58 +0000 Subject: [PATCH 31/36] fix segfault --- src/kvstore/gpu_topology.h | 51 +++++++++++++++++++------- tests/cpp/kvstore/gpu_topology_test.cc | 14 +++++-- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 912056248541..1d82a391c79a 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -734,10 +734,18 @@ inline T ComputeTreeWeight(const std::vector& W, const std::vector& resu // 3 1 // 3 0 1 5 // 3 3 0 4 1 2 5 6 -inline void FormTopology(const std::vector& result, +// +// Returns false if invalid tree in result +// Otherwise returns true +inline bool FormTopology(const std::vector& result, std::vector* topo_row, std::vector* scan_row, int depth) { + //PrintVector("Best result", result); + for (unsigned i = 0; i < result.size(); ++i) + if (result[i] == -1) + return false; + scan_row->push_back(topo_row->size()); for (int i = depth; i > 0; --i) { int stride = 1 << i; @@ -751,6 +759,7 @@ inline void FormTopology(const std::vector& result, // Insert at the end, result vector topo_row->insert(topo_row->end(), result.begin(), result.end()); scan_row->push_back(topo_row->size()); + return true; } // Recursive function that finds a spanning tree, which fulfills the following @@ -813,6 +822,8 @@ inline void IterativeBacktrack(const std::vector& W, // a) if stack is empty, break and stop search // b) if stack is not empty, pop stack and set current position to next // position backtrack to previous row + //LOG(WARNING) << "Stack size: " << state_stack.size(); + //PrintVector("state", *state); while (!state_stack.empty() && pos >= num_elements) { pos = state_stack.top(); pos++; @@ -838,8 +849,11 @@ inline void IterativeBacktrack(const std::vector& W, // Pop stack, set current position to next position // Backtrack to find next solution if (row == static_cast(state->size())) { + //LOG(WARNING) << row << " == " << state->size(); std::vector result = *state; + //PrintVector("Best result", result); Postprocess(&result, num_elements, depth); + //PrintVector("Best result", result); T weight = ComputeTreeWeight(W, result, num_elements, depth, true); // Save this spanning tree if it is highest weight tree found so far @@ -852,7 +866,8 @@ inline void IterativeBacktrack(const std::vector& W, pos = state_stack.top(); pos++; state_stack.pop(); - (*state)[state_stack.size()+1] = -1; + //LOG(WARNING) << "Setting " << state_stack.size() << " to 1"; + (*state)[state_stack.size()] = -1; row--; } } @@ -882,7 +897,7 @@ inline void UpdateWeight(std::vector* W, const std::vector& topo_row, // 2) maximize edge weight // 3) tree is binary template -inline void BacktrackGenerateBinaryTree(std::vector* W, +inline bool BacktrackGenerateBinaryTree(std::vector* W, int num_elements, int root, std::vector* topo_row, @@ -893,13 +908,14 @@ inline void BacktrackGenerateBinaryTree(std::vector* W, // Compute depth // num_elements: depth - // 5: 3 - // 6: 3 - // 7: 3 - // 8: 3 - // 9: 4 + // 5: 3 8 + // 6: 3 8 + // 7: 3 8 + // 8: 3 8 + // 9: 4 16 int depth = ComputeDepth(num_elements); int depth_leaves = 1 << depth; + //LOG(WARNING) << num_elements << " " << depth << " " << depth_leaves; // State vector // -1 means unplaced @@ -913,13 +929,15 @@ inline void BacktrackGenerateBinaryTree(std::vector* W, // Seek optimal solution until depth <= 3 i.e. 8 GPUs // For larger numbers of GPUs, settle for first tree found (non-optimal), but // this saves a lot of runtime, because Backtrack is exponential time - if (depth <= 3) + if (depth <= 3) { IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, depth, true); - else + } else { IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, depth, false); - FormTopology(result, topo_row, scan_row, depth); + } + //LOG(WARNING) << "Exit Iterative backtrack " << num_elements; + return FormTopology(result, topo_row, scan_row, depth); } // ComputeTreesFromRoot does the same thing as ComputeTrees, with the only @@ -991,16 +1009,22 @@ inline void ComputeTreesFromRoot(std::vector* W, if (level > 10) break; } + //LOG(WARNING) << "ComputeFromRoot: " << num_elements; + + bool success = true; if (reset == 1) { // if (!backtrack) // LOG(WARNING) << "No valid binary tree found from root " << root << ", try backtracking"; - BacktrackGenerateBinaryTree(W, num_elements, root, topo, scan); + success = BacktrackGenerateBinaryTree(W, num_elements, root, topo, scan); } else { *topo = topo_temp; *scan = scan_temp; scan->push_back(topo->size()); } - UpdateWeight(W, *topo, num_elements, alpha); + if (success) + UpdateWeight(W, *topo, num_elements, alpha); + else + LOG(FATAL) << "No valid binary tree found from root " << root << " using backtracking"; } // ComputeTrees computes balanced binary spanning trees of maximum edge weight @@ -1022,6 +1046,7 @@ inline void ComputeTrees(const std::vector& W, topo->clear(); scan->clear(); + //LOG(WARNING) << "ComputeTrees: " << num_elements; for (int i = 0; i < num_elements; ++i) { topo->push_back(std::vector()); scan->push_back(std::vector()); diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index 18a574e7e4e0..054cfded92bb 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -77,6 +77,8 @@ void TestComputeTreesRandomized(int num_gpus, float alpha, int backtrack, GenerateMatrix(&W, num_gpus, k, gen); satisfied = IsSatisfactory(W, num_gpus, depth); } + //LOG(WARNING) << "num_gpus: " << num_gpus; + //mxnet::kvstore::PrintMatrix("Link topo", W, num_gpus, num_gpus); std::vector> topo; std::vector> scan; @@ -84,7 +86,9 @@ void TestComputeTreesRandomized(int num_gpus, float alpha, int backtrack, unsigned correct_topo_size = (1 << (depth + 1)) - 1; unsigned correct_scan_size = depth+2; - for (int i = 0; i < num_gpus; ++i) { + //LOG(WARNING) << topo.size() << " " << num_gpus; + ASSERT_EQ(topo.size(), static_cast(num_gpus)); + for (int i = 0; i < topo.size(); ++i) { ASSERT_EQ(correct_topo_size, topo[i].size()); ASSERT_EQ(correct_scan_size, scan[i].size()); } @@ -188,6 +192,8 @@ TEST(GpuTopology, TestPostprocess) { } TEST(GpuTopology, TestDepth) { + ASSERT_EQ(mxnet::kvstore::ComputeDepth(2), 1); + ASSERT_EQ(mxnet::kvstore::ComputeDepth(3), 2); ASSERT_EQ(mxnet::kvstore::ComputeDepth(8), 3); ASSERT_EQ(mxnet::kvstore::ComputeDepth(7), 3); ASSERT_EQ(mxnet::kvstore::ComputeDepth(5), 3); @@ -555,12 +561,13 @@ TEST(GpuTopology, TestIsConnected3) { // ComputeTreesTest with backtracking // TODO(carlyang): comment out test for now -/*TEST(GpuTopology, TestComputeTrees1) { +TEST(GpuTopology, TestComputeTrees1) { std::mt19937 gen(1); float alpha = 0.7; bool backtrack = true; // Do 5 randomized tests per GPU count from 2 to 16 for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { + LOG(WARNING) << "Testing " << num_gpus << " x " << num_gpus; for (int i = 0; i < 5; ++i) { TestComputeTreesRandomized(num_gpus, alpha, backtrack, &gen); } @@ -574,11 +581,12 @@ TEST(GpuTopology, TestComputeTrees2) { bool backtrack = false; // Do 5 randomized tests per GPU count from 2 to 16 for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { + LOG(WARNING) << "Testing " << num_gpus << " x " << num_gpus; for (int i = 0; i < 5; ++i) { TestComputeTreesRandomized(num_gpus, alpha, backtrack, &gen); } } -}*/ +} TEST(GpuTopology, TestPermuteMatrix) { std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, From 9cc24d0bffeade8de3a76863eba5f2eecb497100 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Tue, 24 Jul 2018 00:49:59 +0000 Subject: [PATCH 32/36] change PrintVector, PrintTopo, PrintMatrix to LOG(INFO) instead of stdout --- docs/faq/env_var.md | 25 +++ src/kvstore/comm_tree.h | 36 ++-- src/kvstore/gpu_topology.h | 218 +++++++++++++------------ tests/cpp/kvstore/gpu_topology_test.cc | 9 +- tests/python/gpu/test_device.py | 82 ++++++++++ tests/python/gpu/test_nccl.py | 45 ----- 6 files changed, 251 insertions(+), 164 deletions(-) create mode 100644 tests/python/gpu/test_device.py diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 12a898aadc24..fb44bf6ba927 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -83,6 +83,31 @@ export MXNET_GPU_WORKER_NTHREADS=3 - The minimum size of a "big array". - When the array size is bigger than this threshold, MXNET_KVSTORE_REDUCTION_NTHREADS threads are used for reduction. - This parameter is also used as a load balancer in kvstore. It controls when to partition a single weight to all the servers. If the size of a single weight is less than MXNET_KVSTORE_BIGARRAY_BOUND then, it is sent to a single randomly picked server otherwise it is partitioned to all the servers. + +* MXNET_KVSTORE_USETREE + - Values: 0(false) or 1(true) ```(default=0)``` + - If true, MXNet tries to use tree reduction for Push and Pull communication. + - Otherwise, MXNet uses the default Push and Pull implementation. + - [Tree reduction technology](http://www.sysml.cc/doc/178.pdf) has been shown to be faster than the standard ```--kv-store device``` Push/Pull and ```--kv-store nccl``` Push/Pull for small batch sizes. + +* MXNET_KVSTORE_LOGTREE + - Values: 0(false) or 1(true) ```(default=0)``` + - If true and MXNET_KVSTORE_USETREE is set to 1, MXNet will log the reduction trees that have been generated. + +* MXNET_KVSTORE_GPUARRAY_BOUND + - Values: Int ```(default=10000000)``` + - The minimum size of a "big array". + - When the array size is bigger than this threshold and MXNET_KVSTORE_USETREE is set to 1, multiple trees are used to load balance the big gradient being communicated in order to better saturate link bandwidth. + +* MXNET_KVSTORE_BACKTRACK + - Values: 0(false) or 1(true) ```(Default=0) + - If true and MXNET_KVSTORE_USETREE is set to 1, MXNet tries to use backtracking to generate the trees required for tree reduction. + - If false and MXNET_KVSTORE_USETREE is set to 1, MXNet tries to use Kernighan-Lin heuristic to generate the trees required for tree reduction. + +* MXNET_KVSTORE_LINK_USAGE_PENALTY + - Values: Float ```(default=0.7)``` + - The multiplicative penalty term to a link being used once. + * MXNET_ENABLE_GPU_P2P - Values: 0(false) or 1(true) ```(default=1)``` - If true, MXNet tries to use GPU peer-to-peer communication, if available on your device, diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 8f3c5f363ea5..5d2595ef2334 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -375,9 +375,9 @@ class CommDeviceTree : public CommDevice { std::vector link_matrix(devs_.size()*devs_.size()); GetP2PWeight(devs_, &link_matrix); if (backtrack_) - LOG(WARNING) << "Using Backtracking to generate trees"; + LOG(INFO) << "Using Backtracking to generate trees"; else - LOG(WARNING) << "Using Kernighan-Lin to generate trees"; + LOG(INFO) << "Using Kernighan-Lin to generate trees"; ComputeTrees(link_matrix, devs_.size(), link_usage_penalty_, backtrack_, &topology_, &scan_); @@ -388,7 +388,7 @@ class CommDeviceTree : public CommDevice { using KeyAttrs = std::tuple; // try to allocate buff on device evenly void InitMergeBufferTree() { - LOG(WARNING) << "Using Tree"; + LOG(INFO) << "Using Tree"; // same as all-reduce, except: // 1) Allocate copy_buf here instead of in Reduce() @@ -413,14 +413,30 @@ class CommDeviceTree : public CommDevice { int start = scan_[0][depth_]; int end = scan_[0][depth_+1]; - // In order to generalize to any number of GPUs, we use strategy of having - // found the mapping from 0, 1, ..., n_gpus to dev_id i.e. + // In order to generalize to any number of GPUs in arbitrary order, we use + // strategy of having found the mapping from 0, 1, ..., n_gpus to dev_id. + // For example, if the user wants to use --gpus 4,2,3,1,7,5,0, they can do // so: + // // idx: 0 1 2 3 4 5 6 // dev_id: 4 2 3 1 7 5 0 - // and generated an n_gpus x n_gpus link topology matrix: // - // 1) The reduction trees are saved as indices on 0, 1, ..., n_gpus - // 2) We use the mapping to retrieve dev_id and device context + // From this, we: + // 1) generate a link topology matrix with dimensions n_gpus x n_gpus + // (link_matrix) + // + // 2) the reduction trees are saved as indices from 0, 1, ..., n_gpus + // in a vector of vectors (topology_): + // + // index | topology_[index] + // ------------------------- + // 0 | [Tree 0] + // 1 | [Tree 1] + // . + // . + // . + // n_gpus | [Tree n_gpus] + // + // 3) We use the mapping (devs_) to retrieve dev_id and device context for (int j = start; j < end; ++j) { int topo_id = topology_[0][j]; auto& buf = tree_merge_buf_[topo_id][key]; @@ -469,7 +485,7 @@ class CommDeviceTree : public CommDevice { } for (auto it = key_dist.begin(); it != key_dist.end(); ++it) { - LOG(WARNING) << "Size " << it->first << " occurs " << it->second << " times"; + LOG(INFO) << "Size " << it->first << " occurs " << it->second << " times"; } inited_ = true; } @@ -501,8 +517,6 @@ class CommDeviceTree : public CommDevice { std::vector> scan_; std::vector devs_; - /// \brief Highest numbered device - int max_dev_; int depth_; int gpuarray_bound_; bool backtrack_; diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 1d82a391c79a..a8801499c3be 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -47,48 +47,51 @@ static bool kLogTree = dmlc::GetEnv("MXNET_KVSTORE_LOGTREE", false); template inline void PrintVector(const std::string& str, const std::vector& vec) { - std::cout << str << ":\n"; + LOG(INFO) << str << ":"; + std::string output; for (unsigned i = 0; i < vec.size(); ++i) - std::cout << vec[i] << " "; - std::cout << std::endl; + output += std::to_string(vec[i]) + " "; + LOG(INFO) << output; } template inline void PrintMatrix(const std::string& str, const std::vector& matrix, int num_rows, int num_cols) { - std::cout << str << ":\n"; + LOG(INFO) << str << ":"; int count = 0; for (int row = 0; row < num_rows; ++row) { + std::string output; for (int col = 0; col < num_cols; ++col) { - std::cout << matrix[count++] << " "; + output += std::to_string(static_cast(matrix[count++])) + " "; } - std::cout << std::endl; + LOG(INFO) << output; } } inline void PrintTopo(const std::string& str, const std::vector& topo_row, std::vector scan_row) { - PrintVector("Topo vector", topo_row); - PrintVector("Scan vector", scan_row); - std::cout << str << ":\n"; + LOG(INFO) << str << ":"; int depth = scan_row.size()-1; for (int row = 0; row < depth; ++row) { int start = scan_row[row]; int end = scan_row[row+1]; + std::string output; for (; start < end; start++) { for (int i = 0; i < (2 << (depth-row-2))+1; ++i) { - std::cout << " "; + output += " "; } - std::cout << topo_row[start]; + output += std::to_string(topo_row[start]); } - std::cout << std::endl; + LOG(INFO) << output; } } -// Uses BFS to find whether undirected graph is connected or not given its -// adjacency matrix -// Note: only consider matrix values > 1, because we care about whether it is -// connected using only NVLink connections +/** + * \brief Uses BFS to find whether undirected graph is connected or not given its + * adjacency matrix + * Note: only consider matrix values > 1, because we care about whether it is + * connected using only NVLink connections + */ template inline bool IsConnected(const std::vector& matrix, int num_gpus) { int source = 0; @@ -117,14 +120,16 @@ inline bool IsConnected(const std::vector& matrix, int num_gpus) { return true; } -// Generate adjacency matrix with row/col numbering from 0, 1, ..., n_gpu -// @input: devs is a vector of GPU contexts -// @output: matrix is adjacency matrix of link topology graph -// where edge weight represents relative performance of NVIDIA GPUs -// 0: Self-connection -// 1: PCI-E -// 2: 1 NVLink connection -// 3: 2 NVLink connections +/** + * \brief Generate adjacency matrix with row/col numbering from 0, 1, ..., n_gpu + * \param devs is a vector of GPU contexts + * \param matrix is adjacency matrix of link topology graph + * where edge weight represents relative performance of NVIDIA GPUs + * 0: Self-connection + * 1: PCI-E + * 2: 1 NVLink connection + * 3: 2 NVLink connections + */ template inline void GetP2PWeight(const std::vector& devs, std::vector* matrix) { int num_gpus = devs.size(); @@ -184,16 +189,16 @@ inline void GetP2PWeight(const std::vector& devs, std::vector* matri } } - if (kLogTree) - PrintMatrix("Weight W", *matrix, num_gpus, num_gpus); #else LOG(WARNING) << "GPU required for link topology"; #endif } -// Dense matrix-vector multiplication -// Assume: matrix is square -// y = A*x (no accumulate) +/** + * \brief Dense matrix-vector multiplication + * Assume: matrix is square + * y = A*x (no accumulate) + */ template inline void gemv(const std::vector& A, const std::vector& x, std::vector* y) { @@ -208,8 +213,10 @@ inline void gemv(const std::vector& A, const std::vector& x, } } -// Element-wise multiplication between 2 dense vectors -// w = w * alpha*u +/** + * \brief Element-wise multiplication between 2 dense vectors + * w = w * alpha*u + */ template inline void ewisemult(const std::vector& u, T alpha, std::vector* w) { int nelem = u.size(); @@ -218,11 +225,13 @@ inline void ewisemult(const std::vector& u, T alpha, std::vector* w) { } } -// Computes best 2 nodes a,b to swap given objective function: -// g = max_{a \in A, b \in B} D(a) + D(b) - 2*W(a,b) -// -// Optimization: Only need to look at upper triangular since weight matrix is -// symmetric +/** + * \brief Computes best 2 nodes a,b to swap given objective function: + * g = max_{a \in A, b \in B} D(a) + D(b) - 2*W(a,b) + * + * Optimization: Only need to look at upper triangular since weight matrix is + * symmetric + */ template inline void FindBestMove(const std::vector& W, const std::vector& P_temp, @@ -248,12 +257,14 @@ inline void FindBestMove(const std::vector& W, } } -// Performs partition on each existing partition in graph W if partition has -// more than 4 elements in it -// @output: stop returns true if no partitions with >=4 elements found -// returns false otherwise -// cluster_pairs stores the mapping that tells us which 2 clusters are -// the output of partitioning one large cluster +/** + * \brief Performs partition on each existing partition in graph W if partition has + * more than 4 elements in it + * \param stop returns true if no partitions with >=4 elements found + * returns false otherwise + * \param cluster_pairs stores the mapping that tells us which 2 clusters are + * the output of partitioning one large cluster + */ template inline bool KernighanLin(const std::vector& W, std::vector* P, int* num_partitions, @@ -411,8 +422,10 @@ inline bool KernighanLin(const std::vector& W, std::vector* P, return stop; } -// Returns root of a given color if found in roots -// Returns -1 if it is not found +/** + * \brief Returns root of a given color if found in roots + * Returns -1 if it is not found + */ inline int GetRoot(const std::vector& P, int color, const std::unordered_set& roots) { for (auto root : roots) { @@ -422,8 +435,10 @@ inline int GetRoot(const std::vector& P, int color, return -1; } -// Returns root of a given color if found in roots -// Returns -1 if it is not found +/** + * \brief Returns root of a given color if found in roots + * Returns -1 if it is not found + */ inline int GetChild(const std::vector& P, int color, int parent) { for (unsigned i = 0; i < P.size(); ++i) { if (P[i] == color && static_cast(i) != parent) @@ -720,28 +735,29 @@ inline T ComputeTreeWeight(const std::vector& W, const std::vector& resu return weight; } -// Given a spanning tree encoded as result, which was convenient for performing -// backtracking, convert it topology_ and scan_ in the classic "binary tree -// stored in an array" format. For binary trees scan_ is redundant, but this -// additional data structure leaves future generalization to k-radix trees. -// -// Initial result: [3 3 0 4 1 2 5 6] -// topology_: [3 3 1 3 0 1 5 3 3 0 4 1 2 5 6] -// scan_: [0 1 3 7 15] -// -// topology_ is stored in the classic "binary tree stored in an array" format -// e.g. 3 -// 3 1 -// 3 0 1 5 -// 3 3 0 4 1 2 5 6 -// -// Returns false if invalid tree in result -// Otherwise returns true +/** + * \brief Given a spanning tree encoded as result, which was convenient for performing + * backtracking, convert it topology_ and scan_ in the classic "binary tree + * stored in an array" format. For binary trees scan_ is redundant, but this + * additional data structure leaves future generalization to k-radix trees. + * + * Initial result: [3 3 0 4 1 2 5 6] + * topology_: [3 3 1 3 0 1 5 3 3 0 4 1 2 5 6] + * scan_: [0 1 3 7 15] + * + * topology_ is stored in the classic "binary tree stored in an array" format + * e.g. 3 + * 3 1 + * 3 0 1 5 + * 3 3 0 4 1 2 5 6 + * + * Returns false if invalid tree in result + * Otherwise returns true + */ inline bool FormTopology(const std::vector& result, std::vector* topo_row, std::vector* scan_row, int depth) { - //PrintVector("Best result", result); for (unsigned i = 0; i < result.size(); ++i) if (result[i] == -1) return false; @@ -762,11 +778,13 @@ inline bool FormTopology(const std::vector& result, return true; } -// Recursive function that finds a spanning tree, which fulfills the following -// conditions: -// -balanced -// -binary -// -maximum weight +/** + * \brief Recursive function that finds a spanning tree, which fulfills the following + * conditions: + * -balanced + * -binary + * -maximum weight + */ template inline bool RecursiveBacktrack(const std::vector& W, std::vector* state, @@ -822,8 +840,6 @@ inline void IterativeBacktrack(const std::vector& W, // a) if stack is empty, break and stop search // b) if stack is not empty, pop stack and set current position to next // position backtrack to previous row - //LOG(WARNING) << "Stack size: " << state_stack.size(); - //PrintVector("state", *state); while (!state_stack.empty() && pos >= num_elements) { pos = state_stack.top(); pos++; @@ -849,11 +865,8 @@ inline void IterativeBacktrack(const std::vector& W, // Pop stack, set current position to next position // Backtrack to find next solution if (row == static_cast(state->size())) { - //LOG(WARNING) << row << " == " << state->size(); std::vector result = *state; - //PrintVector("Best result", result); Postprocess(&result, num_elements, depth); - //PrintVector("Best result", result); T weight = ComputeTreeWeight(W, result, num_elements, depth, true); // Save this spanning tree if it is highest weight tree found so far @@ -866,15 +879,16 @@ inline void IterativeBacktrack(const std::vector& W, pos = state_stack.top(); pos++; state_stack.pop(); - //LOG(WARNING) << "Setting " << state_stack.size() << " to 1"; (*state)[state_stack.size()] = -1; row--; } } } -// Apply penalty factor alpha to each link in link topology graph that is used -// by the spanning tree +/** + * \brief Apply penalty factor alpha to each link in link topology graph that is used + * by the spanning tree + */ template inline void UpdateWeight(std::vector* W, const std::vector& topo_row, int num_elements, float alpha) { @@ -889,13 +903,15 @@ inline void UpdateWeight(std::vector* W, const std::vector& topo_row, } } -// Do brute-force backtracking approach if Kernighan-Lin fails to find a binary -// tree of height Log P. -// -// Constraints: -// 1) minimize depth (balance) -// 2) maximize edge weight -// 3) tree is binary +/** + * \brief Do brute-force backtracking approach if Kernighan-Lin fails to find a binary + * tree of height Log P. + * + * Constraints: + * 1) minimize depth (balance) + * 2) maximize edge weight + * 3) tree is binary + */ template inline bool BacktrackGenerateBinaryTree(std::vector* W, int num_elements, @@ -915,7 +931,6 @@ inline bool BacktrackGenerateBinaryTree(std::vector* W, // 9: 4 16 int depth = ComputeDepth(num_elements); int depth_leaves = 1 << depth; - //LOG(WARNING) << num_elements << " " << depth << " " << depth_leaves; // State vector // -1 means unplaced @@ -936,12 +951,13 @@ inline bool BacktrackGenerateBinaryTree(std::vector* W, IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, depth, false); } - //LOG(WARNING) << "Exit Iterative backtrack " << num_elements; return FormTopology(result, topo_row, scan_row, depth); } -// ComputeTreesFromRoot does the same thing as ComputeTrees, with the only -// exception being it will do it from a fixed GPU as root +/** + * \brief ComputeTreesFromRoot does the same thing as ComputeTrees, with the only + * exception being it will do it from a fixed GPU as root + */ template inline void ComputeTreesFromRoot(std::vector* W, int num_elements, @@ -1009,12 +1025,9 @@ inline void ComputeTreesFromRoot(std::vector* W, if (level > 10) break; } - //LOG(WARNING) << "ComputeFromRoot: " << num_elements; - bool success = true; if (reset == 1) { - // if (!backtrack) - // LOG(WARNING) << "No valid binary tree found from root " << root << ", try backtracking"; + // LOG(INFO) << "No valid binary tree found from root " << root << ", try backtracking"; success = BacktrackGenerateBinaryTree(W, num_elements, root, topo, scan); } else { *topo = topo_temp; @@ -1027,14 +1040,16 @@ inline void ComputeTreesFromRoot(std::vector* W, LOG(FATAL) << "No valid binary tree found from root " << root << " using backtracking"; } -// ComputeTrees computes balanced binary spanning trees of maximum edge weight -// given a link topology graph stored in adjacency matrix format -// @input: W is the link topology matrix -// num_elements is the number of GPUs -// alpha is the link usage penalty -// backtrack is whether or not we use backtracking to generate trees -// @output: topo stores the trees generated -// scan stores the start of each level of each tree +/** + * \brief ComputeTrees computes balanced binary spanning trees of maximum edge weight + * given a link topology graph stored in adjacency matrix format + * \param W is the link topology matrix + * \param num_elements is the number of GPUs + * \param alpha is the link usage penalty + * \param backtrack is whether or not we use backtracking to generate trees + * \param topo stores the trees generated + * \param scan stores the start of each level of each tree + */ template inline void ComputeTrees(const std::vector& W, int num_elements, @@ -1046,7 +1061,6 @@ inline void ComputeTrees(const std::vector& W, topo->clear(); scan->clear(); - //LOG(WARNING) << "ComputeTrees: " << num_elements; for (int i = 0; i < num_elements; ++i) { topo->push_back(std::vector()); scan->push_back(std::vector()); @@ -1075,7 +1089,7 @@ inline void ComputeTrees(const std::vector& W, if (kLogTree) { for (int i = 0; i < num_elements; ++i) - PrintTopo("Topo", (*topo)[i], (*scan)[i]); + PrintTopo("Tree "+std::to_string(i), (*topo)[i], (*scan)[i]); PrintMatrix("W", W, num_elements, num_elements); PrintMatrix("Links", adj, num_elements, num_elements); diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index 054cfded92bb..8d8d99d2eaae 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -77,8 +77,6 @@ void TestComputeTreesRandomized(int num_gpus, float alpha, int backtrack, GenerateMatrix(&W, num_gpus, k, gen); satisfied = IsSatisfactory(W, num_gpus, depth); } - //LOG(WARNING) << "num_gpus: " << num_gpus; - //mxnet::kvstore::PrintMatrix("Link topo", W, num_gpus, num_gpus); std::vector> topo; std::vector> scan; @@ -86,9 +84,8 @@ void TestComputeTreesRandomized(int num_gpus, float alpha, int backtrack, unsigned correct_topo_size = (1 << (depth + 1)) - 1; unsigned correct_scan_size = depth+2; - //LOG(WARNING) << topo.size() << " " << num_gpus; ASSERT_EQ(topo.size(), static_cast(num_gpus)); - for (int i = 0; i < topo.size(); ++i) { + for (unsigned i = 0; i < topo.size(); ++i) { ASSERT_EQ(correct_topo_size, topo[i].size()); ASSERT_EQ(correct_scan_size, scan[i].size()); } @@ -567,7 +564,7 @@ TEST(GpuTopology, TestComputeTrees1) { bool backtrack = true; // Do 5 randomized tests per GPU count from 2 to 16 for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { - LOG(WARNING) << "Testing " << num_gpus << " x " << num_gpus; + LOG(INFO) << "Testing " << num_gpus << " x " << num_gpus; for (int i = 0; i < 5; ++i) { TestComputeTreesRandomized(num_gpus, alpha, backtrack, &gen); } @@ -581,7 +578,7 @@ TEST(GpuTopology, TestComputeTrees2) { bool backtrack = false; // Do 5 randomized tests per GPU count from 2 to 16 for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { - LOG(WARNING) << "Testing " << num_gpus << " x " << num_gpus; + LOG(INFO) << "Testing " << num_gpus << " x " << num_gpus; for (int i = 0; i < 5; ++i) { TestComputeTreesRandomized(num_gpus, alpha, backtrack, &gen); } diff --git a/tests/python/gpu/test_device.py b/tests/python/gpu/test_device.py new file mode 100644 index 000000000000..2cf10c5bcae9 --- /dev/null +++ b/tests/python/gpu/test_device.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +import numpy as np +import unittest +import os + +shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)] +keys = [1,2,3,4,5,6,7] +num_gpus = len(mx.test_utils.list_gpus()) + + +if num_gpus > 8 : + print("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus)) + print("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.") + num_gpus = 8; + +gpus = range(1, 1+num_gpus) + +class EnvManager: + def __init__(self, key, val): + self._key = key + self._next_val = val + self._prev_val = None + + def __enter__(self): + try: + self._prev_val = os.environ[self._key] + except KeyError: + self._prev_val = '' + os.environ[self._key] = self._next_val + + def __exit__(self, ptype, value, trace): + os.environ[self._key] = self._prev_val + +def test_device_pushpull(): + def check_dense_pushpull(kv_type): + for shape, key in zip(shapes, keys): + for n_gpus in gpus: + kv_device = mx.kv.create(kv_type) + a = mx.nd.ones(shape, mx.gpu(0)) + cur_key = str(key*max(gpus)+n_gpus) + kv_device.init(cur_key, a) + arr_list = [mx.nd.ones(shape, mx.gpu(x)) for x in range(n_gpus)] + res = [mx.nd.zeros(shape, mx.gpu(x)) for x in range(n_gpus)] + kv_device.push(cur_key, arr_list) + kv_device.pull(cur_key, res) + for x in range(n_gpus): + assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0) + + envs1 = '1' + key1 = 'MXNET_KVSTORE_GPUARRAY_BOUND' + envs2 = ['','1'] + key2 = 'MXNET_KVSTORE_USETREE' + for i in range(2): + for val2 in envs2: + with EnvManager(key2, val2): + check_dense_pushpull('local') + check_dense_pushpull('device') + + os.environ[key1] = envs1 + os.environ[key1] = '' + + print ("Passed") + +if __name__ == '__main__': + test_device_pushpull() diff --git a/tests/python/gpu/test_nccl.py b/tests/python/gpu/test_nccl.py index ee822784966f..40ef6fdfd0af 100644 --- a/tests/python/gpu/test_nccl.py +++ b/tests/python/gpu/test_nccl.py @@ -32,50 +32,6 @@ gpus = range(1, 1+num_gpus) -class EnvManager: - def __init__(self, key, val): - self._key = key - self._next_val = val - self._prev_val = None - - def __enter__(self): - try: - self._prev_val = os.environ[self._key] - except KeyError: - self._prev_val = '' - os.environ[self._key] = self._next_val - - def __exit__(self, ptype, value, trace): - os.environ[self._key] = self._prev_val - -def test_device_pushpull(): - def check_dense_pushpull(kv_type): - for shape, key in zip(shapes, keys): - for n_gpus in gpus: - kv_device = mx.kv.create(kv_type) - a = mx.nd.ones(shape, mx.gpu(0)) - cur_key = str(key*max(gpus)+n_gpus) - kv_device.init(cur_key, a) - arr_list = [mx.nd.ones(shape, mx.gpu(x)) for x in range(n_gpus)] - res = [mx.nd.zeros(shape, mx.gpu(x)) for x in range(n_gpus)] - kv_device.push(cur_key, arr_list) - kv_device.pull(cur_key, res) - for x in range(n_gpus): - assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0) - - envs1 = '1' - key1 = 'MXNET_KVSTORE_GPUARRAY_BOUND' - envs2 = ['','1'] - key2 = 'MXNET_KVSTORE_USETREE' - for i in range(2): - for val2 in envs2: - with EnvManager(key2, val2): - check_dense_pushpull('local') - check_dense_pushpull('device') - - os.environ[key1] = envs1 - os.environ[key1] = '' - @unittest.skip("Test requires NCCL library installed and enabled during build") def test_nccl_pushpull(): for shape, key in zip(shapes, keys): @@ -94,5 +50,4 @@ def test_nccl_pushpull(): print ("Passed") if __name__ == '__main__': - test_device_pushpull() test_nccl_pushpull() From 67b0db02044627832ba105cab2350fc95a058830 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Tue, 24 Jul 2018 00:52:32 +0000 Subject: [PATCH 33/36] Fix code alignment --- src/kvstore/comm_tree.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 5d2595ef2334..fb1b71b5e7fa 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -136,8 +136,8 @@ class CommDeviceTree : public CommDevice { } else { if (dest_id != topo_id) { CopyFromTo(buf_from.merged[merged_row], - &(buf_dest.copy_buf[merged_row][is_dest-1]), - priority); + &(buf_dest.copy_buf[merged_row][is_dest-1]), + priority); reduce[dest_id].push_back( buf_dest.copy_buf[merged_row][is_dest-1]); } From c8ebb87a2df7cbc37d66684a544ef476232221b2 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Tue, 24 Jul 2018 00:54:29 +0000 Subject: [PATCH 34/36] get rid of todo --- tests/cpp/kvstore/gpu_topology_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index 8d8d99d2eaae..0f6d5f11cca1 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -557,7 +557,6 @@ TEST(GpuTopology, TestIsConnected3) { } // ComputeTreesTest with backtracking -// TODO(carlyang): comment out test for now TEST(GpuTopology, TestComputeTrees1) { std::mt19937 gen(1); float alpha = 0.7; From 5f7da5ed4e03f0ff2d927436014b3abf2be93ba1 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Tue, 24 Jul 2018 11:56:40 -0700 Subject: [PATCH 35/36] Make changes to env variable names to indicate they are TREE-related --- docs/faq/env_var.md | 8 ++++---- src/kvstore/comm_tree.h | 6 +++--- tests/python/gpu/test_device.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index fb44bf6ba927..9a7ec4a8695c 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -94,17 +94,17 @@ export MXNET_GPU_WORKER_NTHREADS=3 - Values: 0(false) or 1(true) ```(default=0)``` - If true and MXNET_KVSTORE_USETREE is set to 1, MXNet will log the reduction trees that have been generated. -* MXNET_KVSTORE_GPUARRAY_BOUND +* MXNET_KVSTORE_TREE_ARRAY_BOUND - Values: Int ```(default=10000000)``` - The minimum size of a "big array". - When the array size is bigger than this threshold and MXNET_KVSTORE_USETREE is set to 1, multiple trees are used to load balance the big gradient being communicated in order to better saturate link bandwidth. -* MXNET_KVSTORE_BACKTRACK - - Values: 0(false) or 1(true) ```(Default=0) +* MXNET_KVSTORE_TREE_BACKTRACK + - Values: 0(false) or 1(true) ```(default=0) - If true and MXNET_KVSTORE_USETREE is set to 1, MXNet tries to use backtracking to generate the trees required for tree reduction. - If false and MXNET_KVSTORE_USETREE is set to 1, MXNet tries to use Kernighan-Lin heuristic to generate the trees required for tree reduction. -* MXNET_KVSTORE_LINK_USAGE_PENALTY +* MXNET_KVSTORE_TREE_LINK_USAGE_PENALTY - Values: Float ```(default=0.7)``` - The multiplicative penalty term to a link being used once. diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index fb1b71b5e7fa..1ebfcdc8010d 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -51,9 +51,9 @@ class CommDeviceTree : public CommDevice { public: CommDeviceTree() { inited_ = false; - gpuarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_GPUARRAY_BOUND", 10000000); - backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 0); - link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_LINK_USAGE_PENALTY", 0.7); + gpuarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_TREE_ARRAY_BOUND", 10000000); + backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_TREE_BACKTRACK", 0); + link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_TREE_LINK_USAGE_PENALTY", 0.7); } virtual ~CommDeviceTree() { } diff --git a/tests/python/gpu/test_device.py b/tests/python/gpu/test_device.py index 2cf10c5bcae9..66772dc86c21 100644 --- a/tests/python/gpu/test_device.py +++ b/tests/python/gpu/test_device.py @@ -64,7 +64,7 @@ def check_dense_pushpull(kv_type): assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0) envs1 = '1' - key1 = 'MXNET_KVSTORE_GPUARRAY_BOUND' + key1 = 'MXNET_KVSTORE_TREE_ARRAY_BOUND' envs2 = ['','1'] key2 = 'MXNET_KVSTORE_USETREE' for i in range(2): From 16b8fb43fd3bb0942f1a767fdd77ffc4ae54ef37 Mon Sep 17 00:00:00 2001 From: Carl Yang Date: Tue, 24 Jul 2018 12:12:04 -0700 Subject: [PATCH 36/36] Add note saying when ARRAY_BOUND env var takes effect --- docs/faq/env_var.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 9a7ec4a8695c..881bc14fdc89 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -98,6 +98,7 @@ export MXNET_GPU_WORKER_NTHREADS=3 - Values: Int ```(default=10000000)``` - The minimum size of a "big array". - When the array size is bigger than this threshold and MXNET_KVSTORE_USETREE is set to 1, multiple trees are used to load balance the big gradient being communicated in order to better saturate link bandwidth. + - Note: This environmental variable only takes effect if Tree KVStore is being used (MXNET_KVSTORE_USETREE=1). * MXNET_KVSTORE_TREE_BACKTRACK - Values: 0(false) or 1(true) ```(default=0)