Skip to content

Commit

Permalink
gpu mem pool strategy (apache#11041)
Browse files Browse the repository at this point in the history
* use nearest power of 2 for gpu memory pool sizes

* add linear

* add test
  • Loading branch information
szha authored and zheng-da committed Jun 28, 2018
1 parent d56f2f3 commit d281019
Show file tree
Hide file tree
Showing 27 changed files with 259 additions and 37 deletions.
181 changes: 175 additions & 6 deletions src/storage/pooled_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
#if MXNET_USE_CUDA
#include <cuda_runtime.h>
#endif // MXNET_USE_CUDA

#include <mxnet/base.h>
#include <mxnet/storage.h>
#include <unordered_map>
#include <algorithm>
#include <vector>
#include <mutex>
#include <new>
Expand All @@ -43,7 +45,8 @@ namespace storage {

#if MXNET_USE_CUDA
/*!
* \brief Storage manager with a memory pool on gpu.
* \brief Storage manager with a memory pool on gpu. Memory chunks are reused based on exact size
* match.
*/
class GPUPooledStorageManager final : public StorageManager {
public:
Expand All @@ -52,6 +55,11 @@ class GPUPooledStorageManager final : public StorageManager {
*/
GPUPooledStorageManager() {
reserve_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_RESERVE", 5);
page_size_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_PAGE_SIZE", 4096);
if (page_size_ < NDEV) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_PAGE_SIZE cannot be set to a value smaller than " << NDEV \
<< ". Got " << page_size_ << ".";
}
}
/*!
* \brief Default destructor.
Expand All @@ -71,7 +79,7 @@ class GPUPooledStorageManager final : public StorageManager {
private:
void DirectFreeNoLock(Storage::Handle handle) {
cudaError_t err = cudaFree(handle.dptr);
size_t size = handle.size + NDEV;
size_t size = std::max(handle.size, page_size_);
// ignore unloading error, as memory has already been recycled
if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
Expand All @@ -83,18 +91,20 @@ class GPUPooledStorageManager final : public StorageManager {
void ReleaseAll();
// used memory
size_t used_memory_ = 0;
// page size
size_t page_size_;
// percentage of reserved memory
int reserve_;
// number of devices
const int NDEV = 32;
const size_t NDEV = 32;
// memory pool
std::unordered_map<size_t, std::vector<void*>> memory_pool_;
DISALLOW_COPY_AND_ASSIGN(GPUPooledStorageManager);
}; // class GPUPooledStorageManager

void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
size_t size = handle->size + NDEV;
size_t size = std::max(handle->size, page_size_);
auto&& reuse_it = memory_pool_.find(size);
if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) {
size_t free, total;
Expand All @@ -119,7 +129,7 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {

void GPUPooledStorageManager::Free(Storage::Handle handle) {
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
size_t size = handle.size + NDEV;
size_t size = std::max(handle.size, page_size_);
auto&& reuse_pool = memory_pool_[size];
reuse_pool.push_back(handle.dptr);
}
Expand All @@ -129,13 +139,172 @@ void GPUPooledStorageManager::ReleaseAll() {
for (auto&& j : i.second) {
Storage::Handle handle;
handle.dptr = j;
handle.size = i.first - NDEV;
handle.size = i.first;
DirectFreeNoLock(handle);
}
}
memory_pool_.clear();
}

/*!
* \brief Storage manager with a memory pool, with rounded size, on gpu.
*
* This GPU mem pool uses a mixture of nearest pow2 (exponential) rounding and
* nearest multiple (linear) rounding to help alleviate the memory allocation stress
* in which the default naive exact-size-match pool falls short, such as in variable-length
* input/output cases like RNN workloads.
*
* \param cutoff the cutoff at which rounding is switched from exponential to linear. It's set
* through MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF environment variable. Must be between 20 (1 MB)
* and 34 (16 GB).
* Suppose the cutoff is X, the memory size buckets look like this:
* exp2(0), exp2(1), ..., exp2(X), 2*exp2(X), 3*exp2(X), ...
*/
class GPUPooledRoundedStorageManager final : public StorageManager {
public:
/*!
* \brief Default constructor.
*/
GPUPooledRoundedStorageManager() {
reserve_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_RESERVE", 5);
page_size_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_PAGE_SIZE", 4096);
cut_off_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF", 24);
if (page_size_ < 32) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_PAGE_SIZE cannot be set to a value smaller than 32. " \
<< "Got: " << page_size_ << ".";
}
if (page_size_ != 1ul << log2_round_up(page_size_)) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_PAGE_SIZE must be a power of 2. Got: " << page_size_ << ".";
}
page_size_ = log2_round_up(page_size_);
if (cut_off_ < 20 || cut_off_ > LOG2_MAX_MEM) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF cannot be set to a value " \
<< "smaller than 20 or greater than " << LOG2_MAX_MEM << ". Got: " \
<< cut_off_ << ".";
}
if (cut_off_ < page_size_) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF cannot be set to a value " \
<< "smaller than log2 of MXNET_GPU_MEM_POOL_PAGE_SIZE. Got: " \
<< cut_off_ << " vs " << page_size_ << ".";
}
memory_pool_ = std::vector<std::vector<void*>>((1ul << (LOG2_MAX_MEM - cut_off_)) + cut_off_);
}
/*!
* \brief Default destructor.
*/
~GPUPooledRoundedStorageManager() {
ReleaseAll();
}

void Alloc(Storage::Handle* handle) override;
void Free(Storage::Handle handle) override;

void DirectFree(Storage::Handle handle) override {
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
DirectFreeNoLock(handle);
}

private:
inline int log2_round_up(size_t s) {
return static_cast<int>(std::ceil(std::log2(s)));
}
inline int div_pow2_round_up(size_t s, int divisor_log2) {
// (1025, 10) -> 2
// (2048, 10) -> 2
// (2049, 10) -> 3
size_t result = s >> divisor_log2;
return static_cast<int>(result + (s > (result << divisor_log2) ? 1 : 0));
}
inline int get_bucket(size_t s) {
int log_size = log2_round_up(s);
if (log_size > static_cast<int>(cut_off_))
return div_pow2_round_up(s, cut_off_) - 1 + cut_off_;
else
return std::max(log_size, static_cast<int>(page_size_));
}
inline size_t get_size(int bucket) {
if (bucket <= static_cast<int>(cut_off_))
return 1ul << bucket;
else
return (bucket - cut_off_ + 1) * (1ul << cut_off_);
}

void DirectFreeNoLock(Storage::Handle handle) {
cudaError_t err = cudaFree(handle.dptr);
size_t size = get_size(get_bucket(handle.size));
// ignore unloading error, as memory has already been recycled
if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
}
used_memory_ -= size;
}

private:
void ReleaseAll();
// number of devices
const int NDEV = 32;
// log2 of maximum page size. 16GB
const size_t LOG2_MAX_MEM = 34;
// address width in bits
static const int addr_width = sizeof(size_t) * 8;
// used memory
size_t used_memory_ = 0;
// page size
size_t page_size_;
// log2 of memory size before switching to exponential mode to linear mode
size_t cut_off_;
// percentage of reserved memory
int reserve_;
// memory pool
std::vector<std::vector<void*>> memory_pool_;
DISALLOW_COPY_AND_ASSIGN(GPUPooledRoundedStorageManager);
}; // class GPUPooledRoundedStorageManager

void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) {
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
int bucket = get_bucket(handle->size);
size_t size = get_size(bucket);
auto&& reuse_pool = memory_pool_[bucket];
if (reuse_pool.size() == 0) {
size_t free, total;
cudaMemGetInfo(&free, &total);
if (free <= total * reserve_ / 100 || size > free - total * reserve_ / 100)
ReleaseAll();

void* ret = nullptr;
cudaError_t e = cudaMalloc(&ret, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading) {
LOG(FATAL) << "cudaMalloc failed: " << cudaGetErrorString(e);
}
used_memory_ += size;
handle->dptr = ret;
} else {
auto ret = reuse_pool.back();
reuse_pool.pop_back();
handle->dptr = ret;
}
}

void GPUPooledRoundedStorageManager::Free(Storage::Handle handle) {
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
int bucket = get_bucket(handle.size);
auto&& reuse_pool = memory_pool_[bucket];
reuse_pool.push_back(handle.dptr);
}

void GPUPooledRoundedStorageManager::ReleaseAll() {
for (size_t i = 0; i < memory_pool_.size(); i++) {
int size = get_size(i);
for (auto& j : memory_pool_[i]) {
Storage::Handle handle;
handle.size = size;
handle.dptr = j;
DirectFreeNoLock(handle);
}
memory_pool_[i].clear();
}
}

#endif // MXNET_USE_CUDA

} // namespace storage
Expand Down
16 changes: 15 additions & 1 deletion src/storage/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,21 @@ void StorageImpl::Alloc(Storage::Handle* handle) {
#if MXNET_USE_CUDA
CUDA_CALL(cudaGetDeviceCount(&num_gpu_device));
CHECK_GT(num_gpu_device, 0) << "GPU usage requires at least 1 GPU";
ptr = new storage::GPUPooledStorageManager();

const char *type = getenv("MXNET_GPU_MEM_POOL_TYPE");
const bool default_pool = (type == nullptr);
if (default_pool) type = "Naive";
std::string strategy = type;

if (strategy == "Round") {
ptr = new storage::GPUPooledRoundedStorageManager();
LOG(INFO) << "Using GPUPooledRoundedStorageManager.";
} else {
if (strategy != "Naive") {
LOG(FATAL) << "Unknown memory pool strategy specified: " << strategy << ".";
}
ptr = new storage::GPUPooledStorageManager();
}
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to enable GPU usage";
#endif // MXNET_USE_CUDA
Expand Down
36 changes: 33 additions & 3 deletions tests/cpp/storage/storage_test.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
/* * Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
Expand All @@ -22,6 +21,7 @@
* \file storage_test.cc
* \brief cpu/gpu storage tests
*/
#include <stdlib.h>
#include <gtest/gtest.h>
#include <dmlc/logging.h>
#include <mxnet/storage.h>
Expand All @@ -43,7 +43,37 @@ TEST(Storage, Basic_CPU) {
}

#if MXNET_USE_CUDA
TEST(Storage, Basic_GPU) {
TEST(Storage_GPU, Basic_GPU) {
if (mxnet::test::unitTestsWithCuda) {
putenv("MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF=20");
putenv("MXNET_GPU_MEM_POOL_TYPE=Round");
auto &&storage = mxnet::Storage::Get();
mxnet::Context context_gpu = mxnet::Context::GPU(0);
auto &&handle = storage->Alloc(32, context_gpu);
auto &&handle2 = storage->Alloc(2097153, context_gpu);
EXPECT_EQ(handle.ctx, context_gpu);
EXPECT_EQ(handle.size, 32);
EXPECT_EQ(handle2.ctx, context_gpu);
EXPECT_EQ(handle2.size, 2097153);
auto ptr = handle.dptr;
auto ptr2 = handle2.dptr;
storage->Free(handle);
storage->Free(handle2);

handle = storage->Alloc(4095, context_gpu);
EXPECT_EQ(handle.ctx, context_gpu);
EXPECT_EQ(handle.size, 4095);
EXPECT_EQ(handle.dptr, ptr);
storage->Free(handle);

handle2 = storage->Alloc(3145728, context_gpu);
EXPECT_EQ(handle2.ctx, context_gpu);
EXPECT_EQ(handle2.size, 3145728);
EXPECT_EQ(handle2.dptr, ptr2);
storage->Free(handle2);
unsetenv("MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF");
unsetenv("MXNET_GPU_MEM_POOL_TYPE");
}
if (mxnet::test::unitTestsWithCuda) {
constexpr size_t kSize = 1024;
mxnet::Context context_gpu = mxnet::Context::GPU(0);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from mxnet.test_utils import *
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed
from common import setup_module, with_seed, teardown
from mxnet.gluon import utils

def _get_model():
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_gluon_model_zoo_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import unittest
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed
from common import setup_module, with_seed, teardown

def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/gpu/test_kvstore_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mxnet.test_utils import assert_almost_equal, default_context
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed
from common import setup_module, with_seed, teardown

shape = (4, 4)
keys = [5, 7, 11]
Expand Down Expand Up @@ -83,7 +83,7 @@ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False):
check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], is_same_rowid=True)
check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)])
check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], is_same_rowid=True)
check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], use_slice=True)
check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], use_slice=True)
check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], use_slice=True)

# test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/9384
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed
from common import setup_module, with_seed, teardown
from test_operator import *
from test_optimizer import *
from test_random import *
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,11 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self._dirname)

def teardown():
"""
A function with a 'magic name' executed automatically after each nosetests test module.
It waits for all operations in one file to finish before carrying on the next.
"""
mx.nd.waitall()
2 changes: 1 addition & 1 deletion tests/python/unittest/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from mxnet.ndarray import zeros_like
from mxnet.autograd import *
from mxnet.test_utils import *
from common import setup_module, with_seed
from common import setup_module, with_seed, teardown


def grad_and_loss(func, argnum=None):
Expand Down
Loading

0 comments on commit d281019

Please sign in to comment.